In [1]:
#| default_exp models.LLL0XX

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [112]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import (
    PretrainedConfig,
    LlamaConfig,
    LlamaModel,
    LlamaPreTrainedModel,
)
from transformers.activations import ACT2FN

from fastcore.meta import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

from peft import (
    LoraConfig, 
    get_peft_model, 
    TaskType,
    PeftModel
)

## Setup

In [4]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

In [5]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_meta-llama-3-8b_oak.pkl'

In [6]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [7]:
batch = block.train.one_batch(10)
for i,batch in enumerate(block.train.dl):
    if i > 3: break

In [11]:
batch.keys()

dict_keys(['data_idx', 'data_input_ids', 'data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_attention_mask', 'lbl2data_input_ids'])

## Helper functions

In [128]:
#| export
class RepresentationHead(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.transform = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.projector = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = ACT2FN[config.hidden_act]
        
        self.post_init()
        
    def post_init(self):
        self.transform.weight.data = torch.eye(self.transform.out_features, self.transform.in_features, 
                                               dtype=self.transform.weight.dtype)
        self.projector.weight.data = torch.eye(self.projector.out_features, self.projector.in_features, 
                                               dtype=self.projector.weight.dtype)
        
    def forward(self, x:torch.Tensor):
        x = self.transform(x)
        x = self.activation(x)
        x = self.layer_norm(x)
        x = self.projector(x)
        return x
    

## `LAM009`

In [129]:
#| export
class LAM009Encoder(LlamaModel):
    
    def __init__(
        self, 
        config:LlamaConfig, 
        *args, **kwargs,
    ):
        super().__init__(config, *args, **kwargs)
        self.dr_head = RepresentationHead(config)
        self.post_init()
        
    @delegates(LlamaModel.__call__)
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        o = super().forward(input_ids, attention_mask)
        rep = self.dr_head(o[0])
        return o, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
        
    

In [130]:
#| export
class LAM009(LlamaPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["embed_tokens", "layers", "norm"]
    
    def __init__(self,
                 config,
                 bsz:Optional[int]=None,
                 tn_targ:Optional[int]=None,
                 margin:Optional[float]=0.3,
                 tau:Optional[float]=0.1,
                 apply_softmax:Optional[bool]=False,
                 n_negatives:Optional[int]=5,
                 use_encoder_parallel:Optional[bool]=True,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        store_attr('use_encoder_parallel')
        self.encoder = LAM009Encoder(config)
        self.loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                    apply_softmax=apply_softmax, reduce='mean')
        self.post_init()
        self.remap_post_init()
        
    def remap_post_init(self):
        self.layers = self.encoder.layers
        self.norm = self.encoder.norm
        self.embed_tokens = self.encoder.embed_tokens

    def init_retrieval_head(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        self.encoder.dr_head.post_init()

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask, output_attentions=output_attentions, 
                                    output_hidden_states=output_hidden_states, return_dict=return_dict)
        
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask, output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states, return_dict=return_dict)
            
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                plbl2data_data2ptr, plbl2data_idx, **kwargs)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

### Example

In [105]:
model = LAM009.from_pretrained('meta-llama/Meta-Llama-3-8B', bsz=1024, margin=0.3, tau=0.1, n_negatives=10, apply_softmax=True, 
                               use_encoder_parallel=False)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LAM009 were not initialized from the model checkpoint at meta-llama/Meta-Llama-3-8B and are newly initialized: ['model.encoder.dr_head.layer_norm.bias', 'model.encoder.dr_head.layer_norm.weight', 'model.encoder.dr_head.projector.bias', 'model.encoder.dr_head.projector.weight', 'model.encoder.dr_head.transform.bias', 'model.encoder.dr_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [106]:
model.init_retrieval_head()

In [107]:
vocab_size = model.encoder.embed_tokens.num_embeddings
model.encoder.resize_token_embeddings(vocab_size+1)

Embedding(128257, 4096)

In [108]:
o = model(**batch)

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [110]:
o.loss, o.data_repr.shape, o.lbl2data_repr.shape

(tensor(0.0259, grad_fn=<DivBackward0>),
 torch.Size([10, 4096]),
 torch.Size([29, 4096]))

In [116]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj","v_proj","o_proj"],
    bias='none',
)

In [117]:
m = get_peft_model(model, lora_config)

In [125]:
m.base_model.encoder.dr_head.requires_grad_(True)

RepresentationHead(
  (transform): Linear(in_features=4096, out_features=4096, bias=True)
  (layer_norm): LayerNorm((4096,), eps=1e-12, elementwise_affine=True)
  (projector): Linear(in_features=4096, out_features=4096, bias=True)
  (activation): SiLU()
)