In [1]:
#| default_exp models.oakX

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn, math
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import (
    PretrainedConfig,
    DistilBertForMaskedLM,
    DistilBertModel,
    DistilBertPreTrainedModel,
)
from transformers.utils.generic import ModelOutput
from transformers.activations import get_activation

from fastcore.meta import *
from fastcore.utils import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

In [4]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [5]:
from transformers import AutoConfig
from xcai.block import *

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

## Setup

In [7]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC/data/'

block = XCBlock.from_cfg(data_dir, 'data_lnk', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                         sampling_features=[('lbl2data',1), ('lnk2data',3)])



In [8]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-lnk_distilbert-base-uncased_xcs.pkl'

In [9]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [8]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [10]:
block.collator.tfms.tfms[0].sampling_features = [('lbl2data',1), ('lnk2data',3)]

In [11]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 2: break

In [15]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'lnk2data_data2ptr', 'lnk2lbl2data_idx', 'lnk2lbl2data_identifier', 'lnk2lbl2data_input_text', 'lnk2lbl2data_input_ids', 'lnk2lbl2data_attention_mask', 'lnk2lbl2data_data2ptr', 'lnk2lbl2data_lbl2data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx'])

## Encoder

In [32]:
#| export
class Encoder(DistilBertPreTrainedModel):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
        num_metadata:int,
        resize_length:Optional[int]=None,
    ):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        
        self.dr_head = RepresentationHead(config)
        self.dr_fused_head = RepresentationHead(config)
        self.meta_head = RepresentationHead(config)
        self.cross_head = CrossAttention(config)

        self.pretrained_meta_embeddings = nn.Embedding(num_metadata, config.dim, sparse=True)
        
        self.ones = torch.ones(resize_length, dtype=torch.long, device=self.device) if resize_length is not None else None
        self.post_init()

    def freeze_pretrained_meta_embeddings(self):
        self.pretrained_meta_embeddings.requires_grad_(False)

    def unfreeze_pretrained_meta_embeddings(self):
        self.pretrained_meta_embeddings.requires_grad_(True)

    def set_pretrained_meta_embeddings(self, embed:torch.Tensor):
        with torch.no_grad():
            self.pretrained_meta_embeddings.weight.copy_(embed)
        
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
    
    def dr(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def dr_fused(self, embed:torch.Tensor):
        embed = self.dr_fused_head(embed)
        return F.normalize(embed, dim=1)

    def meta(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)
    
    def meta_unnormalized(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return Pooling.mean_pooling(embed, attention_mask)

    def resize(self, idx:torch.Tensor, rep:torch.Tensor, num_inputs:torch.Tensor):
        if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
        bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
        
        self.ones = self.ones.to(idx.device)
        ones = (
            torch.ones(total_num_inputs, dtype=torch.long, device=idx.device) 
            if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
        )
        
        max_num_inputs = num_inputs.max()
        if (num_inputs == max_num_inputs).all():
            return idx,rep,ones
        
        xnum_inputs = max_num_inputs-num_inputs+1
        
        inputs_ptr = num_inputs.cumsum(dim=0)-1
        repeat_inputs = ones.scatter(0, inputs_ptr, xnum_inputs)
        
        resized_idx = idx.repeat_interleave(repeat_inputs, dim=0)
        resized_rep = rep.repeat_interleave(repeat_inputs, dim=0)
        
        ignore_mask = ones.scatter(0, inputs_ptr, 0).repeat_interleave(repeat_inputs, dim=0).view(bsz, -1)
        ignore_mask[:, -1] = 1; ignore_mask = ignore_mask.flatten()
        
        return resized_idx,resized_rep,ignore_mask
        

    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
            
            if len(idx):
                m_idx,m_repr,m_repr_mask = self.resize(m_args['idx'], m_args['meta_repr'], m_args['data2ptr'][idx])
                m_repr = F.normalize(m_repr + self.pretrained_meta_embeddings(m_idx), dim=1)
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
                meta_repr[m_key] = m_repr[m_repr_mask]
                
                fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
                data_fused_repr[idx] += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr

    
    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
                                                                            torch.any(data_attention_mask, dim=1), 
                                                                            meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_repr)
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

## `OAK000`

In [109]:
#| export
class OAK000(nn.Module):
    
    def __init__(
        self, config,
        
        data_aug_meta_prefix:Optional[str]=None, 
        lbl2data_aug_meta_prefix:Optional[str]=None, 

        data_pred_meta_prefix:Optional[str]=None,
        lbl2data_pred_meta_prefix:Optional[str]=None,
        
        num_batch_labels:Optional[int]=None, 
        batch_size:Optional[int]=None,
        margin:Optional[float]=0.3,
        num_negatives:Optional[int]=5,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,

        calib_margin:Optional[float]=0.3,
        calib_num_negatives:Optional[int]=10,
        calib_tau:Optional[float]=0.1,
        calib_apply_softmax:Optional[bool]=False,
        calib_loss_weight:Optional[float]=0.1,
        use_calib_loss:Optional[float]=False,
        
        meta_loss_weight:Optional[Union[List,float]]=0.3,
        
        use_fusion_loss:Optional[bool]=False,
        fusion_loss_weight:Optional[float]=0.15,

        use_query_loss:Optional[float]=True,
        
        use_encoder_parallel:Optional[bool]=True,
    ):
        super().__init__(config)
        store_attr('meta_loss_weight,fusion_loss_weight,calib_loss_weight')
        store_attr('data_pred_meta_prefix,lbl2data_pred_meta_prefix')
        store_attr('data_aug_meta_prefix,lbl2data_aug_meta_prefix')
        store_attr('use_fusion_loss,use_query_loss,use_calib_loss,use_encoder_parallel')
        
        self.rep_loss_fn = MultiTriplet(bsz=batch_size, tn_targ=num_batch_labels, margin=margin, n_negatives=num_negatives, 
                                        tau=tau, apply_softmax=apply_softmax, reduce='mean')
        self.cab_loss_fn = Calibration(margin=calib_margin, tau=calib_tau, n_negatives=calib_num_negatives, 
                                       apply_softmax=calib_apply_softmax, reduce='mean')
        self.encoder, self.meta_embeddings = None, None
        
    def init_meta_embeddings(self):
        torch.nn.init.zeros_(self.meta_embeddings.weight)

    def freeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(False)

    def unfreeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(True)

    def set_meta_embeddings(self, embed:torch.Tensor):
        with torch.no_grad():
            self.meta_embeddings.weight.copy_(embed)
    
    def init_retrieval_head(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        self.encoder.dr_head.post_init()
        self.encoder.meta_head.post_init()
        self.encoder.dr_fused_head.post_init()

    def init_cross_head(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        self.encoder.cross_head.post_init()

    def compute_loss(self, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def calibration_loss(self, einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.calib_loss_weight * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)
    
    def compute_meta_loss(self, data_repr, lbl2data_repr, **kwargs):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_meta_inputs = Parameters.from_meta_pred_prefix(self.data_pred_meta_prefix, **kwargs)
        lbl2data_meta_inputs = Parameters.from_meta_pred_prefix(self.lbl2data_pred_meta_prefix, **kwargs)
        meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}

        m_lw = Parameters.get_meta_loss_weights(self.meta_loss_weight, len(meta_inputs)) if len(meta_inputs) else []
        
        loss = 0.0
        for inputs,lw in zip(meta_inputs.values(), m_lw):
            if 'lbl2data2ptr' in inputs:
                idx = torch.where(inputs['lbl2data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(lbl2data_repr[idx], inputs_o.rep, inputs['lbl2data2ptr'][idx],
                                              inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss

            elif 'data2ptr' in inputs:
                idx = torch.where(inputs['data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(data_repr[idx], inputs_o.rep, inputs['data2ptr'][idx], inputs['idx'], 
                                              inputs['pdata2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss       

            else: raise ValueError('Invalid metadata input arguments.')
        return loss

    def compute_fusion_loss(self, data_repr, meta_repr:Dict, prefix:str, **kwargs):
        meta_inputs = Parameters.from_meta_pred_prefix(prefix, **kwargs)
        
        loss = 0.0
        if meta_repr is not None:
            for key,input_repr in meta_repr.items():
                inputs = meta_inputs[key]
                if 'lbl2data2ptr' in inputs:
                    idx = torch.where(inputs['lbl2data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['lbl2data2ptr'][idx],
                                                  inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                        loss += self.fusion_loss_weight * m_loss
    
                elif 'data2ptr' in inputs:
                    idx = torch.where(inputs['data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
                                                  inputs['pdata2ptr'][idx], inputs['pidx'])
                        loss += self.fusion_loss_weight * m_loss       
    
                else: raise ValueError('Invalid metadata input arguments.')
        return loss


    def get_meta_representation(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_unnormalized=True, data_type="meta")
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    def _get_encoder_meta_kwargs(self, feat:str, prefix:str, **kwargs):
        meta_kwargs = Parameters.from_feat_meta_aug_prefix(feat, prefix, **kwargs)
        return meta_kwargs
    
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

## `OAK001`

In [110]:
#| export
class OAK001(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.meta_embeddings = nn.Embedding(num_metadata, config.dim, sparse=True)
        self.encoder = Encoder(config, num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def _get_encoder_meta_kwargs(self, feat:str, prefix:str, **kwargs):
        meta_kwargs = Parameters.from_feat_meta_aug_prefix(feat, prefix, **kwargs)
        if f'{prefix}_idx' in meta_kwargs:
            m_idx = meta_kwargs[f'{prefix}_idx']
            if len(m_idx): meta_kwargs[f'{prefix}_meta_repr'] = self.meta_embeddings(m_idx)
        return meta_kwargs

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [111]:
model = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

Some weights of OAK001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 

In [112]:
model.encoder.set_pretrained_meta_embeddings(torch.ones(block.train.dset.meta['lnk_meta'].n_meta, model.config.dim))
model.encoder.freeze_pretrained_meta_embeddings()

In [113]:
model = model.to('cuda')

In [114]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [83]:
o = model(**b.to(model.device))

In [84]:
o.loss

tensor(0.0183, device='cuda:0', grad_fn=<AddBackward0>)

In [115]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [116]:
func()

> /tmp/ipykernel_36220/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_36220/1449976710.py:158


ipdb>  r


> /tmp/ipykernel_36220/1449976710.py(173)forward()
    171         **kwargs
    172     ):  
--> 173         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    174 
    175         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_36220/1449976710.py(175)forward()
    173         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    174 
--> 175         if self.use_encoder_parallel:
    176             encoder = XCDataParallel(module=self.encoder)
    177         else: encoder = self.encoder



ipdb>  n


> /tmp/ipykernel_36220/1449976710.py(177)forward()
    175         if self.use_encoder_parallel:
    176             encoder = XCDataParallel(module=self.encoder)
--> 177         else: encoder = self.encoder
    178 
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  n


> /tmp/ipykernel_36220/1449976710.py(179)forward()
    177         else: encoder = self.encoder
    178 
--> 179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
    180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(180)forward()
    178 
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
--> 180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    182 



ipdb>  data_meta_kwargs.shape


*** AttributeError: 'dict' object has no attribute 'shape'


ipdb>  data_meta_kwargs.keys()


dict_keys(['lnk2data_attention_mask', 'lnk2data_input_ids', 'lnk2data_idx', 'lnk2data_data2ptr', 'lnk2data_meta_repr'])


ipdb>  data_meta_kwargs['lnk2data_data2ptr'].shape


torch.Size([5])


ipdb>  data_meta_kwargs['lnk2data_data2ptr']


tensor([3, 3, 3, 3, 3], device='cuda:0')


ipdb>  l


    175         if self.use_encoder_parallel:
    176             encoder = XCDataParallel(module=self.encoder)
    177         else: encoder = self.encoder
    178 
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
--> 180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    182 
    183 
    184         loss = None; lbl2data_o = EncoderOutput()
    185         if lbl2data_input_ids is not None:



ipdb>  b encoder.forward


Breakpoint 2 at /tmp/ipykernel_36220/179234550.py:111


ipdb>  n


> /tmp/ipykernel_36220/1449976710.py(181)forward()
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
    180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
--> 181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    182 
    183 



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(180)forward()
    178 
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
--> 180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    182 



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(181)forward()
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
    180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
--> 181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    182 
    183 



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(180)forward()
    178 
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
--> 180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    182 



ipdb>  n


> /tmp/ipykernel_36220/179234550.py(120)forward()
    118         **kwargs
    119     ):  
--> 120         data_o = self.encode(data_input_ids, data_attention_mask)
    121 
    122         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_36220/179234550.py(122)forward()
    120         data_o = self.encode(data_input_ids, data_attention_mask)
    121 
--> 122         if data_type is not None and data_type == "meta":
    123             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    124         else:



ipdb>  data_type
ipdb>  n


> /tmp/ipykernel_36220/179234550.py(125)forward()
    123             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    124         else:
--> 125             data_repr = self.dr(data_o[0], data_attention_mask)
    126 
    127         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_36220/179234550.py(127)forward()
    125             data_repr = self.dr(data_o[0], data_attention_mask)
    126 
--> 127         data_fused_repr = meta_repr = None
    128         if data_aug_meta_prefix is not None:
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_36220/179234550.py(128)forward()
    126 
    127         data_fused_repr = meta_repr = None
--> 128         if data_aug_meta_prefix is not None:
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    130             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_36220/179234550.py(129)forward()
    127         data_fused_repr = meta_repr = None
    128         if data_aug_meta_prefix is not None:
--> 129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    130             if len(meta_kwargs):
    131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(130)forward()
    128         if data_aug_meta_prefix is not None:
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 130             if len(meta_kwargs):
    131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    132                                                                             torch.any(data_attention_mask, dim=1),



ipdb>  meta_kwargs.keys()


dict_keys(['lnk2data'])


ipdb>  meta_kwargs['lnk2data'].keys()


dict_keys(['attention_mask', 'input_ids', 'idx', 'data2ptr', 'meta_repr'])


ipdb>  meta_kwargs['lnk2data']['meta_repr'].shape


torch.Size([15, 768])


ipdb>  n


> /tmp/ipykernel_36220/179234550.py(131)forward()
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    130             if len(meta_kwargs):
--> 131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    132                                                                             torch.any(data_attention_mask, dim=1),
    133                                                                             meta_kwargs)



ipdb>  n


> /tmp/ipykernel_36220/179234550.py(132)forward()
    130             if len(meta_kwargs):
    131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
--> 132                                                                             torch.any(data_attention_mask, dim=1),
    133                                                                             meta_kwargs)
    134                 data_fused_repr = self.dr_fused(data_fused_repr)



ipdb>  


> /tmp/ipykernel_36220/179234550.py(133)forward()
    131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    132                                                                             torch.any(data_attention_mask, dim=1),
--> 133                                                                             meta_kwargs)
    134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(131)forward()
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    130             if len(meta_kwargs):
--> 131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    132                                                                             torch.any(data_attention_mask, dim=1),
    133                                                                             meta_kwargs)



ipdb>  s


--Call--
> /tmp/ipykernel_36220/179234550.py(90)fuse_meta_into_embeddings()
     88 
     89 
---> 90     def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
     91         meta_repr = {}
     92 



ipdb>  n


> /tmp/ipykernel_36220/179234550.py(91)fuse_meta_into_embeddings()
     89 
     90     def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
---> 91         meta_repr = {}
     92 
     93         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)



ipdb>  


> /tmp/ipykernel_36220/179234550.py(93)fuse_meta_into_embeddings()
     91         meta_repr = {}
     92 
---> 93         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
     94         for m_key, m_args in meta_kwargs.items():
     95             idx = torch.where(m_args['data2ptr'] > 0)[0]



ipdb>  


> /tmp/ipykernel_36220/179234550.py(94)fuse_meta_into_embeddings()
     92 
     93         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
---> 94         for m_key, m_args in meta_kwargs.items():
     95             idx = torch.where(m_args['data2ptr'] > 0)[0]
     96             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)



ipdb>  data_fused_repr.shape


torch.Size([5, 1, 768])


ipdb>  n


> /tmp/ipykernel_36220/179234550.py(95)fuse_meta_into_embeddings()
     93         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
     94         for m_key, m_args in meta_kwargs.items():
---> 95             idx = torch.where(m_args['data2ptr'] > 0)[0]
     96             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
     97 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(96)fuse_meta_into_embeddings()
     94         for m_key, m_args in meta_kwargs.items():
     95             idx = torch.where(m_args['data2ptr'] > 0)[0]
---> 96             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
     97 
     98             if len(idx):



ipdb>  idx


tensor([0, 1, 2, 3, 4], device='cuda:0')


ipdb>  n


> /tmp/ipykernel_36220/179234550.py(98)fuse_meta_into_embeddings()
     96             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
     97 
---> 98             if len(idx):
     99                 m_idx,m_repr,m_repr_mask = self.resize(m_args['idx'], m_args['meta_repr'], m_args['data2ptr'][idx])
    100                 m_repr = F.normalize(m_repr + self.pretrained_meta_embeddings(m_idx), dim=1)



ipdb>  


> /tmp/ipykernel_36220/179234550.py(99)fuse_meta_into_embeddings()
     97 
     98             if len(idx):
---> 99                 m_idx,m_repr,m_repr_mask = self.resize(m_args['idx'], m_args['meta_repr'], m_args['data2ptr'][idx])
    100                 m_repr = F.normalize(m_repr + self.pretrained_meta_embeddings(m_idx), dim=1)
    101 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(100)fuse_meta_into_embeddings()
     98             if len(idx):
     99                 m_idx,m_repr,m_repr_mask = self.resize(m_args['idx'], m_args['meta_repr'], m_args['data2ptr'][idx])
--> 100                 m_repr = F.normalize(m_repr + self.pretrained_meta_embeddings(m_idx), dim=1)
    101 
    102                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)



ipdb>  m_idx


tensor([117344, 161862, 161861, 487671, 309544, 129374, 102557, 102556,  54423,
        559742,  79395, 102815,  84877,  84874,  84871], device='cuda:0')


ipdb>  n


> /tmp/ipykernel_36220/179234550.py(102)fuse_meta_into_embeddings()
    100                 m_repr = F.normalize(m_repr + self.pretrained_meta_embeddings(m_idx), dim=1)
    101 
--> 102                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
    103                 meta_repr[m_key] = m_repr[m_repr_mask]
    104 



ipdb>  n


> /tmp/ipykernel_36220/179234550.py(103)fuse_meta_into_embeddings()
    101 
    102                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
--> 103                 meta_repr[m_key] = m_repr[m_repr_mask]
    104 
    105                 fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]



ipdb>  


> /tmp/ipykernel_36220/179234550.py(105)fuse_meta_into_embeddings()
    103                 meta_repr[m_key] = m_repr[m_repr_mask]
    104 
--> 105                 fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
    106                 data_fused_repr[idx] += fused_repr
    107 



ipdb>  data_fused_repr[idx].shape


torch.Size([5, 1, 768])


ipdb>  data_mask[idx].shape


torch.Size([5, 1])


ipdb>  m_repr


tensor([[[0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361],
         [0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361],
         [0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361]],

        [[0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361],
         [0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361],
         [0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361]],

        [[0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361],
         [0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361],
         [0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361]],

        [[0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361],
         [0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361],
         [0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361]],

        [[0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361],
         [0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361],
         [0.0361, 0.0361, 0.0361,  ..., 0.0361, 0.0361, 0.0361]]],
       device='

ipdb>  m_repr.shape


torch.Size([5, 3, 768])


ipdb>  l


    100                 m_repr = F.normalize(m_repr + self.pretrained_meta_embeddings(m_idx), dim=1)
    101 
    102                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
    103                 meta_repr[m_key] = m_repr[m_repr_mask]
    104 
--> 105                 fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
    106                 data_fused_repr[idx] += fused_repr
    107 
    108         return data_fused_repr.squeeze(), meta_repr
    109 
    110 



ipdb>  m_idx


tensor([117344, 161862, 161861, 487671, 309544, 129374, 102557, 102556,  54423,
        559742,  79395, 102815,  84877,  84874,  84871], device='cuda:0')


ipdb>  m_args['meta_repr']


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
       grad_fn=<EmbeddingBackward0>)


ipdb>  self.pretrained_meta_embeddings(m_idx)


tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0')


ipdb>  n


> /tmp/ipykernel_36220/179234550.py(106)fuse_meta_into_embeddings()
    104 
    105                 fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
--> 106                 data_fused_repr[idx] += fused_repr
    107 
    108         return data_fused_repr.squeeze(), meta_repr



ipdb>  


> /tmp/ipykernel_36220/179234550.py(94)fuse_meta_into_embeddings()
     92 
     93         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
---> 94         for m_key, m_args in meta_kwargs.items():
     95             idx = torch.where(m_args['data2ptr'] > 0)[0]
     96             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)



ipdb>  


> /tmp/ipykernel_36220/179234550.py(108)fuse_meta_into_embeddings()
    106                 data_fused_repr[idx] += fused_repr
    107 
--> 108         return data_fused_repr.squeeze(), meta_repr
    109 
    110 



ipdb>  


--Return--
(tensor([[ 0.0...ezeBackward0>), {'lnk2data': tensor([[0.03...dexBackward0>)})
> /tmp/ipykernel_36220/179234550.py(108)fuse_meta_into_embeddings()
    106                 data_fused_repr[idx] += fused_repr
    107 
--> 108         return data_fused_repr.squeeze(), meta_repr
    109 
    110 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(134)forward()
    132                                                                             torch.any(data_attention_mask, dim=1),
    133                                                                             meta_kwargs)
--> 134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 
    136         return EncoderOutput(



ipdb>  


> /tmp/ipykernel_36220/179234550.py(136)forward()
    134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 
--> 136         return EncoderOutput(
    137             rep=data_repr,
    138             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_36220/179234550.py(137)forward()
    135 
    136         return EncoderOutput(
--> 137             rep=data_repr,
    138             fused_rep=data_fused_repr,
    139             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_36220/179234550.py(138)forward()
    136         return EncoderOutput(
    137             rep=data_repr,
--> 138             fused_rep=data_fused_repr,
    139             meta_repr=meta_repr,
    140         )



ipdb>  


> /tmp/ipykernel_36220/179234550.py(139)forward()
    137             rep=data_repr,
    138             fused_rep=data_fused_repr,
--> 139             meta_repr=meta_repr,
    140         )
    141 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(136)forward()
    134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 
--> 136         return EncoderOutput(
    137             rep=data_repr,
    138             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...xBackward0>)})
> /tmp/ipykernel_36220/179234550.py(136)forward()
    134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 
--> 136         return EncoderOutput(
    137             rep=data_repr,
    138             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(184)forward()
    182 
    183 
--> 184         loss = None; lbl2data_o = EncoderOutput()
    185         if lbl2data_input_ids is not None:
    186             lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(185)forward()
    183 
    184         loss = None; lbl2data_o = EncoderOutput()
--> 185         if lbl2data_input_ids is not None:
    186             lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    187             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(186)forward()
    184         loss = None; lbl2data_o = EncoderOutput()
    185         if lbl2data_input_ids is not None:
--> 186             lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    187             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    188                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(187)forward()
    185         if lbl2data_input_ids is not None:
    186             lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
--> 187             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    188                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    189 



ipdb>  lbl2data_meta_kwargs.keys()


dict_keys([])


ipdb>  n


> /tmp/ipykernel_36220/1449976710.py(188)forward()
    186             lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    187             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
--> 188                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    189 
    190             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(187)forward()
    185         if lbl2data_input_ids is not None:
    186             lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
--> 187             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    188                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    189 



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(188)forward()
    186             lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    187             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
--> 188                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    189 
    190             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(187)forward()
    185         if lbl2data_input_ids is not None:
    186             lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
--> 187             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    188                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    189 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(120)forward()
    118         **kwargs
    119     ):  
--> 120         data_o = self.encode(data_input_ids, data_attention_mask)
    121 
    122         if data_type is not None and data_type == "meta":



ipdb>  c


XCModelOutput(loss=tensor(0.0183, device='cuda:0', grad_fn=<AddBackward0>), logits=None, data_repr=tensor([[-0.0346, -0.0148, -0.0316,  ...,  0.0544,  0.0028, -0.0312],
        [-0.0225, -0.0189, -0.0329,  ...,  0.0344, -0.0342,  0.0102],
        [ 0.0112,  0.0063, -0.0371,  ...,  0.0715, -0.0157,  0.0266],
        [ 0.0698,  0.0139,  0.0030,  ...,  0.0345, -0.0052,  0.0286],
        [-0.0170, -0.0224, -0.0301,  ...,  0.0244, -0.0210, -0.0329]],
       device='cuda:0', grad_fn=<DivBackward0>), data_fused_repr=tensor([[-0.0332, -0.0150, -0.0304,  ...,  0.0533,  0.0018, -0.0301],
        [-0.0219, -0.0187, -0.0315,  ...,  0.0327, -0.0326,  0.0088],
        [ 0.0099,  0.0052, -0.0356,  ...,  0.0717, -0.0159,  0.0252],
        [ 0.0695,  0.0125,  0.0020,  ...,  0.0330, -0.0059,  0.0271],
        [-0.0171, -0.0221, -0.0292,  ...,  0.0230, -0.0208, -0.0318]],
       device='cuda:0', grad_fn=<DivBackward0>), lbl2data_repr=tensor([[-0.0346, -0.0148, -0.0316,  ...,  0.0544,  0.0028, -0.0312],
 

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



## `OAK002`

In [85]:
#| export
class OAK002(OAK001, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK001.__init__)
    def __init__(
        self, 
        config,
        n_labels:int,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.label_embeddings = nn.Embedding(n_labels, config.dim, sparse=True)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_label_embeddings(self):
        torch.nn.init.zeros_(self.label_embeddings.weight)

    def get_label_representation(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
        data_o.rep = F.normalize(data_o.rep + self.label_embeddings(data_idx), dim=1)
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(lbl2data_idx), dim=1)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )



### Example

In [86]:
model = OAK002.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000, n_labels=block.n_lbl,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()
model.init_label_embeddings()

Some weights of OAK002 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 

In [87]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(block.train.dset.meta['lnk_meta'].n_meta, model.config.dim))
model.encoder.freeze_pretrained_meta_embeddings()

In [88]:
model = model.to('cuda')

In [89]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [36]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [90]:
o = model(**b.to(model.device))

In [54]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [91]:
o.loss

tensor(0.0213, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK003`

In [62]:
#| export
class OAK003(OAK001, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK001.__init__)
    def __init__(
        self, 
        config,
        n_labels:int,
        n_clusters:int,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.label_embeddings = nn.Embedding(n_clusters, config.dim, sparse=True)
        self.register_buffer("label_remap", torch.arange(n_labels)%n_clusters, persistent=True)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_label_embeddings(self):
        torch.nn.init.zeros_(self.label_embeddings.weight)

    def set_label_embeddings(self, embed:torch.Tensor):
        with torch.no_grad():
            self.label_embeddings.weight.copy_(embed)

    def set_label_remap(self, label_remap:torch.Tensor):
        if label_remap.shape[0] != self.label_remap.shape[0]:
            raise ValueError(f'Shape mismatch, `label_remap` should have {self.label_remap.shape[0]} elements.')
        self.label_remap = label_remap

    def get_label_representation(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
        data_o.rep = F.normalize(data_o.rep + self.label_embeddings(self.label_remap[data_idx]), dim=1)
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )



### Example

In [63]:
model = OAK003.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000, 
                               n_labels=block.n_lbl, n_clusters=block.n_lbl//3,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()
model.init_label_embeddings()

Some weights of OAK003 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 

In [64]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(block.train.dset.meta['lnk_meta'].n_meta, model.config.dim))
model.encoder.freeze_pretrained_meta_embeddings()

In [65]:
model = model.to('cuda')

In [66]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [62]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [67]:
o = model(**b.to(model.device))

In [68]:
o.loss

tensor(0.0213, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK004`

In [69]:
#| export
class OAK004(OAK001, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK001.__init__)
    def __init__(
        self, 
        config,
        n_labels:int,
        n_clusters:int,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.label_embeddings = nn.Embedding(n_clusters, config.dim)
        self.register_buffer("label_remap", torch.arange(n_labels)%n_clusters, persistent=True)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_label_embeddings(self):
        torch.nn.init.zeros_(self.label_embeddings.weight)

    def set_label_embeddings(self, embed:torch.Tensor):
        with torch.no_grad():
            self.label_embeddings.weight.copy_(embed)

    def set_label_remap(self, label_remap:torch.Tensor):
        if label_remap.shape[0] != self.label_remap.shape[0]:
            raise ValueError(f'Shape mismatch, `label_remap` should have {self.label_remap.shape[0]} elements.')
        with torch.no_grad():
            self.label_remap.copy_(label_remap)

    def get_label_representation(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
        data_o.rep = F.normalize(data_o.rep + self.label_embeddings(self.label_remap[data_idx]), dim=1)
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = self._get_encoder_meta_kwargs('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )



### Example

In [71]:
model = OAK004.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000, 
                               n_labels=block.n_lbl, n_clusters=block.n_lbl//3,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()
model.init_label_embeddings()

Some weights of OAK004 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 

In [73]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(block.train.dset.meta['lnk_meta'].n_meta, model.config.dim))
model.encoder.freeze_pretrained_meta_embeddings()

In [74]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [22]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [75]:
o = model(**b.to(model.device))

In [76]:
o.loss

tensor(0.0213, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK005`

In [117]:
#| export
class OAK005Encoder(Encoder):
    
    def __init__(
        self, 
        config:PretrainedConfig,
        **kwargs
    ):
        super().__init__(config, **kwargs)

    def resize(self, idx:torch.Tensor, num_inputs:torch.Tensor):
        if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
        bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
        
        self.ones = self.ones.to(idx.device)
        ones = (
            torch.ones(total_num_inputs, dtype=torch.long, device=idx.device) 
            if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
        )
        
        max_num_inputs = num_inputs.max()
        if (num_inputs == max_num_inputs).all():
            return idx,ones
        
        xnum_inputs = max_num_inputs-num_inputs+1
        
        inputs_ptr = num_inputs.cumsum(dim=0)-1
        repeat_inputs = ones.scatter(0, inputs_ptr, xnum_inputs)
        
        resized_idx = idx.repeat_interleave(repeat_inputs, dim=0)
        
        ignore_mask = ones.scatter(0, inputs_ptr, 0).repeat_interleave(repeat_inputs, dim=0).view(bsz, -1)
        ignore_mask[:, -1] = 1; ignore_mask = ignore_mask.flatten()
        
        return resized_idx,ignore_mask

    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
            
            if len(idx):
                m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
                m_repr = F.normalize(self.pretrained_meta_embeddings(m_idx), dim=1)
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
                meta_repr[m_key] = m_repr[m_repr_mask]
                
                fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
                data_fused_repr[idx] += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr
        

In [118]:
#| export
class OAK005(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = OAK005Encoder(config, num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def _get_encoder_meta_kwargs(self, feat:str, prefix:str, **kwargs):
        meta_kwargs = Parameters.from_feat_meta_aug_prefix(feat, prefix, **kwargs)
        return meta_kwargs

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [119]:
model = OAK005.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()

Some weights of OAK005 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 

In [120]:
model.encoder.set_pretrained_meta_embeddings(torch.ones(block.train.dset.meta['lnk_meta'].n_meta, model.config.dim))
model.encoder.freeze_pretrained_meta_embeddings()

In [105]:
model = model.to('cuda')

In [106]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [107]:
o = model(**b.to(model.device))

In [108]:
o.loss

tensor(0.0183, device='cuda:0', grad_fn=<AddBackward0>)

In [121]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [122]:
func()

> /tmp/ipykernel_36220/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  b model.forward


Breakpoint 3 at /tmp/ipykernel_36220/1449976710.py:158


ipdb>  c


> /tmp/ipykernel_36220/1449976710.py(173)forward()
    171         **kwargs
    172     ):  
--> 173         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    174 
    175         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_36220/1449976710.py(175)forward()
    173         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    174 
--> 175         if self.use_encoder_parallel:
    176             encoder = XCDataParallel(module=self.encoder)
    177         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(177)forward()
    175         if self.use_encoder_parallel:
    176             encoder = XCDataParallel(module=self.encoder)
--> 177         else: encoder = self.encoder
    178 
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(179)forward()
    177         else: encoder = self.encoder
    178 
--> 179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
    180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(180)forward()
    178 
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
--> 180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    182 



ipdb>  b encoder.forward


Breakpoint 4 at /tmp/ipykernel_36220/179234550.py:111


ipdb>  c


> /tmp/ipykernel_36220/1449976710.py(181)forward()
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
    180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
--> 181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    182 
    183 



ipdb>  n


> /tmp/ipykernel_36220/1449976710.py(180)forward()
    178 
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
--> 180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    182 



ipdb>  c


> /tmp/ipykernel_36220/1449976710.py(181)forward()
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
    180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
--> 181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    182 
    183 



ipdb>  n


> /tmp/ipykernel_36220/1449976710.py(180)forward()
    178 
    179         data_meta_kwargs = self._get_encoder_meta_kwargs('data', self.data_aug_meta_prefix, **kwargs)
--> 180         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    181                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    182 



ipdb>  c


> /tmp/ipykernel_36220/179234550.py(120)forward()
    118         **kwargs
    119     ):  
--> 120         data_o = self.encode(data_input_ids, data_attention_mask)
    121 
    122         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_36220/179234550.py(122)forward()
    120         data_o = self.encode(data_input_ids, data_attention_mask)
    121 
--> 122         if data_type is not None and data_type == "meta":
    123             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    124         else:



ipdb>  n


> /tmp/ipykernel_36220/179234550.py(125)forward()
    123             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    124         else:
--> 125             data_repr = self.dr(data_o[0], data_attention_mask)
    126 
    127         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_36220/179234550.py(127)forward()
    125             data_repr = self.dr(data_o[0], data_attention_mask)
    126 
--> 127         data_fused_repr = meta_repr = None
    128         if data_aug_meta_prefix is not None:
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_36220/179234550.py(128)forward()
    126 
    127         data_fused_repr = meta_repr = None
--> 128         if data_aug_meta_prefix is not None:
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    130             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_36220/179234550.py(129)forward()
    127         data_fused_repr = meta_repr = None
    128         if data_aug_meta_prefix is not None:
--> 129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    130             if len(meta_kwargs):
    131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(130)forward()
    128         if data_aug_meta_prefix is not None:
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 130             if len(meta_kwargs):
    131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    132                                                                             torch.any(data_attention_mask, dim=1),



ipdb>  meta_kwargs.keys()


dict_keys(['lnk2data'])


ipdb>  n


> /tmp/ipykernel_36220/179234550.py(131)forward()
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    130             if len(meta_kwargs):
--> 131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    132                                                                             torch.any(data_attention_mask, dim=1),
    133                                                                             meta_kwargs)



ipdb>  


> /tmp/ipykernel_36220/179234550.py(132)forward()
    130             if len(meta_kwargs):
    131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
--> 132                                                                             torch.any(data_attention_mask, dim=1),
    133                                                                             meta_kwargs)
    134                 data_fused_repr = self.dr_fused(data_fused_repr)



ipdb>  


> /tmp/ipykernel_36220/179234550.py(133)forward()
    131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    132                                                                             torch.any(data_attention_mask, dim=1),
--> 133                                                                             meta_kwargs)
    134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(131)forward()
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    130             if len(meta_kwargs):
--> 131                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    132                                                                             torch.any(data_attention_mask, dim=1),
    133                                                                             meta_kwargs)



ipdb>  s


--Call--
> /tmp/ipykernel_36220/2023555771.py(37)fuse_meta_into_embeddings()
     35         return resized_idx,ignore_mask
     36 
---> 37     def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
     38         meta_repr = {}
     39 



ipdb>  n


> /tmp/ipykernel_36220/2023555771.py(38)fuse_meta_into_embeddings()
     36 
     37     def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
---> 38         meta_repr = {}
     39 
     40         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(40)fuse_meta_into_embeddings()
     38         meta_repr = {}
     39 
---> 40         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
     41         for m_key, m_args in meta_kwargs.items():
     42             idx = torch.where(m_args['data2ptr'] > 0)[0]



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(41)fuse_meta_into_embeddings()
     39 
     40         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
---> 41         for m_key, m_args in meta_kwargs.items():
     42             idx = torch.where(m_args['data2ptr'] > 0)[0]
     43             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(42)fuse_meta_into_embeddings()
     40         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
     41         for m_key, m_args in meta_kwargs.items():
---> 42             idx = torch.where(m_args['data2ptr'] > 0)[0]
     43             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
     44 



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(43)fuse_meta_into_embeddings()
     41         for m_key, m_args in meta_kwargs.items():
     42             idx = torch.where(m_args['data2ptr'] > 0)[0]
---> 43             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
     44 
     45             if len(idx):



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(45)fuse_meta_into_embeddings()
     43             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
     44 
---> 45             if len(idx):
     46                 m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
     47                 m_repr = F.normalize(self.pretrained_meta_embeddings(m_idx), dim=1)



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(46)fuse_meta_into_embeddings()
     44 
     45             if len(idx):
---> 46                 m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
     47                 m_repr = F.normalize(self.pretrained_meta_embeddings(m_idx), dim=1)
     48 



ipdb>  n


> /tmp/ipykernel_36220/2023555771.py(47)fuse_meta_into_embeddings()
     45             if len(idx):
     46                 m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
---> 47                 m_repr = F.normalize(self.pretrained_meta_embeddings(m_idx), dim=1)
     48 
     49                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(49)fuse_meta_into_embeddings()
     47                 m_repr = F.normalize(self.pretrained_meta_embeddings(m_idx), dim=1)
     48 
---> 49                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
     50                 meta_repr[m_key] = m_repr[m_repr_mask]
     51 



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(50)fuse_meta_into_embeddings()
     48 
     49                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
---> 50                 meta_repr[m_key] = m_repr[m_repr_mask]
     51 
     52                 fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(52)fuse_meta_into_embeddings()
     50                 meta_repr[m_key] = m_repr[m_repr_mask]
     51 
---> 52                 fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
     53                 data_fused_repr[idx] += fused_repr
     54 



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(53)fuse_meta_into_embeddings()
     51 
     52                 fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
---> 53                 data_fused_repr[idx] += fused_repr
     54 
     55         return data_fused_repr.squeeze(), meta_repr



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(41)fuse_meta_into_embeddings()
     39 
     40         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
---> 41         for m_key, m_args in meta_kwargs.items():
     42             idx = torch.where(m_args['data2ptr'] > 0)[0]
     43             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)



ipdb>  


> /tmp/ipykernel_36220/2023555771.py(55)fuse_meta_into_embeddings()
     52                 fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
     53                 data_fused_repr[idx] += fused_repr
     54 
---> 55         return data_fused_repr.squeeze(), meta_repr
     56 



ipdb>  


--Return--
(tensor([[ 0.0...ezeBackward0>), {'lnk2data': tensor([[0.03...361, 0.0361]])})
> /tmp/ipykernel_36220/2023555771.py(55)fuse_meta_into_embeddings()
     52                 fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
     53                 data_fused_repr[idx] += fused_repr
     54 
---> 55         return data_fused_repr.squeeze(), meta_repr
     56 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(134)forward()
    132                                                                             torch.any(data_attention_mask, dim=1),
    133                                                                             meta_kwargs)
--> 134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 
    136         return EncoderOutput(



ipdb>  


> /tmp/ipykernel_36220/179234550.py(136)forward()
    134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 
--> 136         return EncoderOutput(
    137             rep=data_repr,
    138             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_36220/179234550.py(137)forward()
    135 
    136         return EncoderOutput(
--> 137             rep=data_repr,
    138             fused_rep=data_fused_repr,
    139             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_36220/179234550.py(138)forward()
    136         return EncoderOutput(
    137             rep=data_repr,
--> 138             fused_rep=data_fused_repr,
    139             meta_repr=meta_repr,
    140         )



ipdb>  


> /tmp/ipykernel_36220/179234550.py(139)forward()
    137             rep=data_repr,
    138             fused_rep=data_fused_repr,
--> 139             meta_repr=meta_repr,
    140         )
    141 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(136)forward()
    134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 
--> 136         return EncoderOutput(
    137             rep=data_repr,
    138             fused_rep=data_fused_repr,



ipdb>  c


> /tmp/ipykernel_36220/179234550.py(120)forward()
    118         **kwargs
    119     ):  
--> 120         data_o = self.encode(data_input_ids, data_attention_mask)
    121 
    122         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_36220/179234550.py(122)forward()
    120         data_o = self.encode(data_input_ids, data_attention_mask)
    121 
--> 122         if data_type is not None and data_type == "meta":
    123             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    124         else:



ipdb>  


> /tmp/ipykernel_36220/179234550.py(125)forward()
    123             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    124         else:
--> 125             data_repr = self.dr(data_o[0], data_attention_mask)
    126 
    127         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_36220/179234550.py(127)forward()
    125             data_repr = self.dr(data_o[0], data_attention_mask)
    126 
--> 127         data_fused_repr = meta_repr = None
    128         if data_aug_meta_prefix is not None:
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_36220/179234550.py(128)forward()
    126 
    127         data_fused_repr = meta_repr = None
--> 128         if data_aug_meta_prefix is not None:
    129             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    130             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_36220/179234550.py(136)forward()
    134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 
--> 136         return EncoderOutput(
    137             rep=data_repr,
    138             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_36220/179234550.py(137)forward()
    135 
    136         return EncoderOutput(
--> 137             rep=data_repr,
    138             fused_rep=data_fused_repr,
    139             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_36220/179234550.py(138)forward()
    136         return EncoderOutput(
    137             rep=data_repr,
--> 138             fused_rep=data_fused_repr,
    139             meta_repr=meta_repr,
    140         )



ipdb>  


> /tmp/ipykernel_36220/179234550.py(139)forward()
    137             rep=data_repr,
    138             fused_rep=data_fused_repr,
--> 139             meta_repr=meta_repr,
    140         )
    141 



ipdb>  


> /tmp/ipykernel_36220/179234550.py(136)forward()
    134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 
--> 136         return EncoderOutput(
    137             rep=data_repr,
    138             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...eta_repr=None)
> /tmp/ipykernel_36220/179234550.py(136)forward()
    134                 data_fused_repr = self.dr_fused(data_fused_repr)
    135 
--> 136         return EncoderOutput(
    137             rep=data_repr,
    138             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(190)forward()
    188                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    189 
--> 190             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
    191                                      plbl2data_data2ptr,plbl2data_idx)
    192 



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(191)forward()
    189 
    190             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
--> 191                                      plbl2data_data2ptr,plbl2data_idx)
    192 
    193             if self.use_query_loss:



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(190)forward()
    188                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    189 
--> 190             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
    191                                      plbl2data_data2ptr,plbl2data_idx)
    192 



ipdb>  


> /tmp/ipykernel_36220/1449976710.py(193)forward()
    191                                      plbl2data_data2ptr,plbl2data_idx)
    192 
--> 193             if self.use_query_loss:
    194                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
    195                                           plbl2data_data2ptr,plbl2data_idx)



ipdb>  c


XCModelOutput(loss=tensor(0.0183, grad_fn=<AddBackward0>), logits=None, data_repr=tensor([[-0.0346, -0.0148, -0.0316,  ...,  0.0544,  0.0028, -0.0312],
        [-0.0225, -0.0189, -0.0329,  ...,  0.0344, -0.0342,  0.0102],
        [ 0.0112,  0.0063, -0.0371,  ...,  0.0715, -0.0157,  0.0266],
        [ 0.0698,  0.0139,  0.0030,  ...,  0.0345, -0.0052,  0.0286],
        [-0.0170, -0.0224, -0.0301,  ...,  0.0244, -0.0210, -0.0329]],
       grad_fn=<DivBackward0>), data_fused_repr=tensor([[-0.0332, -0.0150, -0.0304,  ...,  0.0533,  0.0018, -0.0301],
        [-0.0219, -0.0187, -0.0315,  ...,  0.0327, -0.0326,  0.0088],
        [ 0.0099,  0.0052, -0.0356,  ...,  0.0717, -0.0159,  0.0252],
        [ 0.0695,  0.0125,  0.0020,  ...,  0.0330, -0.0059,  0.0271],
        [-0.0171, -0.0221, -0.0292,  ...,  0.0230, -0.0208, -0.0318]],
       grad_fn=<DivBackward0>), lbl2data_repr=tensor([[-0.0346, -0.0148, -0.0316,  ...,  0.0544,  0.0028, -0.0312],
        [-0.0225, -0.0189, -0.0329,  ...,  0.0344, -

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

