In [1]:
#| default_exp models.NVM0XX

In [2]:
#| hide
%load_ext autoreload
%autoreload 2

In [2]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union

from transformers.activations import get_activation

from fastcore.meta import *

from xcai.core import store_attr
from xcai.losses import MultiTriplet
from xcai.models.modeling_nvembed import NVEmbedModel
from xcai.models.modeling_utils import XCModelOutput, Pooling

## Setup

In [4]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

In [5]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_nv-embed-v2_xcs.pkl'

In [6]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [7]:
batch = block.train.one_batch(10)
for i,batch in enumerate(block.train.dl):
    if i > 3: break

In [8]:
batch.keys()

dict_keys(['data_idx', 'data_input_ids', 'data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_attention_mask', 'lbl2data_input_ids'])

## Helper functions

In [117]:
#| export
class Pooling:

    @staticmethod
    def mean_pooling(data_embeds:torch.FloatTensor, data_attention_mask:torch.LongTensor):
        data_attention_mask = data_attention_mask.unsqueeze(2).expand(data_embeds.size()).float()
        return torch.sum(data_embeds * data_attention_mask, 1) / torch.clamp(data_attention_mask.sum(1), min=1e-9)

    @staticmethod
    def last_pooling(data_embeds:torch.FloatTensor, data_attention_mask:torch.LongTensor):
        index = data_attention_mask.sum(dim=1) - 1
        return data_embeds[torch.arange(data_embeds.size(0), device=data_embeds.device), index]
    

In [3]:
#| export
class RepresentationHead(torch.nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.transform = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.projector = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = get_activation('relu')
        
        self.post_init()
        
    def post_init(self):
        torch.nn.init.eye_(self.transform.weight)
        torch.nn.init.eye_(self.projector.weight)
        
        torch.nn.init.zeros_(self.transform.bias)
        torch.nn.init.zeros_(self.projector.bias)
        
    def forward(self, x:torch.Tensor):
        x = self.transform(x)
        x = self.activation(x)
        x = self.layer_norm(x)
        x = self.projector(x)
        return x
    

## `NVM009`

In [4]:
#| export
class NVM009Encoder(NVEmbedModel):
    
    def __init__(self, config, **kwargs):
        super().__init__(config)
        self.dr_head = RepresentationHead(config)
        
    @delegates(NVEmbedModel.__call__)
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None,
        pool_mask: Optional[torch.Tensor]=None,
        return_dict: bool=True,
        **kwargs
    ):
        outputs = self.embedding_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        embeds = self.latent_attention_model(
            outputs.last_hidden_state,
            pool_mask,
        )
        rep = self.dr_head(embeds)
        
        return outputs, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
    

In [5]:
#| export
class NVM009(NVEmbedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.embedding_model,encoder.latent_attention_model"]
    
    def __init__(self,
                 config,
                 bsz:Optional[int]=None,
                 tn_targ:Optional[int]=None,
                 margin:Optional[float]=0.3,
                 tau:Optional[float]=0.1,
                 apply_softmax:Optional[bool]=False,
                 n_negatives:Optional[int]=5,
                 use_encoder_parallel:Optional[bool]=True,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        store_attr('use_encoder_parallel')
        self.encoder = NVM0XXEncoder(config)
        self.loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                    apply_softmax=apply_softmax, reduce='mean')
        self.post_init()
        self.remap_post_init()
        
    def init_dr_head(self):
        self.encoder.dr_head.post_init()
        
    def remap_post_init(self):
        self.embedding_model = self.encoder.embedding_model
        self.latent_attention_model = self.encoder.latent_attention_model
    
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask)
        
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask)
            
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                plbl2data_data2ptr, plbl2data_idx, **kwargs)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )

### Example

In [121]:
model = NVM009.from_pretrained('nvidia/NV-Embed-v2', bsz=1024, margin=0.3, tau=0.1, n_negatives=10, apply_softmax=True, 
                               use_encoder_parallel=False)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of LAM009 were not initialized from the model checkpoint at meta-llama/Meta-Llama-3-8B and are newly initialized: ['model.encoder.dr_head.layer_norm.bias', 'model.encoder.dr_head.layer_norm.weight', 'model.encoder.dr_head.projector.bias', 'model.encoder.dr_head.projector.weight', 'model.encoder.dr_head.transform.bias', 'model.encoder.dr_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [122]:
model.init_dr_head()

In [124]:
o = model(**batch)

In [None]:
o.loss

In [126]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj","v_proj","o_proj"],
    bias='none',
)

In [127]:
m = get_peft_model(model, lora_config)