In [1]:
#| default_exp 03_benchmarking_nvembed_bm25

In [2]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np, math, transformers
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

from xcai.basics import *

from xclib.utils.sparse import retain_topk

from fastcore.utils import *

In [4]:
os.environ['WANDB_MODE'] = 'disabled'

In [5]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='oakVn_00-wikiseealsotitles'

## Huggingface `NV-Embed-v2` example

In [11]:
task_name_to_instruct = {"example": "Given a question, retrieve passages that answer the question",}

query_prefix = "Instruct: "+task_name_to_instruct["example"]+"\nQuery: "
queries = [
    'are judo throws allowed in wrestling?', 
    'how to become a radiology technician in michigan?'
    ]

passage_prefix = ""
passages = [
    "Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.",
    "Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan."
]

In [25]:
model = AutoModel.from_pretrained('nvidia/NV-Embed-v2', trust_remote_code=True)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [14]:
max_length = 32768
query_embeddings = model.encode(queries, instruction=query_prefix, max_length=max_length)
passage_embeddings = model.encode(passages, instruction=passage_prefix, max_length=max_length)

  'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),


In [16]:
scores = (query_embeddings @ passage_embeddings.T) * 100
print(scores.tolist())

[[87.42693328857422, 0.46283310651779175], [0.9652641415596008, 86.0372085571289]]


## Load data

In [6]:
#| export
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

output_dir = '/home/scai/phd/aiz218323/scratch/outputs/mogic/03_benchmarking_nvembed_bm25'

In [7]:
tokenizer = AutoTokenizer.from_pretrained('nvidia/NV-Embed-v2')

In [19]:
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_distilbert-base-uncased_xcs.pkl'

In [7]:
#| export
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_nv-embed-v2_xcs.pkl'

In [9]:
#| export
block = XCBlock.from_cfg(data_dir, 'data', transform_type='xcs', tokenizer='nvidia/NV-Embed-v2', 
                         sampling_features=[('lbl2data',1)], max_sequence_length=64, oversample=False)

In [73]:
def prompt_func(x):
    return f'''Instruct: Given the title of a wikipedia article and the corresponding categories of that article on wikipedia, \
your task is to predict the titles of all articles which are likely to be listed in the see also section of the mentioned article.\
\nQuery: {x}'''
    

In [10]:
#| export
def prompt_func(x):
    return f'''Instruct: Given the title of a wikipedia article, your task is to predict the titles of all articles which are \
likely to be listed in the see also section of the mentioned article.\nQuery: {x}'''
    

In [11]:
#| export
input_text = [prompt_func(o) for o in block.train.dset.data.data_info['input_text']]
tokenized_text = tokenizer.batch_encode_plus(input_text, truncation=True, max_length=64)
block.train.dset.data.data_info.update(tokenized_text)

input_text = [prompt_func(o) for o in block.test.dset.data.data_info['input_text']]
tokenized_text = tokenizer.batch_encode_plus(input_text, truncation=True, max_length=64)
block.test.dset.data.data_info.update(tokenized_text)

In [12]:
#| export
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [8]:
#| export
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [9]:
batch = next(iter(block.train.dl))

In [10]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx'])

In [102]:
m = AutoModel.from_pretrained('nvidia/NV-Embed-v2', trust_remote_code=True)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [103]:
o = m(**{'input_ids': batch['data_input_ids'], 'attention_mask': batch['data_attention_mask']})



In [107]:
o['sentence_embeddings'].shape

torch.Size([10, 47, 4096])

In [11]:
batch['data_input_ids'].shape

torch.Size([10, 47])

## Model

In [11]:
#| export
from contextlib import nullcontext
from xcai.models.modeling_nvembed import NVEmbedModel
from transformers.activations import get_activation

import torch.nn as nn
from xcai.losses import MultiTriplet

from xcai.models.modeling_utils import XCModelOutput, Pooling

from fastcore.meta import *

In [25]:
#| export
class RepresentationHead(torch.nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.transform = nn.Linear(config.hidden_size, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.projector = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = get_activation('relu')
        
        self.post_init()
        
    def post_init(self):
        torch.nn.init.eye_(self.transform.weight)
        torch.nn.init.eye_(self.projector.weight)
        
    def forward(self, x:torch.Tensor):
        x = self.transform(x)
        #x = self.activation(x)
        #x = self.layer_norm(x)
        #x = self.projector(x)
        return x
        

In [26]:
#| export
class NVM0XXEncoder(NVEmbedModel):
    
    def __init__(self, config, **kwargs):
        super().__init__(config)
        self.dr_head = RepresentationHead(config)
        
    @delegates(NVEmbedModel.__call__)
    def forward(
        self, 
        input_ids:Optional[torch.Tensor]=None, 
        attention_mask:Optional[torch.Tensor]=None,
        pool_mask: Optional[torch.Tensor]=None,
        return_dict: bool=True,
        **kwargs
    ):
        autocast_ctx = torch.autocast if torch.cuda.is_available() else nullcontext
        
        with autocast_ctx("cuda"):
            outputs = self.embedding_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            embeds = self.latent_attention_model(
                outputs.last_hidden_state,
                pool_mask,
            )
            rep = self.dr_head(embeds)
        
        return outputs, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
    

In [27]:
#| export
class NVM009(NVEmbedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.embedding_model,encoder.latent_attention_model"]
    
    def __init__(self,
                 config,
                 bsz:Optional[int]=None,
                 tn_targ:Optional[int]=None,
                 margin:Optional[float]=0.3,
                 tau:Optional[float]=0.1,
                 apply_softmax:Optional[bool]=False,
                 n_negatives:Optional[int]=5,
                 use_encoder_parallel:Optional[bool]=True,
                 *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        store_attr('use_encoder_parallel')
        self.encoder = NVM0XXEncoder(config)
        self.loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                    apply_softmax=apply_softmax, reduce='mean')
        self.post_init()
        self.remap_post_init()
        
    def init_dr_head(self):
        self.encoder.dr_head.post_init()
        
    def remap_post_init(self):
        self.embedding_model = self.encoder.embedding_model
        self.latent_attention_model = self.encoder.latent_attention_model
    
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask)
        
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask)
            
            loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
                                plbl2data_data2ptr, plbl2data_idx, **kwargs)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )

In [28]:
#| export
model = NVM009.from_pretrained('nvidia/NV-Embed-v2', bsz=1024, margin=0.3, tau=0.1, n_negatives=10, apply_softmax=True, 
                               use_encoder_parallel=False)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of NVM009 were not initialized from the model checkpoint at nvidia/NV-Embed-v2 and are newly initialized: ['encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [29]:
#| export
o = model(**batch)

In [30]:
#| export
o.loss

tensor(0.0333, grad_fn=<DivBackward0>)

In [65]:
def func():
    import pdb; pdb.set_trace()
    o = model(**batch)
    

In [66]:
func()

> /tmp/ipykernel_5460/3800815954.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     o = model(**batch)
      4 



ipdb>  b
ipdb>  b model.forward


Breakpoint 5 at /tmp/ipykernel_5460/4263273048.py:30


ipdb>  b model.encoder.forward


Breakpoint 6 at /tmp/ipykernel_5460/1004996700.py:7


ipdb>  n


> /tmp/ipykernel_5460/4263273048.py(43)forward()
     41         **kwargs
     42     ):
---> 43         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     44 
     45         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_5460/4263273048.py(45)forward()
     43         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     44 
---> 45         if self.use_encoder_parallel:
     46             encoder = nn.DataParallel(module=self.encoder)
     47         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_5460/4263273048.py(47)forward()
     45         if self.use_encoder_parallel:
     46             encoder = nn.DataParallel(module=self.encoder)
---> 47         else: encoder = self.encoder
     48 
     49         data_o, data_repr = encoder(data_input_ids, data_attention_mask)



ipdb>  


> /tmp/ipykernel_5460/4263273048.py(49)forward()
     47         else: encoder = self.encoder
     48 
---> 49         data_o, data_repr = encoder(data_input_ids, data_attention_mask)
     50 
     51         loss, lbl2data_repr = None, None



ipdb>  


> /tmp/ipykernel_5460/1004996700.py(16)forward()
     14         **kwargs
     15     ):
---> 16         autocast_ctx = torch.autocast if torch.cuda.is_available() else nullcontext
     17 
     18         with autocast_ctx("cuda"):



ipdb>  n


> /tmp/ipykernel_5460/1004996700.py(18)forward()
     16         autocast_ctx = torch.autocast if torch.cuda.is_available() else nullcontext
     17 
---> 18         with autocast_ctx("cuda"):
     19             outputs = self.embedding_model(
     20                 input_ids=input_ids,



ipdb>  


> /tmp/ipykernel_5460/1004996700.py(19)forward()
     17 
     18         with autocast_ctx("cuda"):
---> 19             outputs = self.embedding_model(
     20                 input_ids=input_ids,
     21                 attention_mask=attention_mask,



ipdb>  


> /tmp/ipykernel_5460/1004996700.py(20)forward()
     18         with autocast_ctx("cuda"):
     19             outputs = self.embedding_model(
---> 20                 input_ids=input_ids,
     21                 attention_mask=attention_mask,
     22             )



ipdb>  


> /tmp/ipykernel_5460/1004996700.py(21)forward()
     19             outputs = self.embedding_model(
     20                 input_ids=input_ids,
---> 21                 attention_mask=attention_mask,
     22             )
     23             embeds = self.latent_attention_model(



ipdb>  


> /tmp/ipykernel_5460/1004996700.py(19)forward()
     17 
     18         with autocast_ctx("cuda"):
---> 19             outputs = self.embedding_model(
     20                 input_ids=input_ids,
     21                 attention_mask=attention_mask,



ipdb>  


> /tmp/ipykernel_5460/1004996700.py(23)forward()
     21                 attention_mask=attention_mask,
     22             )
---> 23             embeds = self.latent_attention_model(
     24                 outputs.last_hidden_state,
     25                 pool_mask,



ipdb>  n


> /tmp/ipykernel_5460/1004996700.py(24)forward()
     22             )
     23             embeds = self.latent_attention_model(
---> 24                 outputs.last_hidden_state,
     25                 pool_mask,
     26             )



ipdb>  


> /tmp/ipykernel_5460/1004996700.py(25)forward()
     23             embeds = self.latent_attention_model(
     24                 outputs.last_hidden_state,
---> 25                 pool_mask,
     26             )
     27             rep = self.dr_head(embeds)



ipdb>  


> /tmp/ipykernel_5460/1004996700.py(23)forward()
     21                 attention_mask=attention_mask,
     22             )
---> 23             embeds = self.latent_attention_model(
     24                 outputs.last_hidden_state,
     25                 pool_mask,



ipdb>  


> /tmp/ipykernel_5460/1004996700.py(27)forward()
     25                 pool_mask,
     26             )
---> 27             rep = self.dr_head(embeds)
     28 
     29         return outputs, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)



ipdb>  torch.nn.Identity()


Identity()


ipdb>  self.dr_head = torch.nn.Identity()
ipdb>  n


> /tmp/ipykernel_5460/1004996700.py(29)forward()
     26             )
     27             rep = self.dr_head(embeds)
     28 
---> 29         return outputs, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
     30 



ipdb>  rep.shape


torch.Size([10, 47, 4096])


ipdb>  rep


tensor([[[-2.7431e+00, -1.6791e+00,  1.5581e+00,  ..., -2.8492e+00,
           2.5091e-02,  5.8852e-01],
         [-5.7464e+00, -7.0676e+00, -5.2332e-01,  ..., -1.0842e+00,
          -3.1414e+00,  2.7235e+00],
         [ 9.3980e-01, -7.7707e+00,  8.3928e+00,  ..., -1.5965e-01,
           1.6055e+00,  1.3432e+01],
         ...,
         [-7.6854e+00, -2.5286e+00,  4.4305e+00,  ..., -6.1472e+00,
          -2.4904e+00,  2.9824e+00],
         [-5.0773e+00,  1.8037e+00,  2.7820e+00,  ..., -1.3706e-02,
           3.6260e+00,  7.7177e+00],
         [-5.5169e+00,  1.3697e+00,  6.0640e+00,  ...,  3.9210e+00,
           1.3565e+00,  7.5124e+00]],

        [[-2.9780e+00, -7.2495e-01,  9.7949e-01,  ..., -2.5192e+00,
          -8.6461e-01,  5.7428e-01],
         [-7.6525e+00, -3.4010e+00,  8.5716e+00,  ...,  3.2418e+00,
           8.1959e-01,  1.0069e+01],
         [ 2.7054e-01, -7.2799e+00,  8.2326e+00,  ...,  6.8420e-01,
           3.3271e+00,  1.2162e+01],
         ...,
         [-8.6517e+00,  2

ipdb>  rep.dtype


torch.float32


ipdb>  rep.shape


torch.Size([10, 47, 4096])


ipdb>  l


     24                 outputs.last_hidden_state,
     25                 pool_mask,
     26             )
     27             rep = self.dr_head(embeds)
     28 
---> 29         return outputs, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
     30 



ipdb>  attention_mask.shape


torch.Size([10, 47])


ipdb>  attention_mask.dtype


torch.int64


ipdb>  xx = Pooling.mean_pooling(rep, attention_mask)
ipdb>  xx.shape


torch.Size([10, 4096])


ipdb>  xx


tensor([[-0.2438, -1.9693,  8.1109,  ...,  0.9516, -0.6723,  8.0693],
        [-0.9251, -0.3462,  9.0514,  ...,  1.3616, -2.7477, 10.3213],
        [ 6.9295, -0.1380,  6.1666,  ..., -0.9735,  0.6603,  4.2732],
        ...,
        [ 2.9733, -0.8720,  7.4740,  ...,  1.1270, -0.8351,  5.5786],
        [ 0.6036, -0.7551,  1.2821,  ...,  0.3292, -3.1174,  1.5676],
        [ 7.2001,  1.4673,  4.1804,  ..., -1.9378, -1.5556,  7.7727]],
       grad_fn=<DivBackward0>)


ipdb>  xx = F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
ipdb>  xx


tensor([[-0.0009, -0.0071,  0.0293,  ...,  0.0034, -0.0024,  0.0292],
        [-0.0033, -0.0012,  0.0322,  ...,  0.0048, -0.0098,  0.0367],
        [ 0.0246, -0.0005,  0.0219,  ..., -0.0035,  0.0023,  0.0152],
        ...,
        [ 0.0109, -0.0032,  0.0274,  ...,  0.0041, -0.0031,  0.0204],
        [ 0.0022, -0.0028,  0.0047,  ...,  0.0012, -0.0114,  0.0057],
        [ 0.0260,  0.0053,  0.0151,  ..., -0.0070, -0.0056,  0.0280]],
       grad_fn=<DivBackward0>)


ipdb>  xx.shape


torch.Size([10, 4096])


ipdb>  torch.norm(xx, dim=1)


tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000], grad_fn=<LinalgVectorNormBackward0>)


ipdb>  l





ipdb>  n


--Return--
(BaseModelOutp...tentions=None), tensor([[-0.0...DivBackward0>))
> /tmp/ipykernel_5460/1004996700.py(29)forward()
     26             )
     27             rep = self.dr_head(embeds)
     28 
---> 29         return outputs, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
     30 



ipdb>  xx = F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
ipdb>  xx


tensor([[-0.0009, -0.0071,  0.0293,  ...,  0.0034, -0.0024,  0.0292],
        [-0.0033, -0.0012,  0.0322,  ...,  0.0048, -0.0098,  0.0367],
        [ 0.0246, -0.0005,  0.0219,  ..., -0.0035,  0.0023,  0.0152],
        ...,
        [ 0.0109, -0.0032,  0.0274,  ...,  0.0041, -0.0031,  0.0204],
        [ 0.0022, -0.0028,  0.0047,  ...,  0.0012, -0.0114,  0.0057],
        [ 0.0260,  0.0053,  0.0151,  ..., -0.0070, -0.0056,  0.0280]],
       grad_fn=<DivBackward0>)


ipdb>  n


> /tmp/ipykernel_5460/4263273048.py(51)forward()
     49         data_o, data_repr = encoder(data_input_ids, data_attention_mask)
     50 
---> 51         loss, lbl2data_repr = None, None
     52         if lbl2data_input_ids is not None:
     53             lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask)



ipdb>  data_repr


tensor([[-0.0009, -0.0071,  0.0293,  ...,  0.0034, -0.0024,  0.0292],
        [-0.0033, -0.0012,  0.0322,  ...,  0.0048, -0.0098,  0.0367],
        [ 0.0246, -0.0005,  0.0219,  ..., -0.0035,  0.0023,  0.0152],
        ...,
        [ 0.0109, -0.0032,  0.0274,  ...,  0.0041, -0.0031,  0.0204],
        [ 0.0022, -0.0028,  0.0047,  ...,  0.0012, -0.0114,  0.0057],
        [ 0.0260,  0.0053,  0.0151,  ..., -0.0070, -0.0056,  0.0280]],
       grad_fn=<DivBackward0>)


ipdb>  n


> /tmp/ipykernel_5460/4263273048.py(52)forward()
     50 
     51         loss, lbl2data_repr = None, None
---> 52         if lbl2data_input_ids is not None:
     53             lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask)
     54 



ipdb>  


> /tmp/ipykernel_5460/4263273048.py(53)forward()
     51         loss, lbl2data_repr = None, None
     52         if lbl2data_input_ids is not None:
---> 53             lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask)
     54 
     55             loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 



ipdb>  


> /tmp/ipykernel_5460/1004996700.py(16)forward()
     14         **kwargs
     15     ):
---> 16         autocast_ctx = torch.autocast if torch.cuda.is_available() else nullcontext
     17 
     18         with autocast_ctx("cuda"):



ipdb>  r


> /tmp/ipykernel_5460/1004996700.py(18)forward()
     16         autocast_ctx = torch.autocast if torch.cuda.is_available() else nullcontext
     17 
---> 18         with autocast_ctx("cuda"):
     19             outputs = self.embedding_model(
     20                 input_ids=input_ids,



ipdb>  r


--Return--
(BaseModelOutp...tentions=None), tensor([[-0.0...DivBackward0>))
> /tmp/ipykernel_5460/1004996700.py(29)forward()
     26             )
     27             rep = self.dr_head(embeds)
     28 
---> 29         return outputs, F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)
     30 



ipdb>  rep


tensor([[[ -4.2426,  -1.9024,  -1.6038,  ...,  -3.0897,   0.3716,  -2.5756],
         [ -5.6590,   1.9587,  -4.9972,  ...,   0.5762,   3.0439,   2.9400],
         [ -7.4704,  -1.7220,   1.1262,  ...,  -6.5749,  -0.9878,  -2.3778],
         ...,
         [ -6.6515,   3.4415,  -2.5020,  ...,   6.7126,   4.5213,   0.8689],
         [ -6.5692,   3.0849,  -2.4015,  ...,   6.2121,   4.5357,   0.9597],
         [ -6.3492,   2.9202,  -2.3292,  ...,   6.1498,   4.2114,   0.9382]],

        [[ -4.7313,  -1.4236,  -0.8584,  ...,  -3.0321,  -0.8166,  -2.0879],
         [  5.0878,  -0.5830,   8.3816,  ...,  -2.1120,  -1.0463,   4.9551],
         [ -8.4701,   3.0259,   9.7580,  ...,  -0.6330,  -0.3507,   5.6358],
         ...,
         [ -5.9997,   3.7220,   1.0707,  ...,   3.9083,  -0.0986,   3.9553],
         [ -6.1704,   3.8101,   0.9107,  ...,   3.9931,   0.1715,   3.7594],
         [ -6.1943,   3.7836,   0.8842,  ...,   4.1534,   0.2901,   3.6224]],

        [[ -3.5082,  -1.1946,  -1.5158,  ...

ipdb>  F.normalize(Pooling.mean_pooling(rep, attention_mask), dim=1)


tensor([[-0.0283, -0.0004, -0.0078,  ..., -0.0162,  0.0121, -0.0020],
        [-0.0134,  0.0017,  0.0285,  ..., -0.0095, -0.0037,  0.0140],
        [-0.0187,  0.0083,  0.0042,  ..., -0.0048,  0.0006, -0.0172],
        ...,
        [-0.0070, -0.0051, -0.0001,  ..., -0.0058,  0.0058,  0.0011],
        [-0.0183,  0.0182, -0.0097,  ...,  0.0027, -0.0196, -0.0172],
        [-0.0004, -0.0003, -0.0051,  ..., -0.0139,  0.0081, -0.0012]],
       grad_fn=<DivBackward0>)


ipdb>  n


> /tmp/ipykernel_5460/4263273048.py(55)forward()
     53             lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask)
     54 
---> 55             loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
     56                                 plbl2data_data2ptr, plbl2data_idx, **kwargs)
     57 



ipdb>  lbl2data_repr


tensor([[-0.0283, -0.0004, -0.0078,  ..., -0.0162,  0.0121, -0.0020],
        [-0.0134,  0.0017,  0.0285,  ..., -0.0095, -0.0037,  0.0140],
        [-0.0187,  0.0083,  0.0042,  ..., -0.0048,  0.0006, -0.0172],
        ...,
        [-0.0070, -0.0051, -0.0001,  ..., -0.0058,  0.0058,  0.0011],
        [-0.0183,  0.0182, -0.0097,  ...,  0.0027, -0.0196, -0.0172],
        [-0.0004, -0.0003, -0.0051,  ..., -0.0139,  0.0081, -0.0012]],
       grad_fn=<DivBackward0>)


ipdb>  n


> /tmp/ipykernel_5460/4263273048.py(56)forward()
     54 
     55             loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
---> 56                                 plbl2data_data2ptr, plbl2data_idx, **kwargs)
     57 
     58         if not return_dict:



ipdb>  


> /tmp/ipykernel_5460/4263273048.py(55)forward()
     53             lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask)
     54 
---> 55             loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
     56                                 plbl2data_data2ptr, plbl2data_idx, **kwargs)
     57 



ipdb>  


> /tmp/ipykernel_5460/4263273048.py(56)forward()
     54 
     55             loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
---> 56                                 plbl2data_data2ptr, plbl2data_idx, **kwargs)
     57 
     58         if not return_dict:



ipdb>  


> /tmp/ipykernel_5460/4263273048.py(55)forward()
     53             lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask)
     54 
---> 55             loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
     56                                 plbl2data_data2ptr, plbl2data_idx, **kwargs)
     57 



ipdb>  


> /tmp/ipykernel_5460/4263273048.py(58)forward()
     56                                 plbl2data_data2ptr, plbl2data_idx, **kwargs)
     57 
---> 58         if not return_dict:
     59             o = (data_repr, lbl2data_repr)
     60             return ((loss,) + o) if loss is not None else o



ipdb>  loss


tensor(0.0140, grad_fn=<DivBackward0>)


ipdb>  n


> /tmp/ipykernel_5460/4263273048.py(62)forward()
     60             return ((loss,) + o) if loss is not None else o
     61 
---> 62         return XCModelOutput(
     63             loss=loss,
     64             data_repr=data_repr,



ipdb>  


> /tmp/ipykernel_5460/4263273048.py(63)forward()
     61 
     62         return XCModelOutput(
---> 63             loss=loss,
     64             data_repr=data_repr,
     65             lbl2data_repr=lbl2data_repr,



ipdb>  


> /tmp/ipykernel_5460/4263273048.py(64)forward()
     62         return XCModelOutput(
     63             loss=loss,
---> 64             data_repr=data_repr,
     65             lbl2data_repr=lbl2data_repr,
     66         )



ipdb>  


> /tmp/ipykernel_5460/4263273048.py(65)forward()
     62         return XCModelOutput(
     63             loss=loss,
     64             data_repr=data_repr,
---> 65             lbl2data_repr=lbl2data_repr,
     66         )



ipdb>  


> /tmp/ipykernel_5460/4263273048.py(62)forward()
     60             return ((loss,) + o) if loss is not None else o
     61 
---> 62         return XCModelOutput(
     63             loss=loss,
     64             data_repr=data_repr,



ipdb>  


--Return--
XCModelOutput...sed_repr=None)
> /tmp/ipykernel_5460/4263273048.py(62)forward()
     60             return ((loss,) + o) if loss is not None else o
     61 
---> 62         return XCModelOutput(
     63             loss=loss,
     64             data_repr=data_repr,



ipdb>  


--Return--
None
> /tmp/ipykernel_5460/3800815954.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     o = model(**batch)
      4 



ipdb>  o.loss


tensor(0.0140, grad_fn=<DivBackward0>)


ipdb>  n


--Call--
> /home/scai/phd/aiz218323/.local/lib/python3.9/site-packages/IPython/core/displayhook.py(258)__call__()
    256         sys.stdout.flush()
    257 
--> 258     def __call__(self, result=None):
    259         """Printing with history cache management.
    260 



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



## Driver

In [None]:
if __name__ == '__main__':
    build_block = True
    pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
    data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'
    
    output_dir = '/home/scai/phd/aiz218323/scratch/outputs/mogic/00_oak-for-wikiseealsotitles-trained-with-linker-predictions'
    meta_embed_file = '/home/aiscuser/scratch/OGB_Weights/LF-WikiSeeAlsoTitles-320K/emb_weights.npy'

    """ Load data """
    pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-lnk_distilbert-base-uncased_xcs.pkl'

    if build_block:
        block = XCBlock.from_cfg(data_dir, 'data_lnk', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                                 sampling_features=[('lbl2data',4), ('lnk2data',3)], oversample=False)
        with open(pkl_file, 'wb') as file: pickle.dump(block, file)
        exit()
    else:
        with open(pkl_file, 'rb') as file: block = pickle.load(file)
    
    """ Prune metadata """
    data_meta = retain_topk(block.train.dset.meta['lnk_meta'].data_meta, k=5)
    lbl_meta = block.train.dset.meta['lnk_meta'].lbl_meta
    block.train.dset.meta['lnk_meta'].update_meta_matrix(data_meta, lbl_meta)
    
    data_meta = retain_topk(block.test.dset.meta['lnk_meta'].data_meta, k=3)
    lbl_meta = block.test.dset.meta['lnk_meta'].lbl_meta
    block.test.dset.meta['lnk_meta'].update_meta_matrix(data_meta, lbl_meta)

    block.collator.tfms.tfms[0].sampling_features = [('lbl2data',4),('lnk2data',3)]
    block.collator.tfms.tfms[0].oversample = False
    
    block.train.dset.meta['lnk_meta'].meta_info = None
    block.test.dset.meta['lnk_meta'].meta_info = None

    """ Training arguements """
    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
        representation_search_type='BRUTEFORCE',
        
        output_representation_attribute='data_fused_repr',
        label_representation_attribute='data_repr',
        metadata_representation_attribute='data_repr',
        data_augmentation_attribute='data_repr',
        representation_attribute='data_fused_repr',
        clustering_representation_attribute='data_fused_repr',
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        use_data_metadata_for_clustering=True,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
        
        use_distributional_representation=False,
        use_encoder_parallel=True,
        max_grad_norm=None, 
        fp16=True,
        
        label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lnk2data_idx'],
        
        prune_metadata=False,
        num_metadata_prune_warmup_epochs=10,
        num_metadata_prune_epochs=5,
        metadata_prune_batch_size=2048,
        prune_metadata_names=['lnk_meta'],
        use_data_metadata_for_pruning=True,
    
        predict_with_augmentation=False,
        use_augmentation_index_representation=True,
    
        data_aug_meta_name='lnk',
        augmentation_num_beams=None,
        data_aug_prefix='lnk',
        use_label_metadata=False,
        
        data_meta_batch_size=2048,
        augment_metadata=False,
        num_metadata_augment_warmup_epochs=10,
        num_metadata_augment_epochs=5,
    
        use_cpu_for_searching=False,
        use_cpu_for_clustering=True,
    )

    """ model """
    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
    model = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=bsz, num_batch_labels=5000, 
                                   margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                                   data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                                   data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                                   
                                   num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                                   
                                   calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                                   calib_loss_weight=0.1, use_calib_loss=False,
    
                                   use_query_loss=True,
    
                                   meta_loss_weight=0.0, 
                                   
                                   fusion_loss_weight=0.0, use_fusion_loss=False,
                                   
                                   use_encoder_parallel=True)
    
    model.init_retrieval_head()
    model.init_cross_head()
    model.init_meta_embeddings()
    
    meta_embeddings = np.load(meta_embed_file)
    model.encoder.set_pretrained_meta_embeddings(torch.tensor(meta_embeddings, dtype=torch.float32))
    model.encoder.freeze_pretrained_meta_embeddings()
    
    """ Training """
    metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                      pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

    learn = XCLearner(
        model=model, 
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    mp.freeze_support()
    learn.train()
    

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,P@1,P@10,P@3,P@5,N@1,N@10,N@3,N@5,Psp@1,Psp@10,Psp@3,Psp@5,Psn@1,Psn@10,Psn@3,Psn@5,R@200,R@10,R@100
10,0.0788,0.086276,0.175101,0.056812,0.115025,0.087132,0.175101,0.194849,0.173305,0.180583,0.163741,0.209324,0.168619,0.180018,0.163741,0.201765,0.174354,0.185467,0.429889,0.235065,0.384685


  0%|          | 0/15617 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


  0%|          | 0/15617 [00:00<?, ?it/s]