In [1]:
#| default_exp models.oakY

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn, math
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import (
    PretrainedConfig,
    DistilBertForMaskedLM,
    DistilBertModel,
    DistilBertPreTrainedModel,
    DistilBertConfig,
)
from transformers.utils.generic import ModelOutput
from transformers.activations import get_activation

from fastcore.meta import *
from fastcore.utils import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

In [4]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [5]:
from transformers import AutoConfig
from xcai.block import *

## Setup

In [67]:
from xcai.main import *

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [64]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/'
config_file = 'wikiseealsotitles'
config_key = 'data_lnk'

mname = 'sentence-transformers/msmarco-distilbert-base-v4'

In [65]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/processed/'
pkl_file = f'{pkl_dir}/wikiseealsotitles_data-meta_distilbert-base-uncased_sxc.joblib'

In [68]:
block = build_block(pkl_file, config_file, True, config_key, data_dir=data_dir, n_sdata_meta_samples=3, do_build=False)



In [69]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 2: break
    

In [70]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_data2ptr', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'plnk2lbl_idx', 'plnk2lbl_lbl2ptr', 'lnk2lbl_idx', 'lnk2lbl_lbl2ptr', 'lnk2lbl_identifier', 'lnk2lbl_input_text', 'lnk2lbl_input_ids', 'lnk2lbl_attention_mask', 'lnk2lbl_data2ptr', 'plnk2lbl_data2ptr'])

## Encoder

In [71]:
#| export
class Encoder(DistilBertPreTrainedModel):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        
        self.dr_head = RepresentationHead(config)
        self.dr_fused_head = RepresentationHead(config)
        self.meta_head = RepresentationHead(config)
        self.cross_head = CrossAttention(config)
        self.meta_embeddings = nn.Embedding(config.vocab_size, config.dim)
        
        self.post_init()

    def freeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(False)

    def unfreeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(True)

    def set_meta_embeddings(self, embed:torch.Tensor):
        with torch.no_grad():
            self.meta_embeddings.weight.copy_(embed)
        
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
    
    def dr(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def dr_fused(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_fused_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def meta(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)
    
    def meta_unnormalized(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return Pooling.mean_pooling(embed, attention_mask)

    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            
            if len(idx):
                assert torch.all(m_args['data2ptr'][idx] == m_args['data2ptr'].max()), f'All datapoints should have same number of metadata.'
                
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
                    m_embed = self.meta_embeddings(m_input_ids)
    
                    m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
                    m_repr_mask = torch.any(m_attention_mask, dim=1)
                    
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
                
                meta_repr[m_key] = m_repr[m_repr_mask]
                meta_repr[m_key] = F.normalize(meta_repr[m_key], dim=1)
                
                fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
                embed[idx] += fused_embed
               
        return embed, meta_repr

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
                                                                            data_attention_mask, 
                                                                            meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

## `OAK000`

In [72]:
#| export
class OAK000(nn.Module):
    
    def __init__(
        self, config,

        data_aug_meta_prefix:Optional[str]=None, 
        lbl2data_aug_meta_prefix:Optional[str]=None, 

        data_pred_meta_prefix:Optional[str]=None,
        lbl2data_pred_meta_prefix:Optional[str]=None,
        
        num_batch_labels:Optional[int]=None, 
        batch_size:Optional[int]=None,
        margin:Optional[float]=0.3,
        num_negatives:Optional[int]=5,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,

        calib_margin:Optional[float]=0.3,
        calib_num_negatives:Optional[int]=10,
        calib_tau:Optional[float]=0.1,
        calib_apply_softmax:Optional[bool]=False,
        calib_loss_weight:Optional[float]=0.1,
        use_calib_loss:Optional[float]=False,
        
        meta_loss_weight:Optional[Union[List,float]]=0.3,
        
        use_fusion_loss:Optional[bool]=False,
        fusion_loss_weight:Optional[float]=0.15,

        use_query_loss:Optional[float]=True,
        
        use_encoder_parallel:Optional[bool]=True,
    ):
        super().__init__(config)
        store_attr('meta_loss_weight,fusion_loss_weight,calib_loss_weight')
        store_attr('data_pred_meta_prefix,lbl2data_pred_meta_prefix')
        store_attr('data_aug_meta_prefix,lbl2data_aug_meta_prefix')
        store_attr('use_fusion_loss,use_query_loss,use_calib_loss,use_encoder_parallel')
        
        self.encoder = None
        self.rep_loss_fn = MultiTriplet(bsz=batch_size, tn_targ=num_batch_labels, margin=margin, n_negatives=num_negatives, 
                                        tau=tau, apply_softmax=apply_softmax, reduce='mean')
        self.cab_loss_fn = Calibration(margin=calib_margin, tau=calib_tau, n_negatives=calib_num_negatives, 
                                       apply_softmax=calib_apply_softmax, reduce='mean')
        
    def init_retrieval_head(self):
        assert self.encoder is not None, 'Encoder is not initialized.'
        self.encoder.dr_head.post_init()
        self.encoder.meta_head.post_init()
        self.encoder.dr_fused_head.post_init()

    def init_cross_head(self):
        assert self.encoder is not None, 'Encoder is not initialized.'
        self.encoder.cross_head.post_init()
        
    def compute_loss(self, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def calibration_loss(self, einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.calib_loss_weight * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)
    
    def compute_meta_loss(self, data_repr, lbl2data_repr, **kwargs):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_meta_inputs = Parameters.from_meta_pred_prefix(self.data_pred_meta_prefix, **kwargs)
        lbl2data_meta_inputs = Parameters.from_meta_pred_prefix(self.lbl2data_pred_meta_prefix, **kwargs)
        meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}

        m_lw = Parameters.get_meta_loss_weights(self.meta_loss_weight, len(meta_inputs)) if len(meta_inputs) else []
        
        loss = 0.0
        for inputs,lw in zip(meta_inputs.values(), m_lw):
            if 'lbl2data2ptr' in inputs:
                idx = torch.where(inputs['lbl2data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(lbl2data_repr[idx], inputs_o.rep, inputs['lbl2data2ptr'][idx],
                                              inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss

            elif 'data2ptr' in inputs:
                idx = torch.where(inputs['data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(data_repr[idx], inputs_o.rep, inputs['data2ptr'][idx], inputs['idx'], 
                                              inputs['pdata2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss       

            else: raise ValueError('Invalid metadata input arguments.')
        return loss

    def compute_fusion_loss(self, data_repr, meta_repr:Dict, prefix:str, **kwargs):
        meta_inputs = Parameters.from_meta_pred_prefix(prefix, **kwargs)
        
        loss = 0.0
        if meta_repr is not None:
            for key,input_repr in meta_repr.items():
                inputs = meta_inputs[key]
                if 'lbl2data2ptr' in inputs:
                    idx = torch.where(inputs['lbl2data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['lbl2data2ptr'][idx],
                                                  inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                        loss += self.fusion_loss_weight * m_loss
    
                elif 'data2ptr' in inputs:
                    idx = torch.where(inputs['data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
                                                  inputs['pdata2ptr'][idx], inputs['pidx'])
                        loss += self.fusion_loss_weight * m_loss       
    
                else: raise ValueError('Invalid metadata input arguments.')
        return loss


    def get_meta_representation(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_unnormalized=True, data_type="meta")
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            # loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            # if self.use_fusion_loss:
            #    loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
            #    loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

## `OAK001`

In [None]:
#| export
class OAK001(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.encoder = Encoder(config)
        self.post_init(); self.remap_post_init();

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()
model.encoder.set_meta_embeddings(model.distilbert.embeddings.word_embeddings.weight)

Some weights of OAK001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'lnk2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'lnk2lbl2data_idx', 'lnk2lbl2data_identifier', 'lnk2lbl2data_input_text', 'lnk2lbl2data_input_ids', 'lnk2lbl2data_attention_mask', 'lnk2lbl2data_data2ptr', 'lnk2lbl2data_lbl2data2ptr'])

In [None]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(0.0289, device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [None]:
o = func()

> /tmp/ipykernel_28793/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  b model.encoder.cross_head.forward


Breakpoint 5 at /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py:75


ipdb>  c


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(83)forward()
     81         output_attentions:Optional[bool] = False,
     82     ):
---> 83         bs, q_len, dim = q.size()
     84         v, k_len = k, k.size(1)
     85 



ipdb>  !q


tensor([[[-0.6077, -1.4566, -1.2658,  ...,  0.4173,  0.4841, -0.4502],
         [-0.6280, -0.1254, -0.2795,  ...,  0.7303,  0.4780, -0.8035],
         [-0.5878,  0.2601, -0.8291,  ...,  0.7271, -0.0244, -0.1744],
         [-0.3885, -0.0131, -0.3512,  ...,  0.8768, -0.2073, -0.3525],
         [-0.2144, -1.0242, -1.2639,  ...,  1.1083,  0.2689,  0.1441]],

        [[-0.0116, -0.6070, -1.0937,  ...,  0.4919, -0.6339,  0.3495],
         [-0.4328, -0.5217, -0.7185,  ...,  0.3906, -0.9542,  0.9589],
         [-0.6673, -0.8967, -1.2427,  ...,  0.3123, -0.9577, -0.1136],
         [-0.2185,  0.0981, -0.4987,  ...,  0.7605, -0.7738,  0.0544],
         [ 0.1080,  0.3637, -0.7733,  ...,  0.7831, -0.8321,  0.0020]],

        [[-0.1806,  0.5675, -1.1267,  ...,  1.1764,  0.0713,  0.6914],
         [ 0.0746,  0.4647, -1.5286,  ...,  0.7264,  0.2030,  0.5578],
         [-0.2783,  0.2235, -0.8286,  ...,  0.4618,  0.0872, -0.2953],
         [ 0.7958, -0.2424, -1.0187,  ...,  0.6092, -0.0145,  0.7918],
  

ipdb>  !q_m


tensor([[1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0]], device='cuda:0')


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(84)forward()
     82     ):
     83         bs, q_len, dim = q.size()
---> 84         v, k_len = k, k.size(1)
     85 
     86         h_dim = self.dim//self.n_h



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(86)forward()
     84         v, k_len = k, k.size(1)
     85 
---> 86         h_dim = self.dim//self.n_h
     87 
     88         def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(88)forward()
     86         h_dim = self.dim//self.n_h
     87 
---> 88         def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)
     89 
     90         def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(90)forward()
     88         def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)
     89 
---> 90         def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)
     91 
     92         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(92)forward()
     90         def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)
     91 
---> 92         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
     93         k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
     94         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(93)forward()
     91 
     92         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
---> 93         k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
     94         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)
     95 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(94)forward()
     92         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
     93         k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
---> 94         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)
     95 
     96         q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(96)forward()
     94         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)
     95 
---> 96         q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
     97         sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
     98 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(97)forward()
     95 
     96         q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
---> 97         sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
     98 
     99         q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(99)forward()
     97         sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
     98 
---> 99         q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
    100         mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
    101 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(100)forward()
     98 
     99         q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
--> 100         mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
    101 
    102         sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(102)forward()
    100         mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
    101 
--> 102         sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)
    103 
    104         w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)



ipdb>  mask.shape


torch.Size([5, 12, 5, 3])


ipdb>  mask[0][0]


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [0., 0., 0.]], device='cuda:0')


ipdb>  mask[1][0]


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], device='cuda:0')


ipdb>  mask[2][0]


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], device='cuda:0')


ipdb>  mask[3][0]


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')


ipdb>  mask[4][0]


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')


ipdb>  mask[5][0]


*** IndexError: index 5 is out of bounds for dimension 0 with size 5


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(104)forward()
    102         sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)
    103 
--> 104         w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
    105         w = self.dropout(w)  # (bs, n_h, q_len, k_len)
    106 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(105)forward()
    103 
    104         w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
--> 105         w = self.dropout(w)  # (bs, n_h, q_len, k_len)
    106 
    107         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(107)forward()
    105         w = self.dropout(w)  # (bs, n_h, q_len, k_len)
    106 
--> 107         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
    108 
    109         if output_attentions: return (o, w)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(109)forward()
    107         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
    108 
--> 109         if output_attentions: return (o, w)
    110         else: return (o,)
    111 



ipdb>  


--Return--
(tensor([[[ 4....iewBackward0>), tensor([[[[0....maxBackward0>))
> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(109)forward()
    107         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
    108 
--> 109         if output_attentions: return (o, w)
    110         else: return (o,)
    111 



ipdb>  


--Call--
> /tmp/ipykernel_28793/505101603.py(46)dr_fused()
     44         return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)
     45 
---> 46     def dr_fused(self, embed:torch.Tensor, attention_mask:torch.Tensor):
     47         embed = self.dr_fused_head(embed)
     48         return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



## `OAK002`

In [None]:
#| export
class Encoder002(Encoder):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.post_init()

    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            
            if len(idx):
                assert torch.all(m_args['data2ptr'][idx] == m_args['data2ptr'].max()), f'All datapoints should have same number of metadata.'
                
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
                    m_embed = self.meta_embeddings(m_input_ids)

                m_repr, m_repr_mask = m_embed.view(len(idx), -1, self.config.dim), m_attention_mask.view(len(idx), -1)
                meta_repr[m_key] = self.meta(m_embed, m_attention_mask)
                
                fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
                embed[idx] += fused_embed
               
        return embed, meta_repr
        

In [None]:
#| export
class OAK002(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.encoder = Encoder002(config)
        self.post_init(); self.remap_post_init();

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = OAK002.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()
model.encoder.set_meta_embeddings(model.distilbert.embeddings.word_embeddings.weight)

Some weights of OAK002 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'lnk2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'lnk2lbl2data_idx', 'lnk2lbl2data_identifier', 'lnk2lbl2data_input_text', 'lnk2lbl2data_input_ids', 'lnk2lbl2data_attention_mask', 'lnk2lbl2data_data2ptr', 'lnk2lbl2data_lbl2data2ptr'])

In [None]:
o = model(**b.to(model.device))

> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(83)forward()
     81         output_attentions:Optional[bool] = False,
     82     ):
---> 83         bs, q_len, dim = q.size()
     84         v, k_len = k, k.size(1)
     85 



ipdb>  c


In [None]:
o.loss

tensor(0.0132, device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))

In [None]:
o = func()

> /tmp/ipykernel_28793/1795951242.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))



ipdb>  b model.encoder.forward


Breakpoint 8 at /tmp/ipykernel_28793/505101603.py:88


ipdb>  c


> /tmp/ipykernel_28793/505101603.py(97)forward()
     95         **kwargs
     96     ):  
---> 97         data_o = self.encode(data_input_ids, data_attention_mask)
     98 
     99         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_28793/505101603.py(99)forward()
     97         data_o = self.encode(data_input_ids, data_attention_mask)
     98 
---> 99         if data_type is not None and data_type == "meta":
    100             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    101         else:



ipdb>  


> /tmp/ipykernel_28793/505101603.py(102)forward()
    100             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    101         else:
--> 102             data_repr = self.dr(data_o[0], data_attention_mask)
    103 
    104         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_28793/505101603.py(104)forward()
    102             data_repr = self.dr(data_o[0], data_attention_mask)
    103 
--> 104         data_fused_repr = meta_repr = None
    105         if data_aug_meta_prefix is not None:
    106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_28793/505101603.py(105)forward()
    103 
    104         data_fused_repr = meta_repr = None
--> 105         if data_aug_meta_prefix is not None:
    106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    107             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_28793/505101603.py(106)forward()
    104         data_fused_repr = meta_repr = None
    105         if data_aug_meta_prefix is not None:
--> 106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    107             if len(meta_kwargs):
    108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 



ipdb>  n


> /tmp/ipykernel_28793/505101603.py(107)forward()
    105         if data_aug_meta_prefix is not None:
    106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 107             if len(meta_kwargs):
    108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    109                                                                             data_attention_mask,



ipdb>  n


> /tmp/ipykernel_28793/505101603.py(108)forward()
    106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    107             if len(meta_kwargs):
--> 108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    109                                                                             data_attention_mask,
    110                                                                             meta_kwargs)



ipdb>  


> /tmp/ipykernel_28793/505101603.py(109)forward()
    107             if len(meta_kwargs):
    108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
--> 109                                                                             data_attention_mask,
    110                                                                             meta_kwargs)
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)



ipdb>  


> /tmp/ipykernel_28793/505101603.py(110)forward()
    108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    109                                                                             data_attention_mask,
--> 110                                                                             meta_kwargs)
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 



ipdb>  


> /tmp/ipykernel_28793/505101603.py(108)forward()
    106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    107             if len(meta_kwargs):
--> 108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    109                                                                             data_attention_mask,
    110                                                                             meta_kwargs)



ipdb>  s


--Call--
> /tmp/ipykernel_28793/219820846.py(11)fuse_meta_into_embeddings()
      9         self.post_init()
     10 
---> 11     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
     12         meta_repr = {}
     13 



ipdb>  n


> /tmp/ipykernel_28793/219820846.py(12)fuse_meta_into_embeddings()
     10 
     11     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
---> 12         meta_repr = {}
     13 
     14         for m_key, m_args in meta_kwargs.items():



ipdb>  


> /tmp/ipykernel_28793/219820846.py(14)fuse_meta_into_embeddings()
     12         meta_repr = {}
     13 
---> 14         for m_key, m_args in meta_kwargs.items():
     15             idx = torch.where(m_args['data2ptr'] > 0)[0]
     16             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_28793/219820846.py(15)fuse_meta_into_embeddings()
     13 
     14         for m_key, m_args in meta_kwargs.items():
---> 15             idx = torch.where(m_args['data2ptr'] > 0)[0]
     16             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     17 



ipdb>  


> /tmp/ipykernel_28793/219820846.py(16)fuse_meta_into_embeddings()
     14         for m_key, m_args in meta_kwargs.items():
     15             idx = torch.where(m_args['data2ptr'] > 0)[0]
---> 16             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     17 
     18             if len(idx):



ipdb>  


> /tmp/ipykernel_28793/219820846.py(18)fuse_meta_into_embeddings()
     16             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     17 
---> 18             if len(idx):
     19                 assert torch.all(m_args['data2ptr'][idx] == m_args['data2ptr'].max()), f'All datapoints should have same number of metadata.'
     20 



ipdb>  


> /tmp/ipykernel_28793/219820846.py(19)fuse_meta_into_embeddings()
     17 
     18             if len(idx):
---> 19                 assert torch.all(m_args['data2ptr'][idx] == m_args['data2ptr'].max()), f'All datapoints should have same number of metadata.'
     20 
     21                 if 'meta_repr' in m_args:



ipdb>  


> /tmp/ipykernel_28793/219820846.py(21)fuse_meta_into_embeddings()
     19                 assert torch.all(m_args['data2ptr'][idx] == m_args['data2ptr'].max()), f'All datapoints should have same number of metadata.'
     20 
---> 21                 if 'meta_repr' in m_args:
     22                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
     23                     m_repr_mask = m_repr_mask.bool()



ipdb>  


> /tmp/ipykernel_28793/219820846.py(25)fuse_meta_into_embeddings()
     23                     m_repr_mask = m_repr_mask.bool()
     24                 else:
---> 25                     m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
     26                     m_embed = self.meta_embeddings(m_input_ids)
     27 



ipdb>  


> /tmp/ipykernel_28793/219820846.py(26)fuse_meta_into_embeddings()
     24                 else:
     25                     m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
---> 26                     m_embed = self.meta_embeddings(m_input_ids)
     27 
     28                 m_repr, m_repr_mask = m_embed.view(len(idx), -1, self.config.dim), m_attention_mask.view(len(idx), -1)



ipdb>  


> /tmp/ipykernel_28793/219820846.py(28)fuse_meta_into_embeddings()
     26                     m_embed = self.meta_embeddings(m_input_ids)
     27 
---> 28                 m_repr, m_repr_mask = m_embed.view(len(idx), -1, self.config.dim), m_attention_mask.view(len(idx), -1)
     29                 meta_repr[m_key] = self.meta(m_embed, m_attention_mask)
     30 



ipdb>  m_embed.shape


torch.Size([15, 14, 768])


ipdb>  n


> /tmp/ipykernel_28793/219820846.py(29)fuse_meta_into_embeddings()
     27 
     28                 m_repr, m_repr_mask = m_embed.view(len(idx), -1, self.config.dim), m_attention_mask.view(len(idx), -1)
---> 29                 meta_repr[m_key] = self.meta(m_embed, m_attention_mask)
     30 
     31                 fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)



ipdb>  m_repr.shape


torch.Size([5, 42, 768])


ipdb>  m_repr_mask.shape


torch.Size([5, 42])


ipdb>  m_repr_mask


tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])


ipdb>  n


> /tmp/ipykernel_28793/219820846.py(31)fuse_meta_into_embeddings()
     29                 meta_repr[m_key] = self.meta(m_embed, m_attention_mask)
     30 
---> 31                 fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
     32                 embed[idx] += fused_embed
     33 



ipdb>  meta_repr[m_key].shape


torch.Size([15, 768])


ipdb>  m_attention_mask.shape


torch.Size([15, 14])


ipdb>  m_repr_mask.shape


torch.Size([5, 42])


ipdb>  m_repr.shape


torch.Size([5, 42, 768])


ipdb>  embed.shape


torch.Size([5, 5, 768])


ipdb>  n


> /tmp/ipykernel_28793/219820846.py(32)fuse_meta_into_embeddings()
     30 
     31                 fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
---> 32                 embed[idx] += fused_embed
     33 
     34         return embed, meta_repr



ipdb>  fused_embed.shape


torch.Size([5, 5, 768])


ipdb>  n


> /tmp/ipykernel_28793/219820846.py(14)fuse_meta_into_embeddings()
     12         meta_repr = {}
     13 
---> 14         for m_key, m_args in meta_kwargs.items():
     15             idx = torch.where(m_args['data2ptr'] > 0)[0]
     16             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_28793/219820846.py(34)fuse_meta_into_embeddings()
     31                 fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
     32                 embed[idx] += fused_embed
     33 
---> 34         return embed, meta_repr
     35 



ipdb>  


--Return--
(tensor([[[-6....PutBackward0>), {'lnk2data': tensor([[-0.0...DivBackward0>)})
> /tmp/ipykernel_28793/219820846.py(34)fuse_meta_into_embeddings()
     31                 fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
     32                 embed[idx] += fused_embed
     33 
---> 34         return embed, meta_repr
     35 



ipdb>  


> /tmp/ipykernel_28793/505101603.py(111)forward()
    109                                                                             data_attention_mask,
    110                                                                             meta_kwargs)
--> 111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
    113         return EncoderOutput(



ipdb>  c


> /tmp/ipykernel_28793/505101603.py(113)forward()
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
--> 113         return EncoderOutput(
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,



ipdb>  c


> /tmp/ipykernel_28793/505101603.py(114)forward()
    112 
    113         return EncoderOutput(
--> 114             rep=data_repr,
    115             fused_rep=data_fused_repr,
    116             meta_repr=meta_repr,



ipdb>  c


> /tmp/ipykernel_28793/505101603.py(115)forward()
    113         return EncoderOutput(
    114             rep=data_repr,
--> 115             fused_rep=data_fused_repr,
    116             meta_repr=meta_repr,
    117         )



ipdb>  n


> /tmp/ipykernel_28793/505101603.py(116)forward()
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,
--> 116             meta_repr=meta_repr,
    117         )
    118 



ipdb>  


> /tmp/ipykernel_28793/505101603.py(113)forward()
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
--> 113         return EncoderOutput(
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /tmp/ipykernel_28793/505101603.py(113)forward()
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
--> 113         return EncoderOutput(
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,



ipdb>  


--Call--
> <string>(2)__init__()



ipdb>  


> <string>(3)__init__()



ipdb>  


> <string>(4)__init__()



ipdb>  


> <string>(5)__init__()



ipdb>  


> <string>(6)__init__()



ipdb>  c


> /tmp/ipykernel_28793/505101603.py(97)forward()
     95         **kwargs
     96     ):  
---> 97         data_o = self.encode(data_input_ids, data_attention_mask)
     98 
     99         if data_type is not None and data_type == "meta":



ipdb>  


> /tmp/ipykernel_28793/505101603.py(113)forward()
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
--> 113         return EncoderOutput(
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_28793/505101603.py(114)forward()
    112 
    113         return EncoderOutput(
--> 114             rep=data_repr,
    115             fused_rep=data_fused_repr,
    116             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_28793/505101603.py(115)forward()
    113         return EncoderOutput(
    114             rep=data_repr,
--> 115             fused_rep=data_fused_repr,
    116             meta_repr=meta_repr,
    117         )



ipdb>  


> /tmp/ipykernel_28793/505101603.py(113)forward()
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
--> 113         return EncoderOutput(
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,



ipdb>  


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [None]:
o.loss

tensor(0.0134, grad_fn=<AddBackward0>)

## `OAK003`

In [24]:
#| export
class CrossAttention003(nn.Module):
    
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config, self.n_h, self.dim = config, config.n_heads, config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        if self.dim % self.n_h != 0:
            raise ValueError(f"self.n_heads: {self.n_h} must divide self.dim: {self.dim} evenly.")
            
        self.q = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.o = nn.Linear(in_features=config.dim, out_features=config.dim)

    def post_init(self):
        torch.nn.init.eye_(self.q.weight)
        torch.nn.init.eye_(self.k.weight)
        torch.nn.init.eye_(self.v.weight)
        torch.nn.init.eye_(self.o.weight)

    def forward(
        self, 
        q: torch.Tensor,
        k: torch.Tensor, 
        output_attentions:Optional[bool] = False,
    ):  
        bsz, dim, v = q.size(0), q.size(1), k
        h_dim = self.dim//self.n_h

        def shape(x: torch.Tensor): return x.view(bsz, -1, self.n_h, h_dim).transpose(1, 2)
        def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bsz, -1, self.n_h * h_dim)

        q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
        k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
        v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
        o = unshape(y)

        return o.squeeze(1)
        

In [25]:
#| export
class Encoder003(DistilBertPreTrainedModel):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.dr_distilbert = DistilBertModel(config)
        self.meta_distilbert = DistilBertModel(config)
        
        self.dr_head = RepresentationHead(config)
        self.cross_head = CrossAttention003(config)
        
        self.post_init()
    
    def dr_encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        o = self.dr_distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        return Pooling.mean_pooling(o[0], attention_mask)

    def meta_encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        o = self.meta_distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        return Pooling.mean_pooling(o[0], attention_mask)

    def dr(self, embed:torch.Tensor):
        return F.normalize(self.dr_head(embed), dim=1)

    def init_meta_encoder(self):
        sd_meta, sd_dr = self.meta_distilbert.state_dict(), self.dr_distilbert.state_dict()
        for k in sd_dr:
            assert sd_meta[k].shape == sd_dr[k].shape
            with torch.no_grad():
                sd_meta[k].copy_(sd_dr[k])

    def fuse_meta_into_embeddings(self, embed:torch.Tensor, meta_kwargs:Dict):
        meta_repr, bsz = {}, embed.size(0)
        
        for m_key, m_args in meta_kwargs.items():
            n_meta, valid_meta_idx = m_args['data2ptr'].max(), torch.where(m_args['data2ptr'] > 0)[0]
            assert torch.all(m_args['data2ptr'][valid_meta_idx] == n_meta), f'All datapoints should have same number of metadata.'
            
            m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']    
            if m_input_ids.shape[0] > 0: m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
            else: m_embed = torch.zeros(size=(0,self.config.dim)).to(m_input_ids.device)
            meta_repr[m_key] = m_embed
                            
            m_embed = m_embed.view(len(valid_meta_idx), -1, self.config.dim)  
            fused_embed = self.cross_head(embed[valid_meta_idx], m_embed)
            embed[valid_meta_idx] = embed[valid_meta_idx] + fused_embed
               
        return embed, meta_repr

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):  
        data_repr = self.dr_encode(data_input_ids, data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
                data_fused_repr = self.dr(data_fused_repr)

        data_repr = F.normalize(data_repr, dim=1)
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

In [26]:
#| export
class OAK003(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.dr_distilbert"]
    _keys_to_ignore_on_load_missing = ["encoder.meta_distilbert"]

    @delegates(OAK000.__init__)
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.encoder = Encoder003(config)
        self.post_init(); self.remap_post_init();

    def init_retrieval_head(self):
        self.encoder.dr_head.post_init()

    def init_meta_encoder(self):
        self.encoder.init_meta_encoder()

    def remap_post_init(self):
        self.distilbert = self.encoder.dr_distilbert
        

### Example

In [None]:
model = OAK003.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.0, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_meta_encoder()
model.init_cross_head()

Some weights of OAK003 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'lnk2data_data2ptr', 'lnk2lbl2data_idx', 'lnk2lbl2data_identifier', 'lnk2lbl2data_input_text', 'lnk2lbl2data_input_ids', 'lnk2lbl2data_attention_mask', 'lnk2lbl2data_data2ptr', 'lnk2lbl2data_lbl2data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx'])

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.0175, device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))

In [None]:
o = func()

> /tmp/ipykernel_7033/1795951242.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))



ipdb>  b model.forward


Breakpoint 6 at /tmp/ipykernel_7033/2804362729.py:140


ipdb>  b model.encoder.forward


Breakpoint 7 at /tmp/ipykernel_7033/1294998010.py:60


ipdb>  c


> /tmp/ipykernel_7033/2804362729.py(155)forward()
    153         **kwargs
    154     ):  
--> 155         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    156 
    157         if self.use_encoder_parallel:



ipdb>  c


> /tmp/ipykernel_7033/2804362729.py(157)forward()
    155         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    156 
--> 157         if self.use_encoder_parallel:
    158             encoder = XCDataParallel(module=self.encoder)
    159         else: encoder = self.encoder



ipdb>  n


> /tmp/ipykernel_7033/2804362729.py(159)forward()
    157         if self.use_encoder_parallel:
    158             encoder = XCDataParallel(module=self.encoder)
--> 159         else: encoder = self.encoder
    160 
    161         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_7033/2804362729.py(161)forward()
    159         else: encoder = self.encoder
    160 
--> 161         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
    162         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    163                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_7033/2804362729.py(162)forward()
    160 
    161         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
--> 162         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    163                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    164 



ipdb>  c


> /tmp/ipykernel_7033/1294998010.py(69)forward()
     67         **kwargs
     68     ):  
---> 69         data_repr = self.dr_encode(data_input_ids, data_attention_mask)
     70 
     71         data_fused_repr = meta_repr = None



ipdb>  n


> /tmp/ipykernel_7033/1294998010.py(71)forward()
     69         data_repr = self.dr_encode(data_input_ids, data_attention_mask)
     70 
---> 71         data_fused_repr = meta_repr = None
     72         if data_aug_meta_prefix is not None:
     73             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  n


> /tmp/ipykernel_7033/1294998010.py(72)forward()
     70 
     71         data_fused_repr = meta_repr = None
---> 72         if data_aug_meta_prefix is not None:
     73             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     74             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_7033/1294998010.py(73)forward()
     71         data_fused_repr = meta_repr = None
     72         if data_aug_meta_prefix is not None:
---> 73             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     74             if len(meta_kwargs):
     75                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)



ipdb>  


> /tmp/ipykernel_7033/1294998010.py(74)forward()
     72         if data_aug_meta_prefix is not None:
     73             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
---> 74             if len(meta_kwargs):
     75                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
     76                 data_fused_repr = self.dr(data_fused_repr)



ipdb>  


> /tmp/ipykernel_7033/1294998010.py(75)forward()
     73             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     74             if len(meta_kwargs):
---> 75                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
     76                 data_fused_repr = self.dr(data_fused_repr)
     77 



ipdb>  


> /tmp/ipykernel_7033/1294998010.py(76)forward()
     74             if len(meta_kwargs):
     75                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
---> 76                 data_fused_repr = self.dr(data_fused_repr)
     77 
     78         data_repr = F.normalize(data_repr, dim=1)



ipdb>  


> /tmp/ipykernel_7033/1294998010.py(78)forward()
     76                 data_fused_repr = self.dr(data_fused_repr)
     77 
---> 78         data_repr = F.normalize(data_repr, dim=1)
     79 
     80         return EncoderOutput(



ipdb>  


> /tmp/ipykernel_7033/1294998010.py(80)forward()
     78         data_repr = F.normalize(data_repr, dim=1)
     79 
---> 80         return EncoderOutput(
     81             rep=data_repr,
     82             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_7033/1294998010.py(81)forward()
     79 
     80         return EncoderOutput(
---> 81             rep=data_repr,
     82             fused_rep=data_fused_repr,
     83             meta_repr=meta_repr,



ipdb>  c


> /tmp/ipykernel_7033/1294998010.py(69)forward()
     67         **kwargs
     68     ):  
---> 69         data_repr = self.dr_encode(data_input_ids, data_attention_mask)
     70 
     71         data_fused_repr = meta_repr = None



ipdb>  n


> /tmp/ipykernel_7033/1294998010.py(71)forward()
     69         data_repr = self.dr_encode(data_input_ids, data_attention_mask)
     70 
---> 71         data_fused_repr = meta_repr = None
     72         if data_aug_meta_prefix is not None:
     73             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  r


--Return--
EncoderOutput...eta_repr=None)
> /tmp/ipykernel_7033/1294998010.py(80)forward()
     78         data_repr = F.normalize(data_repr, dim=1)
     79 
---> 80         return EncoderOutput(
     81             rep=data_repr,
     82             fused_rep=data_fused_repr,



ipdb>  n


> /tmp/ipykernel_7033/2804362729.py(172)forward()
    170                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    171 
--> 172             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
    173                                      plbl2data_data2ptr,plbl2data_idx)
    174 



ipdb>  data_o.fused_rep.norm(dim=1)


tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000], device='cuda:0',
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  lbl2data_o.rep.norm(dim=1)


tensor([1., 1., 1., 1., 1.], device='cuda:0',
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  lbl2data_o.rep.dtype


torch.float32


ipdb>  lbl2data_o.rep.norm(dim=1)


tensor([1., 1., 1., 1., 1.], device='cuda:0',
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  data_o.fused_rep.norm(dim=1)


tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000], device='cuda:0',
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  data_o.fused_rep.dtype


torch.float32


ipdb>  c


> /tmp/ipykernel_7033/2804362729.py(179)forward()
    177                                           plbl2data_data2ptr,plbl2data_idx)
    178 
--> 179             if self.use_calib_loss:
    180                 loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
    181                                               plbl2data_data2ptr,plbl2data_idx)



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [None]:
sd1 = model.encoder.dr_distilbert.state_dict()
sd2 = model.encoder.meta_distilbert.state_dict()

## `OAK004`

In [None]:
#| export
class Encoder004(DistilBertPreTrainedModel):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        self.dr_head = RepresentationHead(config)
        self.cross_head = CrossAttention003(config)
        self.post_init()
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        o = self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        return Pooling.mean_pooling(o[0], attention_mask)

    def dr(self, embed:torch.Tensor):
        return F.normalize(self.dr_head(embed), dim=1)

    def fuse_meta_into_embeddings(self, embed:torch.Tensor, meta_kwargs:Dict):
        meta_repr, bsz = {}, embed.size(0)
        
        for m_key, m_args in meta_kwargs.items():
            n_meta, valid_meta_idx = m_args['data2ptr'].max(), torch.where(m_args['data2ptr'] > 0)[0]
            assert torch.all(m_args['data2ptr'][valid_meta_idx] == n_meta), f'All datapoints should have same number of metadata.'
            
            m_embed = m_args['meta_repr']
            meta_repr[m_key] = m_embed
                    
            m_embed = m_embed.view(len(valid_meta_idx), -1, self.config.dim)  
            fused_embed = self.cross_head(embed[valid_meta_idx], m_embed)
            embed[valid_meta_idx] = embed[valid_meta_idx] + fused_embed
               
        return embed, meta_repr
        
    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):  
        data_repr = self.encode(data_input_ids, data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
                data_fused_repr = self.dr(data_fused_repr)

        data_repr = F.normalize(data_repr, dim=1)
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

In [None]:
#| export
class OAK004(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(self, config, num_metadata:int, **kwargs):
        super().__init__(config, **kwargs)
        self.encoder = Encoder004(config)
        self.meta_embeddings = nn.Embedding(num_metadata, config.dim, sparse=True)
        self.post_init(); self.remap_post_init();

    def init_retrieval_head(self):
        self.encoder.dr_head.post_init()

    def init_meta_embeddings(self):
        torch.nn.init.zeros_(self.meta_embeddings.weight)

    def freeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(False)

    def unfreeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(True)

    def set_meta_embeddings(self, embed:torch.Tensor):
        with torch.no_grad():
            self.meta_embeddings.weight.copy_(embed)

    def _get_encoder_meta_kwargs(self, feat:str, prefix:str, **kwargs):
        meta_kwargs = Parameters.from_feat_meta_aug_prefix(feat, prefix, **kwargs)
        if f'{prefix}_idx' in meta_kwargs:
            m_idx = meta_kwargs[f'{prefix}_idx']
            if len(m_idx): meta_kwargs[f'{prefix}_meta_repr'] = self.meta_embeddings(m_idx)
        return meta_kwargs

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        
        

### Example

In [None]:
model = OAK004.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta,

                               use_query_loss=True,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()

model.set_meta_embeddings(torch.ones(block.train.dset.meta['lnk_meta'].n_meta, model.config.dim))

Some weights of OAK004 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'meta_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_data2ptr',
])

In [None]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'lnk2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'lnk2lbl2data_idx', 'lnk2lbl2data_identifier', 'lnk2lbl2data_input_text', 'lnk2lbl2data_input_ids', 'lnk2lbl2data_attention_mask', 'lnk2lbl2data_data2ptr', 'lnk2lbl2data_lbl2data2ptr'])

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))

In [None]:
o = func()

> /tmp/ipykernel_843/1795951242.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))



ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_843/2142014626.py:39


ipdb>  b model.encoder.forward


Breakpoint 2 at /tmp/ipykernel_843/2257606702.py:41


ipdb>  c


> /tmp/ipykernel_843/2142014626.py(54)forward()
     52         **kwargs
     53     ):  
---> 54         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     55 
     56         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_843/2142014626.py(56)forward()
     54         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     55 
---> 56         if self.use_encoder_parallel:
     57             encoder = XCDataParallel(module=self.encoder)
     58         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_843/2142014626.py(58)forward()
     56         if self.use_encoder_parallel:
     57             encoder = XCDataParallel(module=self.encoder)
---> 58         else: encoder = self.encoder
     59 
     60         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_843/2142014626.py(60)forward()
     58         else: encoder = self.encoder
     59 
---> 60         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
     61         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     62                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_843/2142014626.py(61)forward()
     59 
     60         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
---> 61         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     62                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     63 



ipdb>  data_meta_kwargs.keys()


dict_keys(['lnk2data_idx', 'lnk2data_data2ptr', 'lnk2data_meta_repr'])


ipdb>  c


> /tmp/ipykernel_843/2257606702.py(50)forward()
     48         **kwargs
     49     ):  
---> 50         data_repr = self.encode(data_input_ids, data_attention_mask)
     51 
     52         data_fused_repr = meta_repr = None



ipdb>  n


> /tmp/ipykernel_843/2257606702.py(52)forward()
     50         data_repr = self.encode(data_input_ids, data_attention_mask)
     51 
---> 52         data_fused_repr = meta_repr = None
     53         if data_aug_meta_prefix is not None:
     54             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  data_repr.shape


torch.Size([5, 768])


ipdb>  n


> /tmp/ipykernel_843/2257606702.py(53)forward()
     51 
     52         data_fused_repr = meta_repr = None
---> 53         if data_aug_meta_prefix is not None:
     54             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     55             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_843/2257606702.py(54)forward()
     52         data_fused_repr = meta_repr = None
     53         if data_aug_meta_prefix is not None:
---> 54             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     55             if len(meta_kwargs):
     56                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)



ipdb>  


> /tmp/ipykernel_843/2257606702.py(55)forward()
     53         if data_aug_meta_prefix is not None:
     54             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
---> 55             if len(meta_kwargs):
     56                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
     57                 data_fused_repr = self.dr(data_fused_repr)



ipdb>  


> /tmp/ipykernel_843/2257606702.py(56)forward()
     54             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     55             if len(meta_kwargs):
---> 56                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
     57                 data_fused_repr = self.dr(data_fused_repr)
     58 



ipdb>  s


--Call--
> /tmp/ipykernel_843/2257606702.py(25)fuse_meta_into_embeddings()
     23         return F.normalize(self.dr_head(embed), dim=1)
     24 
---> 25     def fuse_meta_into_embeddings(self, embed:torch.Tensor, meta_kwargs:Dict):
     26         meta_repr, bsz = {}, embed.size(0)
     27 



ipdb>  n


> /tmp/ipykernel_843/2257606702.py(26)fuse_meta_into_embeddings()
     24 
     25     def fuse_meta_into_embeddings(self, embed:torch.Tensor, meta_kwargs:Dict):
---> 26         meta_repr, bsz = {}, embed.size(0)
     27 
     28         for m_key, m_args in meta_kwargs.items():



ipdb>  


> /tmp/ipykernel_843/2257606702.py(28)fuse_meta_into_embeddings()
     26         meta_repr, bsz = {}, embed.size(0)
     27 
---> 28         for m_key, m_args in meta_kwargs.items():
     29             n_meta = m_args['data2ptr'].max()
     30             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'



ipdb>  


> /tmp/ipykernel_843/2257606702.py(29)fuse_meta_into_embeddings()
     27 
     28         for m_key, m_args in meta_kwargs.items():
---> 29             n_meta = m_args['data2ptr'].max()
     30             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     31 



ipdb>  


> /tmp/ipykernel_843/2257606702.py(30)fuse_meta_into_embeddings()
     28         for m_key, m_args in meta_kwargs.items():
     29             n_meta = m_args['data2ptr'].max()
---> 30             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     31 
     32             m_embed = m_args['meta_repr']



ipdb>  n_meta


tensor(3, device='cuda:0')


ipdb>  n


> /tmp/ipykernel_843/2257606702.py(32)fuse_meta_into_embeddings()
     30             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     31 
---> 32             m_embed = m_args['meta_repr']
     33             meta_repr[m_key] = m_embed
     34 



ipdb>  


> /tmp/ipykernel_843/2257606702.py(33)fuse_meta_into_embeddings()
     31 
     32             m_embed = m_args['meta_repr']
---> 33             meta_repr[m_key] = m_embed
     34 
     35             m_embed = m_embed.view(bsz, -1, self.config.dim)



ipdb>  m_embed.shape


torch.Size([15, 768])


ipdb>  n


> /tmp/ipykernel_843/2257606702.py(35)fuse_meta_into_embeddings()
     33             meta_repr[m_key] = m_embed
     34 
---> 35             m_embed = m_embed.view(bsz, -1, self.config.dim)
     36             fused_embed = self.cross_head(embed, m_embed)
     37             embed = embed + fused_embed



ipdb>  


> /tmp/ipykernel_843/2257606702.py(36)fuse_meta_into_embeddings()
     34 
     35             m_embed = m_embed.view(bsz, -1, self.config.dim)
---> 36             fused_embed = self.cross_head(embed, m_embed)
     37             embed = embed + fused_embed
     38 



ipdb>  m_embed.shape


torch.Size([5, 3, 768])


ipdb>  n


> /tmp/ipykernel_843/2257606702.py(37)fuse_meta_into_embeddings()
     35             m_embed = m_embed.view(bsz, -1, self.config.dim)
     36             fused_embed = self.cross_head(embed, m_embed)
---> 37             embed = embed + fused_embed
     38 
     39         return embed, meta_repr



ipdb>  embed.shape


torch.Size([5, 768])


ipdb>  m_embed.shape


torch.Size([5, 3, 768])


ipdb>  n


> /tmp/ipykernel_843/2257606702.py(28)fuse_meta_into_embeddings()
     26         meta_repr, bsz = {}, embed.size(0)
     27 
---> 28         for m_key, m_args in meta_kwargs.items():
     29             n_meta = m_args['data2ptr'].max()
     30             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'



ipdb>  


> /tmp/ipykernel_843/2257606702.py(39)fuse_meta_into_embeddings()
     37             embed = embed + fused_embed
     38 
---> 39         return embed, meta_repr
     40 
2    41     def forward(



ipdb>  


--Return--
(tensor([[ 0.4...AddBackward0>), {'lnk2data': tensor([[1., ...ingBackward0>)})
> /tmp/ipykernel_843/2257606702.py(39)fuse_meta_into_embeddings()
     37             embed = embed + fused_embed
     38 
---> 39         return embed, meta_repr
     40 
2    41     def forward(



ipdb>  


> /tmp/ipykernel_843/2257606702.py(57)forward()
     55             if len(meta_kwargs):
     56                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
---> 57                 data_fused_repr = self.dr(data_fused_repr)
     58 
     59         data_repr = F.normalize(data_repr, dim=1)



ipdb>  


> /tmp/ipykernel_843/2257606702.py(59)forward()
     57                 data_fused_repr = self.dr(data_fused_repr)
     58 
---> 59         data_repr = F.normalize(data_repr, dim=1)
     60 
     61         return EncoderOutput(



ipdb>  


> /tmp/ipykernel_843/2257606702.py(61)forward()
     59         data_repr = F.normalize(data_repr, dim=1)
     60 
---> 61         return EncoderOutput(
     62             rep=data_repr,
     63             fused_rep=data_fused_repr,



ipdb>  c


> /tmp/ipykernel_843/2257606702.py(50)forward()
     48         **kwargs
     49     ):  
---> 50         data_repr = self.encode(data_input_ids, data_attention_mask)
     51 
     52         data_fused_repr = meta_repr = None



ipdb>  c


  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



## `OAK005`

In [96]:
#| export
class Encoder005(Encoder003):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.alpha = nn.Parameter(torch.ones(1))

    def fuse_meta_into_embeddings(self, embed:torch.Tensor, meta_kwargs:Dict):
        meta_repr, bsz = {}, embed.size(0)
        
        for m_key, m_args in meta_kwargs.items():
            n_meta = m_args['data2ptr'].max()
            assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
            
            m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
            m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
            m_embed = m_embed + m_args['meta_repr'] * self.alpha
            
            meta_repr[m_key] = m_embed
                    
            m_embed = m_embed.view(bsz, -1, self.config.dim)  
            fused_embed = self.cross_head(embed, m_embed)
            embed = embed + fused_embed
               
        return embed, meta_repr
        

In [97]:
#| export
class OAK005(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.dr_distilbert"]
    _keys_to_ignore_on_load_missing = ["encoder.meta_distilbert"]

    @delegates(OAK000.__init__)
    def __init__(self, config, num_metadata:int, num_metadata_clusters:int, do_meta_embed_sparse:Optional[bool]=True, **kwargs):
        super().__init__(config, **kwargs)
        self.encoder = Encoder005(config)
        
        self.meta_embeddings = nn.Embedding(num_metadata_clusters, config.dim, sparse=do_meta_embed_sparse)
        self.register_buffer("metadata_cluster_mapping", torch.arange(num_metadata)%num_metadata_clusters, persistent=True)
        
        self.post_init(); self.remap_post_init();

    def init_meta_embeddings(self):
        torch.nn.init.zeros_(self.meta_embeddings.weight)

    def set_meta_embeddings(self, embed:torch.Tensor):
        with torch.no_grad():
            self.meta_embeddings.weight.copy_(embed)

    def set_metadata_cluster_mapping(self, metadata_cluster_mapping:torch.Tensor):
        if metadata_cluster_mapping.shape[0] != self.metadata_cluster_mapping.shape[0]:
            raise ValueError(f'Shape mismatch, `metadata_cluster_mapping` should have {self.metadata_cluster_mapping.shape[0]} elements.')
        with torch.no_grad():
            self.metadata_cluster_mapping.copy_(metadata_cluster_mapping)

    def init_retrieval_head(self):
        self.encoder.dr_head.post_init()

    def init_meta_encoder(self):
        self.encoder.init_meta_encoder()

    def remap_post_init(self):
        self.distilbert = self.encoder.dr_distilbert

    def _get_encoder_meta_kwargs(self, feat:str, prefix:str, **kwargs):
        meta_kwargs = Parameters.from_feat_meta_aug_prefix(feat, prefix, **kwargs)
        if f'{prefix}_idx' in meta_kwargs:
            m_idx = meta_kwargs[f'{prefix}_idx']
            if len(m_idx): meta_kwargs[f'{prefix}_meta_repr'] = self.meta_embeddings(self.metadata_cluster_mapping[m_idx])
        return meta_kwargs

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
                
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

### Example

In [98]:
from xcai.core import prepare_batch

In [99]:
meta_name = 'lnk'

In [100]:
n_clusters = block.train.dset.meta[f'{meta_name}_meta'].n_meta // 3
meta_cluster_mapping = torch.arange(block.train.dset.meta[f'{meta_name}_meta'].n_meta) % n_clusters

In [101]:
model = OAK005.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix=f'{meta_name}2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,

                               num_metadata=block.train.dset.meta[f'{meta_name}_meta'].n_meta, num_metadata_clusters=n_clusters,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=True,

                               use_query_loss=True,
                               
                               use_encoder_parallel=False, do_meta_embed_sparse=False)
model.init_retrieval_head()
model.init_meta_encoder()
model.init_cross_head()

Some weights of OAK005 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.alpha', 'encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'meta_embeddings.weight', 'metadata_cluster_mapping']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [102]:
model.init_meta_embeddings()
model.set_metadata_cluster_mapping(meta_cluster_mapping)

In [None]:
model = model.to('cuda')

In [103]:
b = prepare_batch(model, batch, m_args=[
    f'p{meta_name}2data_idx', f'p{meta_name}2data_data2ptr', f'{meta_name}2data_idx', f'{meta_name}2data_input_ids', 
    f'{meta_name}2data_attention_mask', f'{meta_name}2data_data2ptr',
])

In [104]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_data2ptr', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'plnk2lbl_idx', 'plnk2lbl_lbl2ptr', 'lnk2lbl_idx', 'lnk2lbl_lbl2ptr', 'lnk2lbl_identifier', 'lnk2lbl_input_text', 'lnk2lbl_input_ids', 'lnk2lbl_attention_mask', 'lnk2lbl_data2ptr', 'plnk2lbl_data2ptr'])

In [106]:
o = model(**b.to(model.device))

> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(61)[0;36mforward[0;34m()[0m
[0;32m     59 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m    ):  
[0m[0;32m---> 61 [0;31m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m[0;34m[0m[0m
[0m[0;32m     63 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(63)[0;36mforward[0;34m()[0m
[0;32m     61 [0;31m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m[0;34m[0m[0m
[0m[0;32m---> 63 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m            [0mencoder[0m [0;34m=[0m [0mXCDataParallel[0m[0;34m([0m[0mmodule[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mencoder[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     65 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(65)[0;36mforward[0;34m()[0m
[0;32m     63 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m            [0mencoder[0m [0;34m=[0m [0mXCDataParallel[0m[0;34m([0m[0mmodule[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mencoder[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m     67 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(67)[0;36mforward[0;34m()[0m
[0;32m     65 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m---> 67 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m     69 [0;31m                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(68)[0;36mforward[0;34m()[0m
[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m     67 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 68 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m     69 [0;31m                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(69)[0;36mforward[0;34m()[0m
[0;32m     67 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m---> 69 [0;31m                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(70)[0;36mforward[0;34m()[0m
[0;32m     68 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     69 [0;31m    ):  
[0m[0;32m---> 70 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_encode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_encode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(73)[0;36mforward[0;34m()[0m
[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(74)[0;36mforward[0;34m()[0m
[0;32m     72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m                [0mdata_fused_repr[0m[0;34m,[0m [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mfuse_meta_into_embeddings[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mmeta_kwargs[0m

ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m                [0mdata_fused_repr[0m[0;34m,[0m [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mfuse_meta_into_embeddings[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mmeta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0

ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(76)[0;36mforward[0;34m()[0m
[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 76 [0;31m                [0mdata_fused_repr[0m[0;34m,[0m [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mfuse_meta_into_embeddings[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mmeta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m[0;34m[0m[0m
[0m


ipdb>  b self.fuse_meta_into_embeddings


Breakpoint 4 at /tmp/ipykernel_15122/3934909451.py:11


ipdb>  n


> [0;32m/tmp/ipykernel_15122/3934909451.py[0m(12)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     10 [0;31m[0;34m[0m[0m
[0m[1;31m4[0;32m    11 [0;31m    [0;32mdef[0m [0mfuse_meta_into_embeddings[0m[0;34m([0m[0mself[0m[0;34m,[0m [0membed[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mmeta_kwargs[0m[0;34m:[0m[0mDict[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 12 [0;31m        [0mmeta_repr[0m[0;34m,[0m [0mbsz[0m [0;34m=[0m [0;34m{[0m[0;34m}[0m[0;34m,[0m [0membed[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     13 [0;31m[0;34m[0m[0m
[0m[0;32m     14 [0;31m        [0;32mfor[0m [0mm_key[0m[0;34m,[0m [0mm_args[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_15122/3934909451.py[0m(14)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     12 [0;31m        [0mmeta_repr[0m[0;34m,[0m [0mbsz[0m [0;34m=[0m [0;34m{[0m[0;34m}[0m[0;34m,[0m [0membed[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     13 [0;31m[0;34m[0m[0m
[0m[0;32m---> 14 [0;31m        [0;32mfor[0m [0mm_key[0m[0;34m,[0m [0mm_args[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     15 [0;31m            [0mn_meta[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m            [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m==[0m [0mn_meta[0m[0;34m)[0m[0;34m,[0m [0;34mf'All

ipdb>  


> [0;32m/tmp/ipykernel_15122/3934909451.py[0m(15)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     13 [0;31m[0;34m[0m[0m
[0m[0;32m     14 [0;31m        [0;32mfor[0m [0mm_key[0m[0;34m,[0m [0mm_args[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 15 [0;31m            [0mn_meta[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m            [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m==[0m [0mn_meta[0m[0;34m)[0m[0;34m,[0m [0;34mf'All datapoints should have same number of metadata.'[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/3934909451.py[0m(16)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     14 [0;31m        [0;32mfor[0m [0mm_key[0m[0;34m,[0m [0mm_args[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     15 [0;31m            [0mn_meta[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 16 [0;31m            [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m==[0m [0mn_meta[0m[0;34m)[0m[0;34m,[0m [0;34mf'All datapoints should have same number of metadata.'[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m[0;34m[0m[0m
[0m[0;32m     18 [0;31m            [0mm_input_ids[0m[0;34m,[0m [0mm_attention_mask[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'inpu

ipdb>  


> [0;32m/tmp/ipykernel_15122/3934909451.py[0m(18)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     16 [0;31m            [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m==[0m [0mn_meta[0m[0;34m)[0m[0;34m,[0m [0;34mf'All datapoints should have same number of metadata.'[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m[0;34m[0m[0m
[0m[0;32m---> 18 [0;31m            [0mm_input_ids[0m[0;34m,[0m [0mm_attention_mask[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'input_ids'[0m[0;34m][0m[0;34m,[0m [0mm_args[0m[0;34m[[0m[0;34m'attention_mask'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m            [0mm_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmeta_encode[0m[0;34m([0m[0minput_ids[0m[0;34m=[0m[0mm_input_ids[0m[0;34m,[0m [0mattention_mask[0m[0;34m=[0m[0mm_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0

ipdb>  


> [0;32m/tmp/ipykernel_15122/3934909451.py[0m(19)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     17 [0;31m[0;34m[0m[0m
[0m[0;32m     18 [0;31m            [0mm_input_ids[0m[0;34m,[0m [0mm_attention_mask[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'input_ids'[0m[0;34m][0m[0;34m,[0m [0mm_args[0m[0;34m[[0m[0;34m'attention_mask'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 19 [0;31m            [0mm_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmeta_encode[0m[0;34m([0m[0minput_ids[0m[0;34m=[0m[0mm_input_ids[0m[0;34m,[0m [0mattention_mask[0m[0;34m=[0m[0mm_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m [0;34m+[0m [0mm_args[0m[0;34m[[0m[0;34m'meta_repr'[0m[0;34m][0m [0;34m*[0m [0mself[0m[0;34m.[0m[0malpha[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/3934909451.py[0m(20)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     18 [0;31m            [0mm_input_ids[0m[0;34m,[0m [0mm_attention_mask[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'input_ids'[0m[0;34m][0m[0;34m,[0m [0mm_args[0m[0;34m[[0m[0;34m'attention_mask'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m            [0mm_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmeta_encode[0m[0;34m([0m[0minput_ids[0m[0;34m=[0m[0mm_input_ids[0m[0;34m,[0m [0mattention_mask[0m[0;34m=[0m[0mm_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 20 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m [0;34m+[0m [0mm_args[0m[0;34m[[0m[0;34m'meta_repr'[0m[0;34m][0m [0;34m*[0m [0mself[0m[0;34m.[0m[0malpha[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;34m[0m[0m
[0m[0;32m     22 [0;31m            [0mmeta_repr[0m[0;34m[[0m[0mm_key[

ipdb>  


> [0;32m/tmp/ipykernel_15122/3934909451.py[0m(22)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     20 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m [0;34m+[0m [0mm_args[0m[0;34m[[0m[0;34m'meta_repr'[0m[0;34m][0m [0;34m*[0m [0mself[0m[0;34m.[0m[0malpha[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;34m[0m[0m
[0m[0;32m---> 22 [0;31m            [0mmeta_repr[0m[0;34m[[0m[0mm_key[0m[0;34m][0m [0;34m=[0m [0mm_embed[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m[0;34m[0m[0m
[0m[0;32m     24 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdim[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/3934909451.py[0m(24)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     22 [0;31m            [0mmeta_repr[0m[0;34m[[0m[0mm_key[0m[0;34m][0m [0;34m=[0m [0mm_embed[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     23 [0;31m[0;34m[0m[0m
[0m[0;32m---> 24 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdim[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m            [0mfused_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcross_head[0m[0;34m([0m[0membed[0m[0;34m,[0m [0mm_embed[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m            [0membed[0m [0;34m=[0m [0membed[0m [0;34m+[0m [0mfused_embed[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/3934909451.py[0m(25)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     23 [0;31m[0;34m[0m[0m
[0m[0;32m     24 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdim[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 25 [0;31m            [0mfused_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcross_head[0m[0;34m([0m[0membed[0m[0;34m,[0m [0mm_embed[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m            [0membed[0m [0;34m=[0m [0membed[0m [0;34m+[0m [0mfused_embed[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     27 [0;31m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(70)[0;36mforward[0;34m()[0m
[0;32m     68 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     69 [0;31m    ):  
[0m[0;32m---> 70 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_encode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_encode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(73)[0;36mforward[0;34m()[0m
[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(79)[0;36mforward[0;34m()[0m
[0;32m     77 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m[0;34m[0m[0m
[0m[0;32m---> 79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m        return EncoderOutput(
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(82)[0;36mforward[0;34m()[0m
[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m        return EncoderOutput(
[0m[0;32m---> 82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     84 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(83)[0;36mforward[0;34m()[0m
[0;32m     81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     84 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     85 [0;31m        )
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(84)[0;36mforward[0;34m()[0m
[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 84 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     85 [0;31m        )
[0m[0;32m     86 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(78)[0;36mforward[0;34m()[0m
[0;32m     76 [0;31m                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m[0;32m---> 78 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     79 [0;31m                                     plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m


ipdb>  c


In [108]:
o.loss

tensor(0.0167, grad_fn=<AddBackward0>)

In [84]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [85]:
o = func()

> [0;32m/tmp/ipykernel_15122/3657616883.py[0m(3)[0;36mfunc[0;34m()[0m
[0;32m      1 [0;31m[0;32mdef[0m [0mfunc[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 3 [0;31m    [0;32mreturn[0m [0mmodel[0m[0;34m([0m[0;34m**[0m[0mb[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mmodel[0m[0;34m.[0m[0mdevice[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m[0;34m[0m[0m
[0m


ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_15122/1326464437.py:46


ipdb>  b model.encoder.forward


Breakpoint 2 at /tmp/ipykernel_15122/1379713280.py:61


ipdb>  c


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(61)[0;36mforward[0;34m()[0m
[0;32m     59 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m    ):  
[0m[0;32m---> 61 [0;31m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m[0;34m[0m[0m
[0m[0;32m     63 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(63)[0;36mforward[0;34m()[0m
[0;32m     61 [0;31m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m[0;34m[0m[0m
[0m[0;32m---> 63 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m            [0mencoder[0m [0;34m=[0m [0mXCDataParallel[0m[0;34m([0m[0mmodule[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mencoder[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     65 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(65)[0;36mforward[0;34m()[0m
[0;32m     63 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m            [0mencoder[0m [0;34m=[0m [0mXCDataParallel[0m[0;34m([0m[0mmodule[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mencoder[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m     67 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(67)[0;36mforward[0;34m()[0m
[0;32m     65 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m---> 67 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m     69 [0;31m                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(68)[0;36mforward[0;34m()[0m
[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m     67 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 68 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m     69 [0;31m                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(69)[0;36mforward[0;34m()[0m
[0;32m     67 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m---> 69 [0;31m                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(68)[0;36mforward[0;34m()[0m
[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m     67 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 68 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m     69 [0;31m                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(69)[0;36mforward[0;34m()[0m
[0;32m     67 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     68 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m---> 69 [0;31m                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(68)[0;36mforward[0;34m()[0m
[0;32m     66 [0;31m[0;34m[0m[0m
[0m[0;32m     67 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 68 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m     69 [0;31m                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(70)[0;36mforward[0;34m()[0m
[0;32m     68 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     69 [0;31m    ):  
[0m[0;32m---> 70 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_encode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  data_input_ids.shape


torch.Size([5, 32])


ipdb>  data_attention_mask.shape


torch.Size([5, 32])


ipdb>  n


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_encode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  data_repr.shape


torch.Size([5, 768])


ipdb>  n


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(73)[0;36mforward[0;34m()[0m
[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(74)[0;36mforward[0;34m()[0m
[0;32m     72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m                [0mdata_fused_repr[0m[0;34m,[0m [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mfuse_meta_into_embeddings[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mmeta_kwargs[0m

ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m                [0mdata_fused_repr[0m[0;34m,[0m [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mfuse_meta_into_embeddings[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mmeta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0

ipdb>  meta_kwargs.keys()


dict_keys(['lnk2data'])


ipdb>  n


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(76)[0;36mforward[0;34m()[0m
[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 76 [0;31m                [0mdata_fused_repr[0m[0;34m,[0m [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mfuse_meta_into_embeddings[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mmeta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m[0;34m[0m[0m
[0m


ipdb>  meta_kwargs['lnk2data'].keys()


dict_keys(['attention_mask', 'input_ids', 'idx', 'data2ptr', 'meta_repr'])


ipdb>  meta_kwargs['lnk2data']['meta_repr'].shape


torch.Size([15, 768])


ipdb>  b self.fuse_meta_into_embeddings


Breakpoint 3 at /tmp/ipykernel_15122/374209400.py:10


ipdb>  c


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(11)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m      9 [0;31m[0;34m[0m[0m
[0m[1;31m3[0;32m    10 [0;31m    [0;32mdef[0m [0mfuse_meta_into_embeddings[0m[0;34m([0m[0mself[0m[0;34m,[0m [0membed[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mmeta_kwargs[0m[0;34m:[0m[0mDict[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 11 [0;31m        [0mmeta_repr[0m[0;34m,[0m [0mbsz[0m [0;34m=[0m [0;34m{[0m[0;34m}[0m[0;34m,[0m [0membed[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     12 [0;31m[0;34m[0m[0m
[0m[0;32m     13 [0;31m        [0;32mfor[0m [0mm_key[0m[0;34m,[0m [0mm_args[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(13)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     11 [0;31m        [0mmeta_repr[0m[0;34m,[0m [0mbsz[0m [0;34m=[0m [0;34m{[0m[0;34m}[0m[0;34m,[0m [0membed[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     12 [0;31m[0;34m[0m[0m
[0m[0;32m---> 13 [0;31m        [0;32mfor[0m [0mm_key[0m[0;34m,[0m [0mm_args[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     14 [0;31m            [0mn_meta[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     15 [0;31m            [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m==[0m [0mn_meta[0m[0;34m)[0m[0;34m,[0m [0;34mf'All 

ipdb>  embed.shape


torch.Size([5, 768])


ipdb>  n


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(14)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     12 [0;31m[0;34m[0m[0m
[0m[0;32m     13 [0;31m        [0;32mfor[0m [0mm_key[0m[0;34m,[0m [0mm_args[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 14 [0;31m            [0mn_meta[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     15 [0;31m            [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m==[0m [0mn_meta[0m[0;34m)[0m[0;34m,[0m [0;34mf'All datapoints should have same number of metadata.'[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m[0;34m[0m[0m
[0m


ipdb>  m_key


'lnk2data'


ipdb>  n


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(15)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     13 [0;31m        [0;32mfor[0m [0mm_key[0m[0;34m,[0m [0mm_args[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     14 [0;31m            [0mn_meta[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 15 [0;31m            [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m==[0m [0mn_meta[0m[0;34m)[0m[0;34m,[0m [0;34mf'All datapoints should have same number of metadata.'[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m[0;34m[0m[0m
[0m[0;32m     17 [0;31m            [0mm_input_ids[0m[0;34m,[0m [0mm_attention_mask[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'input

ipdb>  n_meta


tensor(3)


ipdb>  n


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(17)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     15 [0;31m            [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m==[0m [0mn_meta[0m[0;34m)[0m[0;34m,[0m [0;34mf'All datapoints should have same number of metadata.'[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m[0;34m[0m[0m
[0m[0;32m---> 17 [0;31m            [0mm_input_ids[0m[0;34m,[0m [0mm_attention_mask[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'input_ids'[0m[0;34m][0m[0;34m,[0m [0mm_args[0m[0;34m[[0m[0;34m'attention_mask'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m            [0mm_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmeta_encode[0m[0;34m([0m[0minput_ids[0m[0;34m=[0m[0mm_input_ids[0m[0;34m,[0m [0mattention_mask[0m[0;34m=[0m[0mm_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m

ipdb>  


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(18)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     16 [0;31m[0;34m[0m[0m
[0m[0;32m     17 [0;31m            [0mm_input_ids[0m[0;34m,[0m [0mm_attention_mask[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'input_ids'[0m[0;34m][0m[0;34m,[0m [0mm_args[0m[0;34m[[0m[0;34m'attention_mask'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 18 [0;31m            [0mm_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmeta_encode[0m[0;34m([0m[0minput_ids[0m[0;34m=[0m[0mm_input_ids[0m[0;34m,[0m [0mattention_mask[0m[0;34m=[0m[0mm_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     19 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m [0;34m+[0m [0mm_args[0m[0;34m[[0m[0;34m'meta_repr'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(19)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     17 [0;31m            [0mm_input_ids[0m[0;34m,[0m [0mm_attention_mask[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'input_ids'[0m[0;34m][0m[0;34m,[0m [0mm_args[0m[0;34m[[0m[0;34m'attention_mask'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m            [0mm_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmeta_encode[0m[0;34m([0m[0minput_ids[0m[0;34m=[0m[0mm_input_ids[0m[0;34m,[0m [0mattention_mask[0m[0;34m=[0m[0mm_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 19 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m [0;34m+[0m [0mm_args[0m[0;34m[[0m[0;34m'meta_repr'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m[0;34m[0m[0m
[0m[0;32m     21 [0;31m            [0mmeta_repr[0m[0;34m[[0m[0mm_key[0m[0;34m][0m [0;34m=[0m [0mm_embed[0m[0;34m[

ipdb>  m_embed.shape


torch.Size([15, 768])


ipdb>  n


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(21)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     19 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m [0;34m+[0m [0mm_args[0m[0;34m[[0m[0;34m'meta_repr'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m[0;34m[0m[0m
[0m[0;32m---> 21 [0;31m            [0mmeta_repr[0m[0;34m[[0m[0mm_key[0m[0;34m][0m [0;34m=[0m [0mm_embed[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     22 [0;31m[0;34m[0m[0m
[0m[0;32m     23 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdim[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(23)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     21 [0;31m            [0mmeta_repr[0m[0;34m[[0m[0mm_key[0m[0;34m][0m [0;34m=[0m [0mm_embed[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     22 [0;31m[0;34m[0m[0m
[0m[0;32m---> 23 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdim[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m            [0mfused_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcross_head[0m[0;34m([0m[0membed[0m[0;34m,[0m [0mm_embed[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m            [0membed[0m [0;34m=[0m [0membed[0m [0;34m+[0m [0mfused_embed[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(24)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     22 [0;31m[0;34m[0m[0m
[0m[0;32m     23 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdim[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 24 [0;31m            [0mfused_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcross_head[0m[0;34m([0m[0membed[0m[0;34m,[0m [0mm_embed[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m            [0membed[0m [0;34m=[0m [0membed[0m [0;34m+[0m [0mfused_embed[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(25)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     23 [0;31m            [0mm_embed[0m [0;34m=[0m [0mm_embed[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mbsz[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdim[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     24 [0;31m            [0mfused_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcross_head[0m[0;34m([0m[0membed[0m[0;34m,[0m [0mm_embed[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 25 [0;31m            [0membed[0m [0;34m=[0m [0membed[0m [0;34m+[0m [0mfused_embed[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m[0;34m[0m[0m
[0m[0;32m     27 [0;31m        [0;32mreturn[0m [0membed[0m[0;34m,[0m [0mmeta_repr[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(13)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     11 [0;31m        [0mmeta_repr[0m[0;34m,[0m [0mbsz[0m [0;34m=[0m [0;34m{[0m[0;34m}[0m[0;34m,[0m [0membed[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     12 [0;31m[0;34m[0m[0m
[0m[0;32m---> 13 [0;31m        [0;32mfor[0m [0mm_key[0m[0;34m,[0m [0mm_args[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     14 [0;31m            [0mn_meta[0m [0;34m=[0m [0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     15 [0;31m            [0;32massert[0m [0mtorch[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mm_args[0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m==[0m [0mn_meta[0m[0;34m)[0m[0;34m,[0m [0;34mf'All 

ipdb>  embed.shape


torch.Size([5, 768])


ipdb>  n


> [0;32m/tmp/ipykernel_15122/374209400.py[0m(27)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     24 [0;31m            [0mfused_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcross_head[0m[0;34m([0m[0membed[0m[0;34m,[0m [0mm_embed[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m            [0membed[0m [0;34m=[0m [0membed[0m [0;34m+[0m [0mfused_embed[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m[0;34m[0m[0m
[0m[0;32m---> 27 [0;31m        [0;32mreturn[0m [0membed[0m[0;34m,[0m [0mmeta_repr[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     28 [0;31m[0;34m[0m[0m
[0m


ipdb>  


--Return--
(tensor([[-0.9...AddBackward0>), {'lnk2data': tensor([[-0.2...AddBackward0>)})
> [0;32m/tmp/ipykernel_15122/374209400.py[0m(27)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     24 [0;31m            [0mfused_embed[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcross_head[0m[0;34m([0m[0membed[0m[0;34m,[0m [0mm_embed[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     25 [0;31m            [0membed[0m [0;34m=[0m [0membed[0m [0;34m+[0m [0mfused_embed[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     26 [0;31m[0;34m[0m[0m
[0m[0;32m---> 27 [0;31m        [0;32mreturn[0m [0membed[0m[0;34m,[0m [0mmeta_repr[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     28 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(77)[0;36mforward[0;34m()[0m
[0;32m     75 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m                [0mdata_fused_repr[0m[0;34m,[0m [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mfuse_meta_into_embeddings[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mmeta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 77 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m[0;34m[0m[0m
[0m[0;32m     79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(79)[0;36mforward[0;34m()[0m
[0;32m     77 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m[0;34m[0m[0m
[0m[0;32m---> 79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m        return EncoderOutput(
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(82)[0;36mforward[0;34m()[0m
[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m        return EncoderOutput(
[0m[0;32m---> 82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     84 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(83)[0;36mforward[0;34m()[0m
[0;32m     81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     84 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     85 [0;31m        )
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(84)[0;36mforward[0;34m()[0m
[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 84 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     85 [0;31m        )
[0m[0;32m     86 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


--Return--
EncoderOutput...dBackward0>)})
> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m        [0mloss[0m [0;34m=[0m [0;32mNone[0m[0;34m;[0m [0mlbl2data_o[0m [0;34m=[0m [0mEncoderOutput[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(73)[0;36mforward[0;34m()[0m
[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m        [0mloss[0m [0;34m=[0m [0;32mNone[0m[0;34m;[0m [0mlbl2data_o[0m [0;34m=[0m [0mEncoderOutput[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 73 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(74)[0;36mforward[0;34m()[0m
[0;32m     72 [0;31m        [0mloss[0m [0;34m=[0m [0;32mNone[0m[0;34m;[0m [0mlbl2data_o[0m [0;34m=[0m [0mEncoderOutput[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 74 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     76 [0;31m                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2d

ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     76 [0;31m                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m


ipdb>  lbl2data_meta_kwargs


{}


ipdb>  n


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(76)[0;36mforward[0;34m()[0m
[0;32m     74 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m---> 76 [0;31m                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m[0;32m     78 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     76 [0;31m                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(76)[0;36mforward[0;34m()[0m
[0;32m     74 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m---> 76 [0;31m                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m[0;32m     78 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     76 [0;31m                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(70)[0;36mforward[0;34m()[0m
[0;32m     68 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     69 [0;31m    ):  
[0m[0;32m---> 70 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_encode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_encode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(73)[0;36mforward[0;34m()[0m
[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m        [0mdata_fused_repr[0m [0;34m=[0m [0mmeta_repr[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 73 [0;31m        [0;32mif[0m [0mdata_aug_meta_prefix[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m            [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_meta_aug_prefix[0m[0;34m([0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(79)[0;36mforward[0;34m()[0m
[0;32m     77 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m[0;34m[0m[0m
[0m[0;32m---> 79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m        return EncoderOutput(
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(82)[0;36mforward[0;34m()[0m
[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m        return EncoderOutput(
[0m[0;32m---> 82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     84 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(83)[0;36mforward[0;34m()[0m
[0;32m     81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     84 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     85 [0;31m        )
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(84)[0;36mforward[0;34m()[0m
[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 84 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     85 [0;31m        )
[0m[0;32m     86 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/tmp/ipykernel_15122/1379713280.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m        [0mdata_repr[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mnormalize[0m[0;34m([0m[0mdata_repr[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m        return EncoderOutput(
[0m[0;32m     82 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(78)[0;36mforward[0;34m()[0m
[0;32m     76 [0;31m                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m[0;32m---> 78 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     79 [0;31m                                     plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(79)[0;36mforward[0;34m()[0m
[0;32m     77 [0;31m[0;34m[0m[0m
[0m[0;32m     78 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m---> 79 [0;31m                                     plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(78)[0;36mforward[0;34m()[0m
[0;32m     76 [0;31m                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m[0;32m---> 78 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     79 [0;31m                                     plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m                                     plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     82 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     83 [0;31m                                          plbl2data_data2ptr,plbl2data_idx)
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(82)[0;36mforward[0;34m()[0m
[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 82 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     83 [0;31m                                          plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(83)[0;36mforward[0;34m()[0m
[0;32m     81 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     82 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m---> 83 [0;31m                                          plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m[0;32m     85 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(82)[0;36mforward[0;34m()[0m
[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 82 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     83 [0;31m                                          plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(85)[0;36mforward[0;34m()[0m
[0;32m     83 [0;31m                                          plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m[0;32m---> 85 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     86 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     87 [0;31m                                              plbl2data_data2ptr,plbl2data_idx)
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(86)[0;36mforward[0;34m()[0m
[0;32m     84 [0;31m[0;34m[0m[0m
[0m[0;32m     85 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 86 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     87 [0;31m                                              plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     88 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(87)[0;36mforward[0;34m()[0m
[0;32m     85 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     86 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m---> 87 [0;31m                                              plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     88 [0;31m[0;34m[0m[0m
[0m[0;32m     89 [0;31m        [0;32mif[0m [0;32mnot[0m [0mreturn_dict[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(86)[0;36mforward[0;34m()[0m
[0;32m     84 [0;31m[0;34m[0m[0m
[0m[0;32m     85 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 86 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     87 [0;31m                                              plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     88 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(89)[0;36mforward[0;34m()[0m
[0;32m     87 [0;31m                                              plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     88 [0;31m[0;34m[0m[0m
[0m[0;32m---> 89 [0;31m        [0;32mif[0m [0;32mnot[0m [0mreturn_dict[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     90 [0;31m            [0mo[0m [0;34m=[0m [0;34m([0m[0mdata_o[0m[0;34m.[0m[0mlogits[0m[0;34m,[0m[0mdata_o[0m[0;34m.[0m[0mrep[0m[0;34m,[0m[0mdata_o[0m[0;34m.[0m[0mfused_rep[0m[0;34m,[0m[0mlbl2data_o[0m[0;34m.[0m[0mlogits[0m[0;34m,[0m[0mlbl2data_o[0m[0;34m.[0m[0mrep[0m[0;34m,[0m[0mlbl2data_o[0m[0;34m.[0m[0mfused_rep[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     91 [0;31m            [0;32mreturn[0m [0;34m([0m[0;34m([0m[0mloss[0m[0;34m,[0m[0;34m)[0m [0;34m+[0m [0mo[0m[0;34m)[0m [0;32mif[0m [0mloss[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32me

ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(94)[0;36mforward[0;34m()[0m
[0;32m     92 [0;31m[0;34m[0m[0m
[0m[0;32m     93 [0;31m[0;34m[0m[0m
[0m[0;32m---> 94 [0;31m        return XCModelOutput(
[0m[0;32m     95 [0;31m            [0mloss[0m[0;34m=[0m[0mloss[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     96 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(95)[0;36mforward[0;34m()[0m
[0;32m     93 [0;31m[0;34m[0m[0m
[0m[0;32m     94 [0;31m        return XCModelOutput(
[0m[0;32m---> 95 [0;31m            [0mloss[0m[0;34m=[0m[0mloss[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     96 [0;31m[0;34m[0m[0m
[0m[0;32m     97 [0;31m            [0mdata_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mrep[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(97)[0;36mforward[0;34m()[0m
[0;32m     95 [0;31m            [0mloss[0m[0;34m=[0m[0mloss[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     96 [0;31m[0;34m[0m[0m
[0m[0;32m---> 97 [0;31m            [0mdata_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mrep[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     98 [0;31m            [0mdata_fused_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mfused_rep[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     99 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(98)[0;36mforward[0;34m()[0m
[0;32m     96 [0;31m[0;34m[0m[0m
[0m[0;32m     97 [0;31m            [0mdata_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mrep[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 98 [0;31m            [0mdata_fused_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mfused_rep[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     99 [0;31m[0;34m[0m[0m
[0m[0;32m    100 [0;31m            [0mlbl2data_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0mrep[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(100)[0;36mforward[0;34m()[0m
[0;32m     98 [0;31m            [0mdata_fused_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mfused_rep[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     99 [0;31m[0;34m[0m[0m
[0m[0;32m--> 100 [0;31m            [0mlbl2data_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0mrep[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    101 [0;31m            [0mlbl2data_fused_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0mfused_rep[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    102 [0;31m        )
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(101)[0;36mforward[0;34m()[0m
[0;32m     99 [0;31m[0;34m[0m[0m
[0m[0;32m    100 [0;31m            [0mlbl2data_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0mrep[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 101 [0;31m            [0mlbl2data_fused_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0mfused_rep[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    102 [0;31m        )
[0m[0;32m    103 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_15122/1326464437.py[0m(94)[0;36mforward[0;34m()[0m
[0;32m     92 [0;31m[0;34m[0m[0m
[0m[0;32m     93 [0;31m[0;34m[0m[0m
[0m[0;32m---> 94 [0;31m        return XCModelOutput(
[0m[0;32m     95 [0;31m            [0mloss[0m[0;34m=[0m[0mloss[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     96 [0;31m[0;34m[0m[0m
[0m


ipdb>  c


[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m



In [83]:
o = model(**b.to(model.device))

In [86]:
o.loss

tensor(0.0167, grad_fn=<AddBackward0>)

## `OAK006`

In [None]:
#| export
class Encoder006(Encoder003):
    
    def __init__(
        self, 
        config:PretrainedConfig,
    ):
        super().__init__(config)

    def dr_encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        o = self.dr_distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        return F.normalize(Pooling.mean_pooling(o[0], attention_mask), dim=1)

    def meta_encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        o = self.meta_distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        return F.normalize(Pooling.mean_pooling(o[0], attention_mask), dim=1)

    def fuse_meta_into_embeddings(self, embed:torch.Tensor, meta_kwargs:Dict):
        meta_repr, bsz = {}, embed.size(0)
        
        for m_key, m_args in meta_kwargs.items():
            n_meta = m_args['data2ptr'].max()
            assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
            
            m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
            m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)

            m_alpha = F.sigmoid(m_args['alpha']).unsqueeze(1)
            m_embed = m_alpha * m_embed + (1 - m_alpha) * F.normalize(m_args['meta_repr'], dim=1)
            
            meta_repr[m_key] = m_embed
                    
            m_embed = m_embed.view(bsz, -1, self.config.dim)  
            fused_embed = self.cross_head(embed, m_embed)
            embed = embed + fused_embed
               
        return embed, meta_repr

    def params_from_meta_aug_prefix(self, prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^{prefix}.*_(input_ids|attention_mask|data2ptr|meta_repr|idx|alpha)$', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):  
        data_repr = self.dr_encode(data_input_ids, data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = self.params_from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
                data_fused_repr = self.dr(data_fused_repr)
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

In [None]:
#| export
class OAK006(OAK005, DistilBertPreTrainedModel):
    
    @delegates(OAK005.__init__)
    def __init__(self, config, num_metadata_clusters:int, **kwargs):
        super().__init__(config, num_metadata_clusters=num_metadata_clusters, **kwargs)
        self.encoder = Encoder006(config)
        
        self.alpha = nn.Parameter(torch.rand(num_metadata_clusters))
        self.post_init(); self.remap_post_init();

    def _get_encoder_meta_kwargs(self, feat:str, prefix:str, **kwargs):
        meta_kwargs = Parameters.from_feat_meta_aug_prefix(feat, prefix, **kwargs)
        if f'{prefix}_idx' in meta_kwargs:
            m_idx = meta_kwargs[f'{prefix}_idx']
            if len(m_idx): 
                meta_kwargs[f'{prefix}_meta_repr'] = self.meta_embeddings(self.metadata_cluster_mapping[m_idx])
                meta_kwargs[f'{prefix}_alpha'] = self.alpha[self.metadata_cluster_mapping[m_idx]]
        return meta_kwargs


### Example

In [None]:
n_clusters = block.train.dset.meta['lnk_meta'].n_meta // 3
meta_cluster_mapping = torch.arange(block.train.dset.meta['lnk_meta'].n_meta) % n_clusters

In [None]:
model = OAK006.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,

                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, num_metadata_clusters=n_clusters,
                               
                               calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=True,

                               use_query_loss=True,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_meta_encoder()
model.init_cross_head()

Some weights of OAK006 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['alpha', 'encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'meta_embeddings.weight', 'metadata_cluster_mapping']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model.init_meta_embeddings()
model.set_metadata_cluster_mapping(meta_cluster_mapping)

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'lnk2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'lnk2lbl2data_idx', 'lnk2lbl2data_identifier', 'lnk2lbl2data_input_text', 'lnk2lbl2data_input_ids', 'lnk2lbl2data_attention_mask', 'lnk2lbl2data_data2ptr', 'lnk2lbl2data_lbl2data2ptr'])

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.0615, grad_fn=<AddBackward0>)

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))

In [None]:
o = func()

> /tmp/ipykernel_6442/1795951242.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))



ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_6442/3737515001.py:46


ipdb>  b model.encoder.forward


Breakpoint 2 at /tmp/ipykernel_6442/3083163642.py:55


ipdb>  c


> /tmp/ipykernel_6442/3737515001.py(61)forward()
     59         **kwargs
     60     ):  
---> 61         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     62 
     63         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_6442/3737515001.py(63)forward()
     61         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     62 
---> 63         if self.use_encoder_parallel:
     64             encoder = XCDataParallel(module=self.encoder)
     65         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_6442/3737515001.py(65)forward()
     63         if self.use_encoder_parallel:
     64             encoder = XCDataParallel(module=self.encoder)
---> 65         else: encoder = self.encoder
     66 
     67         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_6442/3737515001.py(67)forward()
     65         else: encoder = self.encoder
     66 
---> 67         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
     68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_6442/3737515001.py(68)forward()
     66 
     67         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
---> 68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 



ipdb>  data_meta_kwargs.keys()


dict_keys(['lnk2data_attention_mask', 'lnk2data_input_ids', 'lnk2data_idx', 'lnk2data_data2ptr', 'lnk2data_meta_repr', 'lnk2data_alpha'])


ipdb>  self.alpha.shape


torch.Size([218695])


ipdb>  n


> /tmp/ipykernel_6442/3737515001.py(69)forward()
     67         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
     68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 
     71 



ipdb>  


> /tmp/ipykernel_6442/3737515001.py(68)forward()
     66 
     67         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
---> 68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 



ipdb>  


> /tmp/ipykernel_6442/3737515001.py(69)forward()
     67         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
     68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 
     71 



ipdb>  


> /tmp/ipykernel_6442/3737515001.py(68)forward()
     66 
     67         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
---> 68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(64)forward()
     62         **kwargs
     63     ):  
---> 64         data_repr = self.dr_encode(data_input_ids, data_attention_mask)
     65 
     66         data_fused_repr = meta_repr = None



ipdb>  n


> /tmp/ipykernel_6442/3083163642.py(66)forward()
     64         data_repr = self.dr_encode(data_input_ids, data_attention_mask)
     65 
---> 66         data_fused_repr = meta_repr = None
     67         if data_aug_meta_prefix is not None:
     68             meta_kwargs = self.params_from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  data_repr.norm(dim=1)


tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  n


> /tmp/ipykernel_6442/3083163642.py(67)forward()
     65 
     66         data_fused_repr = meta_repr = None
---> 67         if data_aug_meta_prefix is not None:
     68             meta_kwargs = self.params_from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     69             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(68)forward()
     66         data_fused_repr = meta_repr = None
     67         if data_aug_meta_prefix is not None:
---> 68             meta_kwargs = self.params_from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     69             if len(meta_kwargs):
     70                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(69)forward()
     67         if data_aug_meta_prefix is not None:
     68             meta_kwargs = self.params_from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
---> 69             if len(meta_kwargs):
     70                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
     71                 data_fused_repr = self.dr(data_fused_repr)



ipdb>  meta_kwargs.keys()


dict_keys(['lnk2data'])


ipdb>  n


> /tmp/ipykernel_6442/3083163642.py(70)forward()
     68             meta_kwargs = self.params_from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     69             if len(meta_kwargs):
---> 70                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
     71                 data_fused_repr = self.dr(data_fused_repr)
     72 



ipdb>  s


--Call--
> /tmp/ipykernel_6442/3083163642.py(26)fuse_meta_into_embeddings()
     24         return F.normalize(Pooling.mean_pooling(o[0], attention_mask), dim=1)
     25 
---> 26     def fuse_meta_into_embeddings(self, embed:torch.Tensor, meta_kwargs:Dict):
     27         meta_repr, bsz = {}, embed.size(0)
     28 



ipdb>  n


> /tmp/ipykernel_6442/3083163642.py(27)fuse_meta_into_embeddings()
     25 
     26     def fuse_meta_into_embeddings(self, embed:torch.Tensor, meta_kwargs:Dict):
---> 27         meta_repr, bsz = {}, embed.size(0)
     28 
     29         for m_key, m_args in meta_kwargs.items():



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(29)fuse_meta_into_embeddings()
     27         meta_repr, bsz = {}, embed.size(0)
     28 
---> 29         for m_key, m_args in meta_kwargs.items():
     30             n_meta = m_args['data2ptr'].max()
     31             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(30)fuse_meta_into_embeddings()
     28 
     29         for m_key, m_args in meta_kwargs.items():
---> 30             n_meta = m_args['data2ptr'].max()
     31             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     32 



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(31)fuse_meta_into_embeddings()
     29         for m_key, m_args in meta_kwargs.items():
     30             n_meta = m_args['data2ptr'].max()
---> 31             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     32 
     33             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(33)fuse_meta_into_embeddings()
     31             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     32 
---> 33             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
     34             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     35 



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(34)fuse_meta_into_embeddings()
     32 
     33             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
---> 34             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     35 
     36             m_alpha = F.sigmoid(m_args['alpha']).unsqueeze(1)



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(36)fuse_meta_into_embeddings()
     34             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     35 
---> 36             m_alpha = F.sigmoid(m_args['alpha']).unsqueeze(1)
     37             m_embed = m_alpha * m_embed + (1 - m_alpha) * F.normalize(m_args['meta_repr'], dim=1)
     38 



ipdb>  m_embed.norm(dim=1)


tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  m_embed.shape


torch.Size([15, 768])


ipdb>  m_args['alpha']


tensor([0.6730, 0.7557, 0.7357, 0.4133, 0.9625, 0.7035, 0.9866, 0.1353, 0.7073,
        0.7709, 0.3240, 0.9944, 0.4719, 0.6533, 0.1304],
       grad_fn=<IndexBackward0>)


ipdb>  n


> /tmp/ipykernel_6442/3083163642.py(37)fuse_meta_into_embeddings()
     35 
     36             m_alpha = F.sigmoid(m_args['alpha']).unsqueeze(1)
---> 37             m_embed = m_alpha * m_embed + (1 - m_alpha) * F.normalize(m_args['meta_repr'], dim=1)
     38 
     39             meta_repr[m_key] = m_embed



ipdb>  m_alpha.shape


torch.Size([15, 1])


ipdb>  m_embed.shape


torch.Size([15, 768])


ipdb>  m_embed


tensor([[ 0.0315,  0.1107, -0.0196,  ..., -0.0062,  0.0025,  0.0325],
        [-0.0350,  0.0437, -0.0851,  ...,  0.0075,  0.0263, -0.0171],
        [-0.0178,  0.0711, -0.0700,  ...,  0.0660,  0.0611,  0.0241],
        ...,
        [ 0.0054,  0.0062,  0.0217,  ..., -0.0415,  0.0254, -0.0445],
        [-0.0017, -0.0252, -0.0213,  ...,  0.0156, -0.0147, -0.0399],
        [-0.0581, -0.0226, -0.0050,  ...,  0.0149,  0.0453, -0.0547]],
       grad_fn=<DivBackward0>)


ipdb>  m_alpha * m_embed


tensor([[ 0.0209,  0.0733, -0.0130,  ..., -0.0041,  0.0016,  0.0215],
        [-0.0238,  0.0297, -0.0579,  ...,  0.0051,  0.0179, -0.0116],
        [-0.0120,  0.0481, -0.0474,  ...,  0.0446,  0.0413,  0.0163],
        ...,
        [ 0.0033,  0.0038,  0.0134,  ..., -0.0255,  0.0156, -0.0274],
        [-0.0011, -0.0166, -0.0140,  ...,  0.0103, -0.0097, -0.0263],
        [-0.0310, -0.0120, -0.0026,  ...,  0.0079,  0.0241, -0.0291]],
       grad_fn=<MulBackward0>)


ipdb>  m_embed[0] * m_alpha[0]


tensor([ 2.0858e-02,  7.3272e-02, -1.2969e-02,  9.0335e-03, -1.8998e-03,
         1.3676e-02,  1.0693e-02,  2.4822e-02, -2.4273e-02,  1.3627e-02,
         1.6462e-02, -1.0414e-02,  2.5757e-02,  5.3670e-03, -1.3406e-02,
         2.2081e-02,  1.7104e-02,  5.8230e-03,  2.6549e-03, -1.8256e-02,
        -2.5436e-02,  2.6339e-02,  1.2694e-03,  1.3380e-02, -7.7870e-03,
         8.4381e-03,  4.5814e-02, -7.3728e-03,  3.2947e-04, -3.0776e-02,
         4.4523e-02,  2.2318e-02,  1.6200e-02,  2.0518e-02, -1.6694e-02,
         2.1503e-02,  1.9884e-02, -2.8900e-02, -1.8003e-02, -1.3380e-02,
        -4.6685e-03, -1.6721e-02,  1.8952e-02, -2.1304e-02, -2.0984e-03,
         1.4539e-02,  3.0005e-02, -2.3299e-02, -2.2655e-02,  3.3433e-02,
        -1.1487e-02, -9.4277e-03, -8.3740e-03, -2.3791e-02, -3.9788e-02,
         1.1913e-02, -3.2841e-02,  1.3988e-02, -2.1339e-02,  2.7591e-02,
         1.0017e-02,  2.6924e-02,  2.4008e-02, -1.6604e-02,  2.1425e-03,
        -4.2652e-02,  8.8554e-03,  2.0540e-02,  2.0

ipdb>  m_embed[1] * m_alpha[1]


tensor([-2.3788e-02,  2.9714e-02, -5.7876e-02, -2.1765e-04,  1.2507e-02,
        -2.3835e-02,  5.5061e-02,  1.3210e-02, -3.4549e-02,  3.4790e-02,
         2.9185e-02,  5.8161e-03,  5.2941e-03, -3.6154e-03, -3.2608e-02,
         1.5085e-02,  1.7241e-02, -1.1665e-02, -1.4718e-02,  2.2050e-04,
        -2.7501e-02,  2.3613e-02, -1.1609e-03,  2.3307e-02,  1.5139e-02,
         2.2802e-02, -7.8086e-03, -9.5263e-03, -2.3886e-02, -1.3995e-03,
         6.6057e-03, -1.3894e-03,  1.8815e-02,  1.5496e-02,  6.6989e-03,
         1.6695e-02,  2.1264e-02,  6.0278e-03, -9.8589e-03,  7.6538e-03,
        -1.5353e-02, -1.4067e-02,  1.4137e-02, -1.4810e-02, -8.5191e-03,
        -2.1247e-02,  1.0426e-02, -1.4058e-02, -1.0089e-02, -1.4398e-03,
         1.1911e-02, -6.1911e-03,  1.8214e-02,  2.3857e-02, -3.6973e-02,
        -5.8674e-02, -2.8984e-02, -3.8586e-02,  2.2215e-02,  1.9510e-02,
        -3.8010e-03, -6.5723e-03, -2.0161e-02,  1.9062e-02, -3.2118e-02,
         4.1047e-03, -2.4865e-02,  6.7011e-02,  2.4

ipdb>  l


     32 
     33             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
     34             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     35 
     36             m_alpha = F.sigmoid(m_args['alpha']).unsqueeze(1)
---> 37             m_embed = m_alpha * m_embed + (1 - m_alpha) * F.normalize(m_args['meta_repr'], dim=1)
     38 
     39             meta_repr[m_key] = m_embed
     40 
     41             m_embed = m_embed.view(bsz, -1, self.config.dim)
     42             fused_embed = self.cross_head(embed, m_embed)



ipdb>  m_args['meta_repr'].norm(dim=1)


tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  (1 - m_alpha)


tensor([[0.3378],
        [0.3196],
        [0.3240],
        [0.3981],
        [0.2764],
        [0.3310],
        [0.2716],
        [0.4662],
        [0.3302],
        [0.3163],
        [0.4197],
        [0.2700],
        [0.3842],
        [0.3422],
        [0.4674]], grad_fn=<RsubBackward1>)


ipdb>  n


> /tmp/ipykernel_6442/3083163642.py(39)fuse_meta_into_embeddings()
     37             m_embed = m_alpha * m_embed + (1 - m_alpha) * F.normalize(m_args['meta_repr'], dim=1)
     38 
---> 39             meta_repr[m_key] = m_embed
     40 
     41             m_embed = m_embed.view(bsz, -1, self.config.dim)



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(41)fuse_meta_into_embeddings()
     39             meta_repr[m_key] = m_embed
     40 
---> 41             m_embed = m_embed.view(bsz, -1, self.config.dim)
     42             fused_embed = self.cross_head(embed, m_embed)
     43             embed = embed + fused_embed



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(42)fuse_meta_into_embeddings()
     40 
     41             m_embed = m_embed.view(bsz, -1, self.config.dim)
---> 42             fused_embed = self.cross_head(embed, m_embed)
     43             embed = embed + fused_embed
     44 



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(43)fuse_meta_into_embeddings()
     41             m_embed = m_embed.view(bsz, -1, self.config.dim)
     42             fused_embed = self.cross_head(embed, m_embed)
---> 43             embed = embed + fused_embed
     44 
     45         return embed, meta_repr



ipdb>  fused_embed.shape


torch.Size([5, 768])


ipdb>  fused_embed.norm(dim=1)


tensor([0.5004, 0.4864, 0.4790, 0.4631, 0.4581],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  n


> /tmp/ipykernel_6442/3083163642.py(29)fuse_meta_into_embeddings()
     27         meta_repr, bsz = {}, embed.size(0)
     28 
---> 29         for m_key, m_args in meta_kwargs.items():
     30             n_meta = m_args['data2ptr'].max()
     31             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'



ipdb>  embed.norm(dim=1)


tensor([1.2409, 1.3489, 1.1833, 1.3607, 1.1987],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  n


> /tmp/ipykernel_6442/3083163642.py(45)fuse_meta_into_embeddings()
     43             embed = embed + fused_embed
     44 
---> 45         return embed, meta_repr
     46 
     47     def params_from_meta_aug_prefix(self, prefix:str, **kwargs):



ipdb>  


--Return--
(tensor([[-4.4...AddBackward0>), {'lnk2data': tensor([[ 0.0...AddBackward0>)})
> /tmp/ipykernel_6442/3083163642.py(45)fuse_meta_into_embeddings()
     43             embed = embed + fused_embed
     44 
---> 45         return embed, meta_repr
     46 
     47     def params_from_meta_aug_prefix(self, prefix:str, **kwargs):



ipdb>  


> /tmp/ipykernel_6442/3083163642.py(71)forward()
     69             if len(meta_kwargs):
     70                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
---> 71                 data_fused_repr = self.dr(data_fused_repr)
     72 
     73         return EncoderOutput(



ipdb>  n


> /tmp/ipykernel_6442/3083163642.py(73)forward()
     71                 data_fused_repr = self.dr(data_fused_repr)
     72 
---> 73         return EncoderOutput(
     74             rep=data_repr,
     75             fused_rep=data_fused_repr,



ipdb>  data_fused_repr.norm(dim=1)


tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  n


> /tmp/ipykernel_6442/3083163642.py(74)forward()
     72 
     73         return EncoderOutput(
---> 74             rep=data_repr,
     75             fused_rep=data_fused_repr,
     76             meta_repr=meta_repr,



ipdb>  c


> /tmp/ipykernel_6442/3083163642.py(64)forward()
     62         **kwargs
     63     ):  
---> 64         data_repr = self.dr_encode(data_input_ids, data_attention_mask)
     65 
     66         data_fused_repr = meta_repr = None



ipdb>  n


> /tmp/ipykernel_6442/3083163642.py(66)forward()
     64         data_repr = self.dr_encode(data_input_ids, data_attention_mask)
     65 
---> 66         data_fused_repr = meta_repr = None
     67         if data_aug_meta_prefix is not None:
     68             meta_kwargs = self.params_from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [None]:
o = model(**b.to(model.device))

> /tmp/ipykernel_38498/1476390753.py(61)forward()
     59         **kwargs
     60     ):  
---> 61         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     62 
     63         if self.use_encoder_parallel:



ipdb>  c


> /tmp/ipykernel_38498/1294998010.py(69)forward()
     67         **kwargs
     68     ):  
---> 69         data_repr = self.dr_encode(data_input_ids, data_attention_mask)
     70 
     71         data_fused_repr = meta_repr = None



ipdb>  n


> /tmp/ipykernel_38498/1294998010.py(71)forward()
     69         data_repr = self.dr_encode(data_input_ids, data_attention_mask)
     70 
---> 71         data_fused_repr = meta_repr = None
     72         if data_aug_meta_prefix is not None:
     73             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_38498/1294998010.py(72)forward()
     70 
     71         data_fused_repr = meta_repr = None
---> 72         if data_aug_meta_prefix is not None:
     73             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     74             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_38498/1294998010.py(73)forward()
     71         data_fused_repr = meta_repr = None
     72         if data_aug_meta_prefix is not None:
---> 73             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     74             if len(meta_kwargs):
     75                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)



ipdb>  


> /tmp/ipykernel_38498/1294998010.py(74)forward()
     72         if data_aug_meta_prefix is not None:
     73             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
---> 74             if len(meta_kwargs):
     75                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
     76                 data_fused_repr = self.dr(data_fused_repr)



ipdb>  len(meta_kwargs)


1


ipdb>  n


> /tmp/ipykernel_38498/1294998010.py(75)forward()
     73             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     74             if len(meta_kwargs):
---> 75                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, meta_kwargs)
     76                 data_fused_repr = self.dr(data_fused_repr)
     77 



ipdb>  s


--Call--
> /tmp/ipykernel_38498/1294998010.py(43)fuse_meta_into_embeddings()
     41                 sd_meta[k].copy_(sd_dr[k])
     42 
---> 43     def fuse_meta_into_embeddings(self, embed:torch.Tensor, meta_kwargs:Dict):
     44         meta_repr, bsz = {}, embed.size(0)
     45 



ipdb>  n


> /tmp/ipykernel_38498/1294998010.py(44)fuse_meta_into_embeddings()
     42 
     43     def fuse_meta_into_embeddings(self, embed:torch.Tensor, meta_kwargs:Dict):
---> 44         meta_repr, bsz = {}, embed.size(0)
     45 
     46         for m_key, m_args in meta_kwargs.items():



ipdb>  n


> /tmp/ipykernel_38498/1294998010.py(46)fuse_meta_into_embeddings()
     44         meta_repr, bsz = {}, embed.size(0)
     45 
---> 46         for m_key, m_args in meta_kwargs.items():
     47             n_meta = m_args['data2ptr'].max()
     48             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'



ipdb>  n


> /tmp/ipykernel_38498/1294998010.py(47)fuse_meta_into_embeddings()
     45 
     46         for m_key, m_args in meta_kwargs.items():
---> 47             n_meta = m_args['data2ptr'].max()
     48             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     49 



ipdb>  m_key


'lnk2data'


ipdb>  n


> /tmp/ipykernel_38498/1294998010.py(48)fuse_meta_into_embeddings()
     46         for m_key, m_args in meta_kwargs.items():
     47             n_meta = m_args['data2ptr'].max()
---> 48             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     49 
     50             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']



ipdb>  


> /tmp/ipykernel_38498/1294998010.py(50)fuse_meta_into_embeddings()
     48             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     49 
---> 50             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
     51             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     52             meta_repr[m_key] = m_embed



ipdb>  


> /tmp/ipykernel_38498/1294998010.py(51)fuse_meta_into_embeddings()
     49 
     50             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
---> 51             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     52             meta_repr[m_key] = m_embed
     53 



ipdb>  m_input_ids.shape


torch.Size([15, 14])


ipdb>  m_attention_mask.shape


torch.Size([15, 14])


ipdb>  q


In [None]:
o.loss

tensor(0.0615, grad_fn=<AddBackward0>)

## `OAK007`

In [None]:
#| export
class OAK007(OAK003, DistilBertPreTrainedModel):
    
    @delegates(OAK003.__init__)
    def __init__(
        self, 
        config,
        n_labels:int,
        n_clusters:int,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.label_embeddings = nn.Embedding(n_clusters, config.dim)
        self.register_buffer("label_cluster_mapping", torch.arange(n_labels)%n_clusters, persistent=True)
        self.post_init(); self.remap_post_init();

    def init_label_embeddings(self):
        torch.nn.init.zeros_(self.label_embeddings.weight)

    def set_label_embeddings(self, embed:torch.Tensor):
        with torch.no_grad():
            self.label_embeddings.weight.copy_(embed)

    def set_label_cluster_mapping(self, label_cluster_mapping:torch.Tensor):
        if self.label_cluster_mapping.shape[0] != label_cluster_mapping.shape[0]:
            raise ValueError(f'Shape mismatch, `label_cluster_mapping` should have {self.label_cluster_mapping.shape[0]} elements.')
        with torch.no_grad():
            self.label_cluster_mapping.copy_(label_cluster_mapping)

    def get_label_representation(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
        data_o.rep = F.normalize(data_o.rep + self.label_embeddings(self.label_cluster_mapping[data_idx]), dim=1)
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    def forward(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)

        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_cluster_mapping[lbl2data_idx]), dim=1)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
                
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

### Example

In [None]:
model = OAK007.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=1000, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               n_labels=block.n_lbl, n_clusters=block.n_lbl//3,hk

                               use_query_loss=True,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_meta_encoder()
model.init_cross_head()

model.init_label_embeddings()
#model.set_label_cluster_mapping(lbl_remap)

Some weights of OAK007 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'label_cluster_mapping', 'label_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(0.0610, device='cuda:0', grad_fn=<AddBackward0>)