In [1]:
#| default_exp models.oakY

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn, math
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import (
    PretrainedConfig,
    DistilBertForMaskedLM,
    DistilBertModel,
    DistilBertPreTrainedModel,
    DistilBertConfig,
)
from transformers.utils.generic import ModelOutput
from transformers.activations import get_activation

from fastcore.meta import *
from fastcore.utils import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

In [4]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [5]:
from transformers import AutoConfig
from xcai.block import *

## Setup

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [7]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-lnk_distilbert-base-uncased_xcs.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [8]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [9]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 2: break

In [24]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'lnk2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'lnk2lbl2data_idx', 'lnk2lbl2data_identifier', 'lnk2lbl2data_input_text', 'lnk2lbl2data_input_ids', 'lnk2lbl2data_attention_mask', 'lnk2lbl2data_data2ptr', 'lnk2lbl2data_lbl2data2ptr'])

## Encoder

In [26]:
#| export
class Encoder(DistilBertPreTrainedModel):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        
        self.dr_head = RepresentationHead(config)
        self.dr_fused_head = RepresentationHead(config)
        self.meta_head = RepresentationHead(config)
        self.cross_head = CrossAttention(config)
        self.meta_embeddings = nn.Embedding(config.vocab_size, config.dim)
        
        self.post_init()

    def freeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(False)

    def unfreeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(True)

    def set_meta_embeddings(self, embed:torch.Tensor):
        with torch.no_grad():
            self.meta_embeddings.weight.copy_(embed)
        
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
    
    def dr(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def dr_fused(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_fused_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def meta(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)
    
    def meta_unnormalized(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return Pooling.mean_pooling(embed, attention_mask)

    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            
            if len(idx):
                assert torch.all(m_args['data2ptr'][idx] == m_args['data2ptr'].max()), f'All datapoints should have same number of metadata.'
                
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
                    m_embed = self.meta_embeddings(m_input_ids)
    
                    m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
                    m_repr_mask = torch.any(m_attention_mask, dim=1)
                    
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
                
                meta_repr[m_key] = m_repr[m_repr_mask]
                meta_repr[m_key] = F.normalize(meta_repr[m_key], dim=1)
                
                fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
                embed[idx] += fused_embed
               
        return embed, meta_repr

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
                                                                            data_attention_mask, 
                                                                            meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

## `OAK000`

In [29]:
#| export
class OAK000(nn.Module):
    
    def __init__(
        self, config,

        data_aug_meta_prefix:Optional[str]=None, 
        lbl2data_aug_meta_prefix:Optional[str]=None, 

        data_pred_meta_prefix:Optional[str]=None,
        lbl2data_pred_meta_prefix:Optional[str]=None,
        
        num_batch_labels:Optional[int]=None, 
        batch_size:Optional[int]=None,
        margin:Optional[float]=0.3,
        num_negatives:Optional[int]=5,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,

        calib_margin:Optional[float]=0.3,
        calib_num_negatives:Optional[int]=10,
        calib_tau:Optional[float]=0.1,
        calib_apply_softmax:Optional[bool]=False,
        calib_loss_weight:Optional[float]=0.1,
        use_calib_loss:Optional[float]=False,
        
        meta_loss_weight:Optional[Union[List,float]]=0.3,
        
        use_fusion_loss:Optional[bool]=False,
        fusion_loss_weight:Optional[float]=0.15,

        use_query_loss:Optional[float]=True,
        
        use_encoder_parallel:Optional[bool]=True,
    ):
        super().__init__(config)
        store_attr('meta_loss_weight,fusion_loss_weight,calib_loss_weight')
        store_attr('data_pred_meta_prefix,lbl2data_pred_meta_prefix')
        store_attr('data_aug_meta_prefix,lbl2data_aug_meta_prefix')
        store_attr('use_fusion_loss,use_query_loss,use_calib_loss,use_encoder_parallel')
        
        self.encoder = None
        self.rep_loss_fn = MultiTriplet(bsz=batch_size, tn_targ=num_batch_labels, margin=margin, n_negatives=num_negatives, 
                                        tau=tau, apply_softmax=apply_softmax, reduce='mean')
        self.cab_loss_fn = Calibration(margin=calib_margin, tau=calib_tau, n_negatives=calib_num_negatives, 
                                       apply_softmax=calib_apply_softmax, reduce='mean')
        
    def init_retrieval_head(self):
        assert self.encoder is not None, 'Encoder is not initialized.'
        self.encoder.dr_head.post_init()
        self.encoder.meta_head.post_init()
        self.encoder.dr_fused_head.post_init()

    def init_cross_head(self):
        assert self.encoder is not None, 'Encoder is not initialized.'
        self.encoder.cross_head.post_init()
        
    def compute_loss(self, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def calibration_loss(self, einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.calib_loss_weight * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)
    
    def compute_meta_loss(self, data_repr, lbl2data_repr, **kwargs):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_meta_inputs = Parameters.from_meta_pred_prefix(self.data_pred_meta_prefix, **kwargs)
        lbl2data_meta_inputs = Parameters.from_meta_pred_prefix(self.lbl2data_pred_meta_prefix, **kwargs)
        meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}

        m_lw = Parameters.get_meta_loss_weights(self.meta_loss_weight, len(meta_inputs)) if len(meta_inputs) else []
        
        loss = 0.0
        for inputs,lw in zip(meta_inputs.values(), m_lw):
            if 'lbl2data2ptr' in inputs:
                idx = torch.where(inputs['lbl2data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(lbl2data_repr[idx], inputs_o.rep, inputs['lbl2data2ptr'][idx],
                                              inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss

            elif 'data2ptr' in inputs:
                idx = torch.where(inputs['data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(data_repr[idx], inputs_o.rep, inputs['data2ptr'][idx], inputs['idx'], 
                                              inputs['pdata2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss       

            else: raise ValueError('Invalid metadata input arguments.')
        return loss

    def compute_fusion_loss(self, data_repr, meta_repr:Dict, prefix:str, **kwargs):
        meta_inputs = Parameters.from_meta_pred_prefix(prefix, **kwargs)
        
        loss = 0.0
        if meta_repr is not None:
            for key,input_repr in meta_repr.items():
                inputs = meta_inputs[key]
                if 'lbl2data2ptr' in inputs:
                    idx = torch.where(inputs['lbl2data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['lbl2data2ptr'][idx],
                                                  inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                        loss += self.fusion_loss_weight * m_loss
    
                elif 'data2ptr' in inputs:
                    idx = torch.where(inputs['data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
                                                  inputs['pdata2ptr'][idx], inputs['pidx'])
                        loss += self.fusion_loss_weight * m_loss       
    
                else: raise ValueError('Invalid metadata input arguments.')
        return loss


    def get_meta_representation(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_unnormalized=True, data_type="meta")
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

## `OAK001`

In [37]:
#| export
class OAK001(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.encoder = Encoder(config)
        self.post_init(); self.remap_post_init();

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [31]:
model = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()
model.encoder.set_meta_embeddings(model.distilbert.embeddings.word_embeddings.weight)

Some weights of OAK001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [32]:
model = model.to('cuda')

In [33]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [39]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'lnk2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'lnk2lbl2data_idx', 'lnk2lbl2data_identifier', 'lnk2lbl2data_input_text', 'lnk2lbl2data_input_ids', 'lnk2lbl2data_attention_mask', 'lnk2lbl2data_data2ptr', 'lnk2lbl2data_lbl2data2ptr'])

In [35]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [36]:
o.loss

tensor(0.0289, device='cuda:0', grad_fn=<AddBackward0>)

In [44]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [48]:
o = func()

> /tmp/ipykernel_28793/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  b model.encoder.cross_head.forward


Breakpoint 5 at /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py:75


ipdb>  c


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(83)forward()
     81         output_attentions:Optional[bool] = False,
     82     ):
---> 83         bs, q_len, dim = q.size()
     84         v, k_len = k, k.size(1)
     85 



ipdb>  !q


tensor([[[-0.6077, -1.4566, -1.2658,  ...,  0.4173,  0.4841, -0.4502],
         [-0.6280, -0.1254, -0.2795,  ...,  0.7303,  0.4780, -0.8035],
         [-0.5878,  0.2601, -0.8291,  ...,  0.7271, -0.0244, -0.1744],
         [-0.3885, -0.0131, -0.3512,  ...,  0.8768, -0.2073, -0.3525],
         [-0.2144, -1.0242, -1.2639,  ...,  1.1083,  0.2689,  0.1441]],

        [[-0.0116, -0.6070, -1.0937,  ...,  0.4919, -0.6339,  0.3495],
         [-0.4328, -0.5217, -0.7185,  ...,  0.3906, -0.9542,  0.9589],
         [-0.6673, -0.8967, -1.2427,  ...,  0.3123, -0.9577, -0.1136],
         [-0.2185,  0.0981, -0.4987,  ...,  0.7605, -0.7738,  0.0544],
         [ 0.1080,  0.3637, -0.7733,  ...,  0.7831, -0.8321,  0.0020]],

        [[-0.1806,  0.5675, -1.1267,  ...,  1.1764,  0.0713,  0.6914],
         [ 0.0746,  0.4647, -1.5286,  ...,  0.7264,  0.2030,  0.5578],
         [-0.2783,  0.2235, -0.8286,  ...,  0.4618,  0.0872, -0.2953],
         [ 0.7958, -0.2424, -1.0187,  ...,  0.6092, -0.0145,  0.7918],
  

ipdb>  !q_m


tensor([[1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0]], device='cuda:0')


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(84)forward()
     82     ):
     83         bs, q_len, dim = q.size()
---> 84         v, k_len = k, k.size(1)
     85 
     86         h_dim = self.dim//self.n_h



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(86)forward()
     84         v, k_len = k, k.size(1)
     85 
---> 86         h_dim = self.dim//self.n_h
     87 
     88         def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(88)forward()
     86         h_dim = self.dim//self.n_h
     87 
---> 88         def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)
     89 
     90         def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(90)forward()
     88         def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)
     89 
---> 90         def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)
     91 
     92         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(92)forward()
     90         def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)
     91 
---> 92         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
     93         k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
     94         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(93)forward()
     91 
     92         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
---> 93         k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
     94         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)
     95 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(94)forward()
     92         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
     93         k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
---> 94         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)
     95 
     96         q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(96)forward()
     94         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)
     95 
---> 96         q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
     97         sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
     98 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(97)forward()
     95 
     96         q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
---> 97         sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
     98 
     99         q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(99)forward()
     97         sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
     98 
---> 99         q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
    100         mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
    101 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(100)forward()
     98 
     99         q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
--> 100         mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
    101 
    102         sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(102)forward()
    100         mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
    101 
--> 102         sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)
    103 
    104         w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)



ipdb>  mask.shape


torch.Size([5, 12, 5, 3])


ipdb>  mask[0][0]


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [0., 0., 0.]], device='cuda:0')


ipdb>  mask[1][0]


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], device='cuda:0')


ipdb>  mask[2][0]


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], device='cuda:0')


ipdb>  mask[3][0]


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')


ipdb>  mask[4][0]


tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')


ipdb>  mask[5][0]


*** IndexError: index 5 is out of bounds for dimension 0 with size 5


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(104)forward()
    102         sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)
    103 
--> 104         w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
    105         w = self.dropout(w)  # (bs, n_h, q_len, k_len)
    106 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(105)forward()
    103 
    104         w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
--> 105         w = self.dropout(w)  # (bs, n_h, q_len, k_len)
    106 
    107         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(107)forward()
    105         w = self.dropout(w)  # (bs, n_h, q_len, k_len)
    106 
--> 107         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
    108 
    109         if output_attentions: return (o, w)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(109)forward()
    107         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
    108 
--> 109         if output_attentions: return (o, w)
    110         else: return (o,)
    111 



ipdb>  


--Return--
(tensor([[[ 4....iewBackward0>), tensor([[[[0....maxBackward0>))
> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(109)forward()
    107         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
    108 
--> 109         if output_attentions: return (o, w)
    110         else: return (o,)
    111 



ipdb>  


--Call--
> /tmp/ipykernel_28793/505101603.py(46)dr_fused()
     44         return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)
     45 
---> 46     def dr_fused(self, embed:torch.Tensor, attention_mask:torch.Tensor):
     47         embed = self.dr_fused_head(embed)
     48         return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



## `OAK002`

In [105]:
#| export
class Encoder002(Encoder):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.post_init()

    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            
            if len(idx):
                assert torch.all(m_args['data2ptr'][idx] == m_args['data2ptr'].max()), f'All datapoints should have same number of metadata.'
                
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
                    m_embed = self.meta_embeddings(m_input_ids)

                m_repr, m_repr_mask = m_embed.view(len(idx), -1, self.config.dim), m_attention_mask.view(len(idx), -1)
                meta_repr[m_key] = self.meta(m_embed, m_attention_mask)
                
                fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
                embed[idx] += fused_embed
               
        return embed, meta_repr
        

In [106]:
#| export
class OAK002(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.encoder = Encoder002(config)
        self.post_init(); self.remap_post_init();

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [107]:
model = OAK002.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()
model.encoder.set_meta_embeddings(model.distilbert.embeddings.word_embeddings.weight)

Some weights of OAK002 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [97]:
model = model.to('cuda')

In [98]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [99]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_identifier', 'lnk2data_input_text', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'lnk2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'lnk2lbl2data_idx', 'lnk2lbl2data_identifier', 'lnk2lbl2data_input_text', 'lnk2lbl2data_input_ids', 'lnk2lbl2data_attention_mask', 'lnk2lbl2data_data2ptr', 'lnk2lbl2data_lbl2data2ptr'])

In [103]:
o = model(**b.to(model.device))

> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/modeling_utils.py(83)forward()
     81         output_attentions:Optional[bool] = False,
     82     ):
---> 83         bs, q_len, dim = q.size()
     84         v, k_len = k, k.size(1)
     85 



ipdb>  c


In [104]:
o.loss

tensor(0.0132, device='cuda:0', grad_fn=<AddBackward0>)

In [108]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))

In [111]:
o = func()

> /tmp/ipykernel_28793/1795951242.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))



ipdb>  b model.encoder.forward


Breakpoint 8 at /tmp/ipykernel_28793/505101603.py:88


ipdb>  c


> /tmp/ipykernel_28793/505101603.py(97)forward()
     95         **kwargs
     96     ):  
---> 97         data_o = self.encode(data_input_ids, data_attention_mask)
     98 
     99         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_28793/505101603.py(99)forward()
     97         data_o = self.encode(data_input_ids, data_attention_mask)
     98 
---> 99         if data_type is not None and data_type == "meta":
    100             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    101         else:



ipdb>  


> /tmp/ipykernel_28793/505101603.py(102)forward()
    100             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    101         else:
--> 102             data_repr = self.dr(data_o[0], data_attention_mask)
    103 
    104         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_28793/505101603.py(104)forward()
    102             data_repr = self.dr(data_o[0], data_attention_mask)
    103 
--> 104         data_fused_repr = meta_repr = None
    105         if data_aug_meta_prefix is not None:
    106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_28793/505101603.py(105)forward()
    103 
    104         data_fused_repr = meta_repr = None
--> 105         if data_aug_meta_prefix is not None:
    106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    107             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_28793/505101603.py(106)forward()
    104         data_fused_repr = meta_repr = None
    105         if data_aug_meta_prefix is not None:
--> 106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    107             if len(meta_kwargs):
    108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 



ipdb>  n


> /tmp/ipykernel_28793/505101603.py(107)forward()
    105         if data_aug_meta_prefix is not None:
    106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 107             if len(meta_kwargs):
    108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    109                                                                             data_attention_mask,



ipdb>  n


> /tmp/ipykernel_28793/505101603.py(108)forward()
    106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    107             if len(meta_kwargs):
--> 108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    109                                                                             data_attention_mask,
    110                                                                             meta_kwargs)



ipdb>  


> /tmp/ipykernel_28793/505101603.py(109)forward()
    107             if len(meta_kwargs):
    108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
--> 109                                                                             data_attention_mask,
    110                                                                             meta_kwargs)
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)



ipdb>  


> /tmp/ipykernel_28793/505101603.py(110)forward()
    108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    109                                                                             data_attention_mask,
--> 110                                                                             meta_kwargs)
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 



ipdb>  


> /tmp/ipykernel_28793/505101603.py(108)forward()
    106             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    107             if len(meta_kwargs):
--> 108                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    109                                                                             data_attention_mask,
    110                                                                             meta_kwargs)



ipdb>  s


--Call--
> /tmp/ipykernel_28793/219820846.py(11)fuse_meta_into_embeddings()
      9         self.post_init()
     10 
---> 11     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
     12         meta_repr = {}
     13 



ipdb>  n


> /tmp/ipykernel_28793/219820846.py(12)fuse_meta_into_embeddings()
     10 
     11     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
---> 12         meta_repr = {}
     13 
     14         for m_key, m_args in meta_kwargs.items():



ipdb>  


> /tmp/ipykernel_28793/219820846.py(14)fuse_meta_into_embeddings()
     12         meta_repr = {}
     13 
---> 14         for m_key, m_args in meta_kwargs.items():
     15             idx = torch.where(m_args['data2ptr'] > 0)[0]
     16             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_28793/219820846.py(15)fuse_meta_into_embeddings()
     13 
     14         for m_key, m_args in meta_kwargs.items():
---> 15             idx = torch.where(m_args['data2ptr'] > 0)[0]
     16             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     17 



ipdb>  


> /tmp/ipykernel_28793/219820846.py(16)fuse_meta_into_embeddings()
     14         for m_key, m_args in meta_kwargs.items():
     15             idx = torch.where(m_args['data2ptr'] > 0)[0]
---> 16             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     17 
     18             if len(idx):



ipdb>  


> /tmp/ipykernel_28793/219820846.py(18)fuse_meta_into_embeddings()
     16             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     17 
---> 18             if len(idx):
     19                 assert torch.all(m_args['data2ptr'][idx] == m_args['data2ptr'].max()), f'All datapoints should have same number of metadata.'
     20 



ipdb>  


> /tmp/ipykernel_28793/219820846.py(19)fuse_meta_into_embeddings()
     17 
     18             if len(idx):
---> 19                 assert torch.all(m_args['data2ptr'][idx] == m_args['data2ptr'].max()), f'All datapoints should have same number of metadata.'
     20 
     21                 if 'meta_repr' in m_args:



ipdb>  


> /tmp/ipykernel_28793/219820846.py(21)fuse_meta_into_embeddings()
     19                 assert torch.all(m_args['data2ptr'][idx] == m_args['data2ptr'].max()), f'All datapoints should have same number of metadata.'
     20 
---> 21                 if 'meta_repr' in m_args:
     22                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
     23                     m_repr_mask = m_repr_mask.bool()



ipdb>  


> /tmp/ipykernel_28793/219820846.py(25)fuse_meta_into_embeddings()
     23                     m_repr_mask = m_repr_mask.bool()
     24                 else:
---> 25                     m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
     26                     m_embed = self.meta_embeddings(m_input_ids)
     27 



ipdb>  


> /tmp/ipykernel_28793/219820846.py(26)fuse_meta_into_embeddings()
     24                 else:
     25                     m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
---> 26                     m_embed = self.meta_embeddings(m_input_ids)
     27 
     28                 m_repr, m_repr_mask = m_embed.view(len(idx), -1, self.config.dim), m_attention_mask.view(len(idx), -1)



ipdb>  


> /tmp/ipykernel_28793/219820846.py(28)fuse_meta_into_embeddings()
     26                     m_embed = self.meta_embeddings(m_input_ids)
     27 
---> 28                 m_repr, m_repr_mask = m_embed.view(len(idx), -1, self.config.dim), m_attention_mask.view(len(idx), -1)
     29                 meta_repr[m_key] = self.meta(m_embed, m_attention_mask)
     30 



ipdb>  m_embed.shape


torch.Size([15, 14, 768])


ipdb>  n


> /tmp/ipykernel_28793/219820846.py(29)fuse_meta_into_embeddings()
     27 
     28                 m_repr, m_repr_mask = m_embed.view(len(idx), -1, self.config.dim), m_attention_mask.view(len(idx), -1)
---> 29                 meta_repr[m_key] = self.meta(m_embed, m_attention_mask)
     30 
     31                 fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)



ipdb>  m_repr.shape


torch.Size([5, 42, 768])


ipdb>  m_repr_mask.shape


torch.Size([5, 42])


ipdb>  m_repr_mask


tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])


ipdb>  n


> /tmp/ipykernel_28793/219820846.py(31)fuse_meta_into_embeddings()
     29                 meta_repr[m_key] = self.meta(m_embed, m_attention_mask)
     30 
---> 31                 fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
     32                 embed[idx] += fused_embed
     33 



ipdb>  meta_repr[m_key].shape


torch.Size([15, 768])


ipdb>  m_attention_mask.shape


torch.Size([15, 14])


ipdb>  m_repr_mask.shape


torch.Size([5, 42])


ipdb>  m_repr.shape


torch.Size([5, 42, 768])


ipdb>  embed.shape


torch.Size([5, 5, 768])


ipdb>  n


> /tmp/ipykernel_28793/219820846.py(32)fuse_meta_into_embeddings()
     30 
     31                 fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
---> 32                 embed[idx] += fused_embed
     33 
     34         return embed, meta_repr



ipdb>  fused_embed.shape


torch.Size([5, 5, 768])


ipdb>  n


> /tmp/ipykernel_28793/219820846.py(14)fuse_meta_into_embeddings()
     12         meta_repr = {}
     13 
---> 14         for m_key, m_args in meta_kwargs.items():
     15             idx = torch.where(m_args['data2ptr'] > 0)[0]
     16             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_28793/219820846.py(34)fuse_meta_into_embeddings()
     31                 fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
     32                 embed[idx] += fused_embed
     33 
---> 34         return embed, meta_repr
     35 



ipdb>  


--Return--
(tensor([[[-6....PutBackward0>), {'lnk2data': tensor([[-0.0...DivBackward0>)})
> /tmp/ipykernel_28793/219820846.py(34)fuse_meta_into_embeddings()
     31                 fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=True)
     32                 embed[idx] += fused_embed
     33 
---> 34         return embed, meta_repr
     35 



ipdb>  


> /tmp/ipykernel_28793/505101603.py(111)forward()
    109                                                                             data_attention_mask,
    110                                                                             meta_kwargs)
--> 111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
    113         return EncoderOutput(



ipdb>  c


> /tmp/ipykernel_28793/505101603.py(113)forward()
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
--> 113         return EncoderOutput(
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,



ipdb>  c


> /tmp/ipykernel_28793/505101603.py(114)forward()
    112 
    113         return EncoderOutput(
--> 114             rep=data_repr,
    115             fused_rep=data_fused_repr,
    116             meta_repr=meta_repr,



ipdb>  c


> /tmp/ipykernel_28793/505101603.py(115)forward()
    113         return EncoderOutput(
    114             rep=data_repr,
--> 115             fused_rep=data_fused_repr,
    116             meta_repr=meta_repr,
    117         )



ipdb>  n


> /tmp/ipykernel_28793/505101603.py(116)forward()
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,
--> 116             meta_repr=meta_repr,
    117         )
    118 



ipdb>  


> /tmp/ipykernel_28793/505101603.py(113)forward()
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
--> 113         return EncoderOutput(
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /tmp/ipykernel_28793/505101603.py(113)forward()
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
--> 113         return EncoderOutput(
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,



ipdb>  


--Call--
> <string>(2)__init__()



ipdb>  


> <string>(3)__init__()



ipdb>  


> <string>(4)__init__()



ipdb>  


> <string>(5)__init__()



ipdb>  


> <string>(6)__init__()



ipdb>  c


> /tmp/ipykernel_28793/505101603.py(97)forward()
     95         **kwargs
     96     ):  
---> 97         data_o = self.encode(data_input_ids, data_attention_mask)
     98 
     99         if data_type is not None and data_type == "meta":



ipdb>  


> /tmp/ipykernel_28793/505101603.py(113)forward()
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
--> 113         return EncoderOutput(
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_28793/505101603.py(114)forward()
    112 
    113         return EncoderOutput(
--> 114             rep=data_repr,
    115             fused_rep=data_fused_repr,
    116             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_28793/505101603.py(115)forward()
    113         return EncoderOutput(
    114             rep=data_repr,
--> 115             fused_rep=data_fused_repr,
    116             meta_repr=meta_repr,
    117         )



ipdb>  


> /tmp/ipykernel_28793/505101603.py(113)forward()
    111                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    112 
--> 113         return EncoderOutput(
    114             rep=data_repr,
    115             fused_rep=data_fused_repr,



ipdb>  


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [112]:
o.loss

tensor(0.0134, grad_fn=<AddBackward0>)