In [None]:
#| default_exp models.lora

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import torch, numpy as np, os, pickle
from typing import Optional
import torch.nn as nn
from dataclasses import dataclass

from xcai.core import store_attr
from xcai.losses import MultiTriplet

from xcai.models.modeling_utils import XCModelOutput, Parameters

from transformers import DistilBertPreTrainedModel,DistilBertConfig
from transformers.utils.generic import ModelOutput

from peft import (
    LoraConfig, 
    get_peft_model, 
    TaskType,
    PeftModel
)

In [None]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
from transformers import AutoConfig
from xcai.block import *
from xcai.models.PPP0XX import DBT010

## Setup

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [None]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-meta_distilbert-base-uncased_rm_ramen-cat.pkl'

In [None]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 2: break

In [None]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx'])

## `LOR001`

In [None]:
#| export
class LOR001(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["peft_model.base_model.model.encoder.distilbert"]

    def __init__(
        self, config, model, peft_config, 
        
        pred_meta_prefix:Optional[str]=None, 
        
        num_batch_labels:Optional[int]=None, 
        batch_size:Optional[int]=None,
        margin:Optional[float]=0.3,
        num_negatives:Optional[int]=5,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,
        
        meta_loss_weight:Optional[float]=0.1,
        
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('pred_meta_prefix,meta_loss_weight')
        self.peft_model = get_peft_model(model, peft_config)
        self.rep_loss_fn = MultiTriplet(bsz=batch_size, tn_targ=num_batch_labels, margin=margin, n_negatives=num_negatives, 
                                        tau=tau, apply_softmax=apply_softmax, reduce='mean')

        self._mark_entire_model_as_trainable()

    def _mark_entire_model_as_trainable(self):
        for p in self.peft_model.parameters(): p.requires_grad_(True)

    def _mark_only_adapters_as_trainable(self):
        self.peft_model.base_model._mark_only_adapters_as_trainable(self.peft_model)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):  
        data_o = self.peft_model(data_input_ids, data_attention_mask, **kwargs)

        loss = data_o.loss
        meta_inputs = Parameters.from_meta_pred_prefix(self.pred_meta_prefix, **kwargs)
        if meta_inputs and loss is not None:
            self._mark_only_adapters_as_trainable()
            meta_inputs = next(iter(meta_inputs.values()))
        
            idx = torch.where(meta_inputs['data2ptr'])[0]
            if len(idx) > 0:
                meta_o = self.peft_model(data_input_ids=meta_inputs['input_ids'], data_attention_mask=meta_inputs['attention_mask'])
                m_loss = self.rep_loss_fn(data_o.data_repr[idx], meta_o.data_repr, meta_inputs['data2ptr'][idx], meta_inputs['idx'], 
                                      meta_inputs['pdata2ptr'][idx], meta_inputs['pidx'])
                loss += self.meta_loss_weight * m_loss
                
        self._mark_entire_model_as_trainable()
        
        return XCModelOutput(
            loss=loss,
            data_repr=data_o.data_repr,
            lbl2data_repr=data_o.lbl2data_repr,
        )
        

### Example

In [None]:
model = DBT010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=1600, tn_targ=5000, margin=0.3, tau=0.1, 
                               n_negatives=10, apply_softmax=True, use_encoder_parallel=False)
model.init_dr_head()

Some weights of DBT010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_lin", "k_lin","v_lin"],
    bias='none',
)

In [None]:
model = LOR001(DistilBertConfig(), model, lora_config, pred_meta_prefix='cat2data', batch_size=1600, num_batch_labels=5000, 
               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True, meta_loss_weight=1.0)

In [None]:
b = prepare_batch(model, batch, m_args=['pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 
                                        'cat2data_attention_mask', 'cat2data_data2ptr', 
                                        'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
                                        'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
                                        'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'lbl2data_data2ptr', 
                                       ])

In [None]:
o = model(**b)

> /tmp/ipykernel_22440/1180721033.py(44)forward()
     42         import pdb; pdb.set_trace()
     43 
---> 44         data_o = self.peft_model(data_input_ids, data_attention_mask, **kwargs)
     45 
     46         loss = data_o.loss



ipdb>  c


In [None]:
o.loss

tensor(0.0745, grad_fn=<AddBackward0>)