In [1]:
#| default_exp models.dexa

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn, math
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import (
    PretrainedConfig,
    DistilBertForMaskedLM,
    DistilBertModel,
    DistilBertPreTrainedModel,
)
from transformers.utils.generic import ModelOutput
from transformers.activations import get_activation

from fastcore.meta import *
from fastcore.utils import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

comet_ml is installed but `COMET_API_KEY` is not set.


In [4]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [5]:
from transformers import AutoConfig
from xcai.block import *

## Setup

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [7]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-meta_distilbert-base-uncased_xcs.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [8]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [9]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 2: break

In [10]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'cat2lbl2data_idx', 'cat2lbl2data_identifier', 'cat2lbl2data_input_text', 'cat2lbl2data_input_ids', 'cat2lbl2data_attention_mask', 'cat2lbl2data_data2ptr', 'cat2lbl2data_lbl2data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'cat2data_idx', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr'])

## Encoder

In [11]:
#| export
class Encoder(DistilBertPreTrainedModel):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        self.dr_head = RepresentationHead(config)
        self.post_init()
        
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
    
    def dr(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        data_repr = self.dr(data_o[0], data_attention_mask)
        return EncoderOutput(
            rep=data_repr,
        )
        

## `DEX001`

In [12]:
#| export
class DEX001(DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]
    
    def __init__(
        self, config,
        n_labels:int,
        n_clusters:int,
        num_batch_labels:Optional[int]=None, 
        batch_size:Optional[int]=None,
        margin:Optional[float]=0.3,
        num_negatives:Optional[int]=5,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,
        use_encoder_parallel:Optional[bool]=True,
        
    ):
        super().__init__(config)
        store_attr('use_encoder_parallel')
        
        self.encoder = Encoder(config)
        self.label_embeddings = nn.Embedding(n_clusters, config.dim)
        self.register_buffer("label_remap", torch.arange(n_labels)%n_clusters, persistent=True)
        
        self.rep_loss_fn = MultiTriplet(bsz=batch_size, tn_targ=num_batch_labels, margin=margin, n_negatives=num_negatives, 
                                        tau=tau, apply_softmax=apply_softmax, reduce='mean')
        self.post_init(); self.remap_post_init(); self.init_retrieval_head()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        
    def init_retrieval_head(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        self.encoder.dr_head.post_init()

    def init_label_embeddings(self):
        self.label_embeddings.weight.data = torch.zeros_like(self.label_embeddings.weight.data)

    def set_label_embeddings(self, embed:torch.Tensor):
        self.label_embeddings.weight.data = embed

    def set_label_remap(self, label_remap:torch.Tensor):
        if label_remap.shape[0] != self.label_remap.shape[0]:
            raise ValueError(f'Shape mismatch, `label_remap` should have {self.label_remap.shape[0]} elements.')
        self.label_remap = label_remap

    def compute_loss(self, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def get_label_representation(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
        data_o.rep = F.normalize(data_o.rep + self.label_embeddings(self.label_remap[data_idx]), dim=1)
        return XCModelOutput(
            data_repr=data_o.rep,
        )
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask)
            lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
            
            loss = self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            data_repr=data_o.rep,
            lbl2data_repr=lbl2data_o.rep,
        )
        

### Example

In [14]:
n_clusters = block.n_lbl//3

In [15]:
model = DEX001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True, use_encoder_parallel=False,
                               n_labels=block.n_lbl, n_clusters=n_clusters)
model.init_retrieval_head()
model.init_label_embeddings()

Some weights of DEX001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'label_embeddings.weight', 'label_remap']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
model.set_label_remap(torch.arange(block.n_lbl)%n_clusters)

In [19]:
model = model.to('cuda')

In [21]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [22]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [86]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [87]:
func()

> /tmp/ipykernel_35716/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_35716/1394788769.py:67


ipdb>  c


> /tmp/ipykernel_35716/1394788769.py(82)forward()
     80         **kwargs
     81     ):  
---> 82         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     83 
     84         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_35716/1394788769.py(84)forward()
     82         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     83 
---> 84         if self.use_encoder_parallel:
     85             encoder = nn.DataParallel(module=self.encoder)
     86         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_35716/1394788769.py(86)forward()
     84         if self.use_encoder_parallel:
     85             encoder = nn.DataParallel(module=self.encoder)
---> 86         else: encoder = self.encoder
     87 
     88         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)



ipdb>  


> /tmp/ipykernel_35716/1394788769.py(88)forward()
     86         else: encoder = self.encoder
     87 
---> 88         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
     89 
     90         loss = None; lbl2data_o = EncoderOutput()



ipdb>  n


> /tmp/ipykernel_35716/1394788769.py(90)forward()
     88         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
     89 
---> 90         loss = None; lbl2data_o = EncoderOutput()
     91         if lbl2data_input_ids is not None:
     92             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask)



ipdb>  data_o[0].shape


torch.Size([5, 768])


ipdb>  n


> /tmp/ipykernel_35716/1394788769.py(91)forward()
     89 
     90         loss = None; lbl2data_o = EncoderOutput()
---> 91         if lbl2data_input_ids is not None:
     92             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask)
     93             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)



ipdb>  


> /tmp/ipykernel_35716/1394788769.py(92)forward()
     90         loss = None; lbl2data_o = EncoderOutput()
     91         if lbl2data_input_ids is not None:
---> 92             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask)
     93             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
     94 



ipdb>  


> /tmp/ipykernel_35716/1394788769.py(93)forward()
     91         if lbl2data_input_ids is not None:
     92             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask)
---> 93             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
     94 
     95             loss = self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,



ipdb>  lbl2data_o[0].shape


torch.Size([13, 768])


ipdb>  self.label_remap.shape


torch.Size([312330])


ipdb>  n


> /tmp/ipykernel_35716/1394788769.py(95)forward()
     93             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
     94 
---> 95             loss = self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     96                                      plbl2data_data2ptr,plbl2data_idx)
     97 



ipdb>  n


> /tmp/ipykernel_35716/1394788769.py(96)forward()
     94 
     95             loss = self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
---> 96                                      plbl2data_data2ptr,plbl2data_idx)
     97 
     98         if not return_dict:



ipdb>  


> /tmp/ipykernel_35716/1394788769.py(95)forward()
     93             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
     94 
---> 95             loss = self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     96                                      plbl2data_data2ptr,plbl2data_idx)
     97 



ipdb>  


> /tmp/ipykernel_35716/1394788769.py(98)forward()
     96                                      plbl2data_data2ptr,plbl2data_idx)
     97 
---> 98         if not return_dict:
     99             o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
    100             return ((loss,) + o) if loss is not None else o



ipdb>  


> /tmp/ipykernel_35716/1394788769.py(103)forward()
    101 
    102 
--> 103         return XCModelOutput(
    104             loss=loss,
    105             data_repr=data_o.rep,



ipdb>  


> /tmp/ipykernel_35716/1394788769.py(104)forward()
    102 
    103         return XCModelOutput(
--> 104             loss=loss,
    105             data_repr=data_o.rep,
    106             lbl2data_repr=lbl2data_o.rep,



ipdb>  


> /tmp/ipykernel_35716/1394788769.py(105)forward()
    103         return XCModelOutput(
    104             loss=loss,
--> 105             data_repr=data_o.rep,
    106             lbl2data_repr=lbl2data_o.rep,
    107         )



ipdb>  


> /tmp/ipykernel_35716/1394788769.py(106)forward()
    104             loss=loss,
    105             data_repr=data_o.rep,
--> 106             lbl2data_repr=lbl2data_o.rep,
    107         )
    108 



ipdb>  


> /tmp/ipykernel_35716/1394788769.py(103)forward()
    101 
    102 
--> 103         return XCModelOutput(
    104             loss=loss,
    105             data_repr=data_o.rep,



ipdb>  c


XCModelOutput(loss=tensor(0.0066, device='cuda:0', grad_fn=<DivBackward0>), logits=None, data_repr=tensor([[-0.0346, -0.0148, -0.0316,  ...,  0.0544,  0.0028, -0.0312],
        [-0.0225, -0.0189, -0.0329,  ...,  0.0344, -0.0342,  0.0102],
        [ 0.0112,  0.0063, -0.0371,  ...,  0.0715, -0.0157,  0.0266],
        [ 0.0698,  0.0139,  0.0030,  ...,  0.0345, -0.0052,  0.0286],
        [-0.0170, -0.0224, -0.0301,  ...,  0.0244, -0.0210, -0.0329]],
       device='cuda:0', grad_fn=<DivBackward0>), data_fused_repr=None, lbl2data_repr=tensor([[-0.0217, -0.0110, -0.0350,  ...,  0.0344, -0.0210,  0.0674],
        [-0.0346, -0.0148, -0.0316,  ...,  0.0544,  0.0028, -0.0312],
        [-0.0244, -0.0302, -0.0259,  ...,  0.0584,  0.0058, -0.0318],
        ...,
        [-0.0170, -0.0224, -0.0301,  ...,  0.0244, -0.0210, -0.0329],
        [-0.0183, -0.0332, -0.0331,  ..., -0.0178,  0.0172, -0.0340],
        [ 0.0358,  0.0146,  0.0086,  ...,  0.0373, -0.0104, -0.0009]],
       device='cuda:0', grad_fn

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [23]:
o.loss

tensor(0.0066, device='cuda:0', grad_fn=<DivBackward0>)

## `DEX002`

In [24]:
#| export
class DEX002(DEX001):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(DEX001.__init__)
    def __init__(
        self, config,
        n_labels:int,
        n_clusters:int,
        **kwargs
    ):
        super().__init__(config, n_labels=n_labels, n_clusters=n_clusters, **kwargs)
        self.label_embeddings = nn.Embedding(n_clusters, config.dim, sparse=True)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head()
        

### Example

In [25]:
n_clusters = block.n_lbl//3

In [26]:
model = DEX002.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True, use_encoder_parallel=False,
                               n_labels=block.n_lbl, n_clusters=n_clusters)
                               
                               
model.init_retrieval_head()
model.init_label_embeddings()

Some weights of DEX002 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'label_embeddings.weight', 'label_remap']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [27]:
model.set_label_remap(torch.arange(block.n_lbl)%n_clusters)

In [28]:
model = model.to('cuda')

In [29]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [30]:
o = model(**b.to(model.device))

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [None]:
o = func()

In [31]:
o.loss

tensor(0.0130, device='cuda:0', grad_fn=<DivBackward0>)

In [None]:
model = OAK003.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

Some weights of OAK003 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [None]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

> /tmp/ipykernel_26893/4002658968.py(65)resize()
     63         #debug
     64 
---> 65         if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
     66         bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
     67 



ipdb>  num_inputs


tensor([3, 3, 3, 3, 3])


ipdb>  n


> /tmp/ipykernel_26893/4002658968.py(66)resize()
     64 
     65         if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
---> 66         bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
     67 
     68         self.ones = self.ones.to(idx.device)



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(68)resize()
     66         bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
     67 
---> 68         self.ones = self.ones.to(idx.device)
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(71)resize()
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
---> 71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )
     73 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(70)resize()
     68         self.ones = self.ones.to(idx.device)
     69         ones = (
---> 70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
     71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(71)resize()
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
---> 71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )
     73 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(70)resize()
     68         self.ones = self.ones.to(idx.device)
     69         ones = (
---> 70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
     71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(71)resize()
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
---> 71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )
     73 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(69)resize()
     67 
     68         self.ones = self.ones.to(idx.device)
---> 69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
     71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(74)resize()
     72         )
     73 
---> 74         max_num_inputs = num_inputs.max()
     75         if (num_inputs == max_num_inputs).all():
     76             return idx,ones



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(75)resize()
     73 
     74         max_num_inputs = num_inputs.max()
---> 75         if (num_inputs == max_num_inputs).all():
     76             return idx,ones
     77 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(76)resize()
     74         max_num_inputs = num_inputs.max()
     75         if (num_inputs == max_num_inputs).all():
---> 76             return idx,ones
     77 
     78         xnum_inputs = max_num_inputs-num_inputs+1



ipdb>  idx.shape


torch.Size([15])


ipdb>  ones.shape


torch.Size([15])


ipdb>  n


--Return--
(tensor([16184...4875,  84879]), tensor([1, 1,..., 1, 1, 1, 1]))
> /tmp/ipykernel_26893/4002658968.py(76)resize()
     74         max_num_inputs = num_inputs.max()
     75         if (num_inputs == max_num_inputs).all():
---> 76             return idx,ones
     77 
     78         xnum_inputs = max_num_inputs-num_inputs+1



ipdb>  n


> /tmp/ipykernel_26893/3123686291.py(36)fuse_meta_into_embeddings()
     34             if len(idx):
     35                 m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
---> 36                 m_repr = F.normalize(self.meta_embeddings(m_idx) + self.pretrained_meta_embeddings(m_idx), dim=1)
     37 
     38                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)



ipdb>  m_idx.shape


torch.Size([15])


ipdb>  m_repr_mask.shape


torch.Size([15])


ipdb>  n


> /tmp/ipykernel_26893/3123686291.py(38)fuse_meta_into_embeddings()
     36                 m_repr = F.normalize(self.meta_embeddings(m_idx) + self.pretrained_meta_embeddings(m_idx), dim=1)
     37 
---> 38                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
     39                 meta_repr[m_key] = m_repr[m_repr_mask]
     40 



ipdb>  c


  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(0.0675, grad_fn=<AddBackward0>)

In [None]:
model = OAK004.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               cross_tau=1.0, cross_dropout=0.1,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()

Some weights of OAK004 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.tau', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_

In [None]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

model.init_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [None]:
o.loss

tensor(0.0618, device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
model = OAK006.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

Some weights of OAK006 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [None]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.0160, device='cuda:0', grad_fn=<AddBackward0>)