In [1]:
#| default_exp models.oak

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn, math
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import (
    PretrainedConfig,
    DistilBertForMaskedLM,
    DistilBertModel,
    DistilBertPreTrainedModel,
    DistilBertConfig,
)
from transformers.utils.generic import ModelOutput
from transformers.activations import get_activation

from fastcore.meta import *
from fastcore.utils import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

In [4]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [5]:
import copy
from transformers import AutoConfig
from xcai.block import *
from xcai.main import *

## Setup

In [6]:
from xcai.core import prepare_batch

In [7]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [8]:
data_dir = '/Users/suchith720/Projects/data'
config_file = 'wikiseealsotitles'
config_key = 'data_meta'

mname = 'sentence-transformers/msmarco-distilbert-base-v4'

pkl_dir = '/Users/suchith720/Projects/data/processed/mogicX'
os.makedirs(pkl_dir, exist_ok=True)

pkl_file = f'{pkl_dir}/wikiseealsotitles_data_distilbert-base-uncased_sxc.joblib'

In [9]:
block = build_block(pkl_file, config_file, True, config_key, data_dir=data_dir, n_slbl_samples=1, do_build=False, 
                    main_oversample=True, meta_oversample=True, return_scores=True)

In [10]:
block.train.dset.meta['neg_meta'] = copy.deepcopy(block.train.dset.meta['cat_meta'])
block.train.dset.meta['neg_meta'].prefix = 'neg'

In [11]:
batch = block.train.one_batch(4)

In [12]:
for i,batch in enumerate(block.train.dl):
    if i > 2: break

In [20]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_scores', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_scores', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_scores', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr', 'pneg2data_idx', 'pneg2data_data2ptr', 'neg2data_idx', 'neg2data_scores', 'neg2data_data2ptr', 'neg2data_identifier', 'neg2data_input_text', 'neg2data_input_ids', 'neg2data_attention_mask', 'pneg2lbl_idx', 'pneg2lbl_lbl2ptr', 'neg2lbl_idx', 'neg2lbl_scores', 'neg2lbl_lbl2ptr', 'neg2lbl_identifier'

In [14]:
[o for o in batch.keys() if 'scores' in o]

['lbl2data_scores',
 'cat2data_scores',
 'cat2lbl_scores',
 'neg2data_scores',
 'neg2lbl_scores']

## CrossAttention

In [15]:
#| export
class CrossAttention(nn.Module):
    
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config, self.n_h, self.dim = config, config.n_heads, config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        if self.dim % self.n_h != 0:
            raise ValueError(f"self.n_heads: {self.n_h} must divide self.dim: {self.dim} evenly.")
            
        self.q = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.o = nn.Linear(in_features=config.dim, out_features=config.dim)

    @torch.no_grad()
    def post_init(self):
        nn.init.eye_(self.q.weight.data)
        nn.init.eye_(self.k.weight.data)
        nn.init.eye_(self.v.weight.data)
        nn.init.eye_(self.o.weight.data)

    def forward(
        self, 
        q: torch.Tensor,
        q_m: torch.Tensor,
        k: torch.Tensor, 
        k_m: torch.Tensor,
        output_attentions:Optional[bool] = False,
    ):
        bs, q_len, dim = q.size()
        v, k_len = k, k.size(1) 

        h_dim = self.dim//self.n_h

        def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)

        def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)

        q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
        k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
        v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)

        q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
        sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
        
        q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
        mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
        
        sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)

        w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
        w = self.dropout(w)  # (bs, n_h, q_len, k_len)

        o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
        
        if output_attentions: return (o, w)
        else: return (o,)
        

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
fuser = CrossAttention(config)

In [None]:
bsz, data_seq_len, n_meta, dim, dtype = 2, 3, 2, config.dim, torch.float32
data, meta = torch.randn(bsz, data_seq_len, dim, dtype=dtype), torch.randn(bsz, n_meta, dim, dtype=dtype)
data_mask = torch.randint(0, 2, size=(bsz,data_seq_len), dtype=dtype)
meta_mask = torch.randint(0, 2, size=(bsz,n_meta), dtype=dtype)

In [None]:
o = fuser(data, data_mask, meta, meta_mask)

In [None]:
o[0].shape

torch.Size([2, 3, 768])

## NormCrossAttention

In [None]:
#| export
class NormCrossAttention(nn.Module):
    
    def __init__(self, config: PretrainedConfig, tau:Optional[float]=0.1, dropout:Optional[float]=0.1):
        super().__init__()
        self.tau = nn.Parameter(torch.tensor(tau, dtype=torch.float32))
        
        self.config, self.n_h, self.dim = config, config.n_heads, config.dim
        self.dropout = nn.Dropout(p=dropout)

        if self.dim % self.n_h != 0:
            raise ValueError(f"self.n_heads: {self.n_h} must divide self.dim: {self.dim} evenly.")
            
        self.q = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.o = nn.Linear(in_features=config.dim, out_features=config.dim)

    def post_init(self):
        self.q.weight.data = torch.eye(self.q.out_features, self.q.in_features, dtype=self.q.weight.dtype)
        self.k.weight.data = torch.eye(self.k.out_features, self.k.in_features, dtype=self.k.weight.dtype)
        self.v.weight.data = torch.eye(self.v.out_features, self.v.in_features, dtype=self.v.weight.dtype)
        self.o.weight.data = torch.eye(self.o.out_features, self.o.in_features, dtype=self.o.weight.dtype)

    def forward(
        self, 
        q: torch.Tensor,
        q_m: torch.Tensor,
        k: torch.Tensor, 
        k_m: torch.Tensor,
        output_attentions:Optional[bool] = False,
    ):
        bs, q_len, dim = q.size()
        v, k_len = k, k.size(1) 

        h_dim = self.dim//self.n_h

        def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)

        def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)

        q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
        k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
        v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)

        q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
        sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
        sc = sc * self.tau
        
        q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
        mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
        
        sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)
        
        w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
        w = self.dropout(w)  # (bs, n_h, q_len, k_len)

        o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
        
        if output_attentions: return (o, w)
        else: return (o,)
        

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
fuser = NormCrossAttention(config)

In [None]:
bsz, data_seq_len, n_meta, dim, dtype = 2, 3, 2, config.dim, torch.float32

data, meta = torch.randn(bsz, data_seq_len, dim, dtype=dtype), torch.randn(bsz, n_meta, dim, dtype=dtype)
data_mask = torch.randint(0, 2, size=(bsz,data_seq_len), dtype=dtype)
meta_mask = torch.randint(0, 2, size=(bsz,n_meta), dtype=dtype)

In [None]:
o = fuser(data, data_mask, meta, meta_mask)

In [None]:
o[0].shape

torch.Size([2, 3, 768])

## Encoder

In [16]:
#| export
class Encoder(DistilBertPreTrainedModel):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
        num_metadata:int,
        resize_length:Optional[int]=None,
        normalize:Optional[bool]=True,
    ):
        super().__init__(config)
        store_attr('normalize')
        self.distilbert = DistilBertModel(config)
        
        self.dr_head = RepresentationHead(config)
        self.dr_fused_head = RepresentationHead(config)
        
        self.meta_head = RepresentationHead(config)
        
        self.cross_head = CrossAttention(config)
        
        self.meta_embeddings = nn.Embedding(num_metadata, config.dim)

        if resize_length is None: self.ones = None
        else: self.register_buffer('ones', torch.ones(resize_length, dtype=torch.long, device=self.device))
        
        self.post_init()

    def freeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(False)

    def unfreeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(True)

    @torch.no_grad()
    def set_meta_embeddings(self, embed:torch.Tensor):
        self.meta_embeddings.weight.data.copy_(embed)
        
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
    
    def dr(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_head(embed)
        embed = Pooling.mean_pooling(embed, attention_mask)
        return F.normalize(embed, dim=1) if self.normalize else embed

    def dr_fused(self, embed:torch.Tensor):
        embed = self.dr_fused_head(embed)
        return F.normalize(embed, dim=1) if self.normalize else embed

    def meta(self, embed:torch.Tensor, attention_mask:torch.Tensor, normalize:Optional[bool]=True):
        embed = self.meta_head(embed)
        embed = Pooling.mean_pooling(embed, attention_mask)
        return F.normalize(embed, dim=1) if normalize else embed

    def resize(self, idx:torch.Tensor, num_inputs:torch.Tensor):
        if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
        bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
        
        self.ones = self.ones.to(idx.device)
        ones = (
            torch.ones(total_num_inputs, dtype=torch.long, device=idx.device) 
            if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
        )

        max_num_inputs = num_inputs.max()
        if (num_inputs == max_num_inputs).all():
            return idx,ones
        
        xnum_inputs = max_num_inputs-num_inputs+1

        inputs_ptr = num_inputs.cumsum(dim=0)-1
        repeat_inputs = ones.scatter(0, inputs_ptr, xnum_inputs)
        
        resized_idx = idx.repeat_interleave(repeat_inputs, dim=0)
        ignore_mask = ones.scatter(0, inputs_ptr, 0).repeat_interleave(repeat_inputs, dim=0).view(bsz, -1)
        ignore_mask[:, -1] = 1; ignore_mask = ignore_mask.flatten()
        
        return resized_idx,ignore_mask

    
    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
            
            if len(idx):
                m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
                m_repr = self.meta_embeddings(m_idx)
                m_repr = F.normalize(m_repr, dim=1) if self.normalize else m_repr
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
                meta_repr[m_key] = m_repr[m_repr_mask]
                
                fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
                data_fused_repr[idx] += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta(data_o[0], data_attention_mask, not data_unnormalized)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
                                                                            torch.any(data_attention_mask, dim=1), 
                                                                            meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_repr)
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

## `OAK000`

In [21]:
#| export
class OAK000(nn.Module):
    
    def __init__(
        self, config,

        data_aug_meta_prefix:Optional[str]=None, 
        lbl2data_aug_meta_prefix:Optional[str]=None, 

        data_pred_meta_prefix:Optional[str]=None,
        lbl2data_pred_meta_prefix:Optional[str]=None,
        
        margin:Optional[float]=0.3,
        num_negatives:Optional[int]=5,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,

        calib_margin:Optional[float]=0.3,
        calib_num_negatives:Optional[int]=10,
        calib_tau:Optional[float]=0.1,
        calib_apply_softmax:Optional[bool]=False,
        calib_loss_weight:Optional[float]=0.1,
        use_calib_loss:Optional[float]=False,
        
        meta_loss_weight:Optional[Union[List,float]]=0.3,
        
        use_fusion_loss:Optional[bool]=False,
        fusion_loss_weight:Optional[float]=0.15,

        use_query_loss:Optional[float]=True,
        
        use_encoder_parallel:Optional[bool]=True,
        
        normalize:Optional[bool]=True,
    ):
        super().__init__(config)
        store_attr('meta_loss_weight,fusion_loss_weight,calib_loss_weight')
        store_attr('data_pred_meta_prefix,lbl2data_pred_meta_prefix')
        store_attr('data_aug_meta_prefix,lbl2data_aug_meta_prefix')
        store_attr('use_fusion_loss,use_query_loss,use_calib_loss,use_encoder_parallel')
        store_attr('normalize')
        
        self.encoder = None
        self.rep_loss_fn = MultiTriplet(margin=margin, n_negatives=num_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        self.cab_loss_fn = Calibration(margin=calib_margin, tau=calib_tau, n_negatives=calib_num_negatives, 
                                       apply_softmax=calib_apply_softmax, reduce='mean')
        
    @torch.no_grad()
    def init_retrieval_head(self):
        assert self.encoder is not None, "`self.encoder` is not initialized."
        self.encoder.dr_head.post_init()
        self.encoder.meta_head.post_init()
        self.encoder.dr_fused_head.post_init()

    @torch.no_grad()
    def init_cross_head(self):
        assert self.encoder is not None, "`self.encoder` is not initialized."
        self.encoder.cross_head.post_init()

    def compute_loss(self, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def calibration_loss(self, einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.calib_loss_weight * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)
    
    def compute_meta_loss(self, data_repr, lbl2data_repr, **kwargs):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_meta_inputs = Parameters.from_meta_pred_prefix(self.data_pred_meta_prefix, **kwargs)
        lbl2data_meta_inputs = Parameters.from_meta_pred_prefix(self.lbl2data_pred_meta_prefix, **kwargs)
        meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}

        m_lw = Parameters.get_meta_loss_weights(self.meta_loss_weight, len(meta_inputs)) if len(meta_inputs) else []
        
        loss = 0.0
        for inputs,lw in zip(meta_inputs.values(), m_lw):
            if 'lbl2data2ptr' in inputs:
                idx = torch.where(inputs['lbl2data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(lbl2data_repr[idx], inputs_o.rep, inputs['lbl2data2ptr'][idx],
                                              inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss

            elif 'data2ptr' in inputs:
                idx = torch.where(inputs['data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(data_repr[idx], inputs_o.rep, inputs['data2ptr'][idx], inputs['idx'], 
                                              inputs['pdata2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss       

            else: raise ValueError('Invalid metadata input arguments.')
        return loss

    def compute_fusion_loss(self, data_repr, meta_repr:Dict, prefix:str, **kwargs):
        meta_inputs = Parameters.from_meta_pred_prefix(prefix, **kwargs)
        
        loss = 0.0
        if meta_repr is not None:
            for key,input_repr in meta_repr.items():
                inputs = meta_inputs[key]
                if 'lbl2data2ptr' in inputs:
                    idx = torch.where(inputs['lbl2data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['lbl2data2ptr'][idx],
                                                  inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                        loss += self.fusion_loss_weight * m_loss
    
                elif 'data2ptr' in inputs:
                    idx = torch.where(inputs['data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
                                                  inputs['pdata2ptr'][idx], inputs['pidx'])
                        loss += self.fusion_loss_weight * m_loss       
    
                else: raise ValueError('Invalid metadata input arguments.')
        return loss


    def get_meta_representation(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_unnormalized=True, data_type="meta")
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

## `OAK001`

In [14]:
#| export
class OAK001(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder(config, num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [58]:
model = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()

Some weights of OAK001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [59]:
model.encoder.set_meta_embeddings(torch.zeros(656086, 768))

In [60]:
model = model.to('cuda')

In [61]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [62]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [63]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [64]:
o = func()

> /tmp/ipykernel_5515/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_5515/3933810934.py:143


ipdb>  b model.encoder.forward


Breakpoint 2 at /tmp/ipykernel_5515/4122862164.py:106


ipdb>  b model.encoder.fuse_meta_into_embeddings


Breakpoint 3 at /tmp/ipykernel_5515/4122862164.py:86


ipdb>  c


> /tmp/ipykernel_5515/3933810934.py(158)forward()
    156         **kwargs
    157     ):  
--> 158         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    159 
    160         if self.use_encoder_parallel:



ipdb>  c


> /tmp/ipykernel_5515/4122862164.py(115)forward()
    113         **kwargs
    114     ):  
--> 115         data_o = self.encode(data_input_ids, data_attention_mask)
    116 
    117         if data_type is not None and data_type == "meta":



ipdb>  c


> /tmp/ipykernel_5515/4122862164.py(87)fuse_meta_into_embeddings()
     85 
3    86     def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
---> 87         meta_repr = {}
     88 
     89         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)



ipdb>  n


> /tmp/ipykernel_5515/4122862164.py(89)fuse_meta_into_embeddings()
     87         meta_repr = {}
     88 
---> 89         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
     90         for m_key, m_args in meta_kwargs.items():
     91             idx = torch.where(m_args['data2ptr'] > 0)[0]



ipdb>  


> /tmp/ipykernel_5515/4122862164.py(90)fuse_meta_into_embeddings()
     88 
     89         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
---> 90         for m_key, m_args in meta_kwargs.items():
     91             idx = torch.where(m_args['data2ptr'] > 0)[0]
     92             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)



ipdb>  data_fused_repr.shape


torch.Size([5, 1, 768])


ipdb>  n


> /tmp/ipykernel_5515/4122862164.py(91)fuse_meta_into_embeddings()
     89         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
     90         for m_key, m_args in meta_kwargs.items():
---> 91             idx = torch.where(m_args['data2ptr'] > 0)[0]
     92             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
     93 



ipdb>  


> /tmp/ipykernel_5515/4122862164.py(92)fuse_meta_into_embeddings()
     90         for m_key, m_args in meta_kwargs.items():
     91             idx = torch.where(m_args['data2ptr'] > 0)[0]
---> 92             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
     93 
     94             if len(idx):



ipdb>  


> /tmp/ipykernel_5515/4122862164.py(94)fuse_meta_into_embeddings()
     92             meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
     93 
---> 94             if len(idx):
     95                 m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
     96                 m_repr = F.normalize(self.meta_embeddings(m_idx), dim=1)



ipdb>  


> /tmp/ipykernel_5515/4122862164.py(95)fuse_meta_into_embeddings()
     93 
     94             if len(idx):
---> 95                 m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
     96                 m_repr = F.normalize(self.meta_embeddings(m_idx), dim=1)
     97 



ipdb>  


> /tmp/ipykernel_5515/4122862164.py(96)fuse_meta_into_embeddings()
     94             if len(idx):
     95                 m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
---> 96                 m_repr = F.normalize(self.meta_embeddings(m_idx), dim=1)
     97 
     98                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)



ipdb>  m_idx


tensor([161850,  72729, 161852, 120509, 306634, 355897, 110482,  68899, 102557,
         74011,  93075, 272047,  84732,  68113,  84876], device='cuda:0')


ipdb>  m_repr_mask


tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')


ipdb>  n


> /tmp/ipykernel_5515/4122862164.py(98)fuse_meta_into_embeddings()
     96                 m_repr = F.normalize(self.meta_embeddings(m_idx), dim=1)
     97 
---> 98                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
     99                 meta_repr[m_key] = m_repr[m_repr_mask]
    100 



ipdb>  m_repr.shape


torch.Size([15, 768])


ipdb>  n


> /tmp/ipykernel_5515/4122862164.py(99)fuse_meta_into_embeddings()
     97 
     98                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
---> 99                 meta_repr[m_key] = m_repr[m_repr_mask]
    100 
    101                 fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]



ipdb>  m_repr.shape


torch.Size([5, 3, 768])


ipdb>  m_repr_mask.shape


torch.Size([5, 3])


ipdb>  n


> /tmp/ipykernel_5515/4122862164.py(101)fuse_meta_into_embeddings()
     99                 meta_repr[m_key] = m_repr[m_repr_mask]
    100 
--> 101                 fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
    102                 data_fused_repr[idx] += fused_repr
    103 



ipdb>  q


In [None]:
o.loss

tensor(0.0670, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK002`

In [None]:
#| export
class Encoder002(Encoder):

    def __init__(
        self, 
        config,
        cross_tau:Optional[float]=0.1, 
        cross_dropout:Optional[float]=0.1, 
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.cross_head = NormCrossAttention(config, tau=cross_tau, dropout=cross_dropout)
        self.post_init()
    

In [None]:
#| export
class OAK002(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        
        cross_tau:Optional[float]=0.1,
        cross_dropout:Optional[float]=0.1,
        
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder002(config, cross_tau=cross_tau, cross_dropout=cross_dropout,
                                  num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = OAK002.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               cross_tau=1.0, cross_dropout=0.1,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()

Some weights of OAK002 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.tau', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_

In [None]:
model.encoder.set_meta_embeddings(torch.zeros(656086, 768))

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

> /tmp/ipykernel_36176/2795290057.py(33)forward()
     31         output_attentions:Optional[bool] = False,
     32     ):
---> 33         bs, q_len, dim = q.size()
     34         v, k_len = k, k.size(1)
     35 



ipdb>  c


> /tmp/ipykernel_36176/2795290057.py(34)forward()
     32     ):
     33         bs, q_len, dim = q.size()
---> 34         v, k_len = k, k.size(1)
     35 
     36         h_dim = self.dim//self.n_h



ipdb>  c


In [None]:
o.loss

tensor(0.0638, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK003`

In [22]:
#| export
class Encoder003(Encoder):

    def __init__(
        self, 
        config,
        num_metadata:int,
        **kwargs
    ):
        super().__init__(config, num_metadata=num_metadata, **kwargs)
        self.pretrained_meta_embeddings = nn.Embedding(num_metadata, config.dim)
        self.post_init()

    def freeze_pretrained_meta_embeddings(self):
        self.pretrained_meta_embeddings.requires_grad_(False)

    def unfreeze_pretrained_meta_embeddings(self):
        self.pretrained_meta_embeddings.requires_grad_(True)

    @torch.no_grad()
    def set_pretrained_meta_embeddings(self, embed:torch.Tensor):
        self.pretrained_meta_embeddings.weight.data.copy_(embed)

    @torch.no_grad()
    def init_meta_embeddings(self):
        nn.init.zeros_(self.meta_embeddings.weight.data)

    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
            
            if len(idx):
                m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
                m_repr = self.meta_embeddings(m_idx) + self.pretrained_meta_embeddings(m_idx)
                m_repr = F.normalize(m_repr, dim=1) if self.normalize else m_repr
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
                meta_repr[m_key] = m_repr[m_repr_mask]
                
                fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
                data_fused_repr[idx] += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr
    

In [23]:
#| export
class OAK003(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder003(config, num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    @torch.no_grad()
    def init_meta_embeddings(self):
        self.encoder.init_meta_embeddings()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = OAK003.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

Some weights of OAK003 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [None]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

> /tmp/ipykernel_26893/4002658968.py(65)resize()
     63         #debug
     64 
---> 65         if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
     66         bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
     67 



ipdb>  num_inputs


tensor([3, 3, 3, 3, 3])


ipdb>  n


> /tmp/ipykernel_26893/4002658968.py(66)resize()
     64 
     65         if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
---> 66         bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
     67 
     68         self.ones = self.ones.to(idx.device)



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(68)resize()
     66         bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
     67 
---> 68         self.ones = self.ones.to(idx.device)
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(71)resize()
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
---> 71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )
     73 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(70)resize()
     68         self.ones = self.ones.to(idx.device)
     69         ones = (
---> 70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
     71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(71)resize()
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
---> 71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )
     73 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(70)resize()
     68         self.ones = self.ones.to(idx.device)
     69         ones = (
---> 70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
     71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(71)resize()
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
---> 71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )
     73 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(69)resize()
     67 
     68         self.ones = self.ones.to(idx.device)
---> 69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
     71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(74)resize()
     72         )
     73 
---> 74         max_num_inputs = num_inputs.max()
     75         if (num_inputs == max_num_inputs).all():
     76             return idx,ones



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(75)resize()
     73 
     74         max_num_inputs = num_inputs.max()
---> 75         if (num_inputs == max_num_inputs).all():
     76             return idx,ones
     77 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(76)resize()
     74         max_num_inputs = num_inputs.max()
     75         if (num_inputs == max_num_inputs).all():
---> 76             return idx,ones
     77 
     78         xnum_inputs = max_num_inputs-num_inputs+1



ipdb>  idx.shape


torch.Size([15])


ipdb>  ones.shape


torch.Size([15])


ipdb>  n


--Return--
(tensor([16184...4875,  84879]), tensor([1, 1,..., 1, 1, 1, 1]))
> /tmp/ipykernel_26893/4002658968.py(76)resize()
     74         max_num_inputs = num_inputs.max()
     75         if (num_inputs == max_num_inputs).all():
---> 76             return idx,ones
     77 
     78         xnum_inputs = max_num_inputs-num_inputs+1



ipdb>  n


> /tmp/ipykernel_26893/3123686291.py(36)fuse_meta_into_embeddings()
     34             if len(idx):
     35                 m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
---> 36                 m_repr = F.normalize(self.meta_embeddings(m_idx) + self.pretrained_meta_embeddings(m_idx), dim=1)
     37 
     38                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)



ipdb>  m_idx.shape


torch.Size([15])


ipdb>  m_repr_mask.shape


torch.Size([15])


ipdb>  n


> /tmp/ipykernel_26893/3123686291.py(38)fuse_meta_into_embeddings()
     36                 m_repr = F.normalize(self.meta_embeddings(m_idx) + self.pretrained_meta_embeddings(m_idx), dim=1)
     37 
---> 38                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
     39                 meta_repr[m_key] = m_repr[m_repr_mask]
     40 



ipdb>  c


  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(0.0675, grad_fn=<AddBackward0>)

## `OAK004`

In [None]:
#| export
class Encoder004(Encoder003):

    def __init__(
        self, 
        config,
        cross_tau:Optional[float]=0.1, 
        cross_dropout:Optional[float]=0.1, 
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.cross_head = NormCrossAttention(config, tau=cross_tau, dropout=cross_dropout)
        self.post_init()
    

In [None]:
#| export
class OAK004(OAK003, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK003.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        
        cross_tau:Optional[float]=1.0,
        cross_dropout:Optional[float]=0.1,
        
        **kwargs
    ):
        super().__init__(config, **kwargs, num_metadata=num_metadata, resize_length=resize_length)
        self.encoder = Encoder004(config, cross_tau=cross_tau, cross_dropout=cross_dropout,
                                  num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert

### Example

In [None]:
model = OAK004.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               cross_tau=1.0, cross_dropout=0.1,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()

Some weights of OAK004 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.tau', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_

In [None]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

model.init_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [None]:
o.loss

tensor(0.0618, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK005`

In [None]:
#| export
class Encoder005(Encoder003):

    def __init__(
        self, 
        config,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.gen_head = GenerationHead(config)
        
    def get_output_embeddings(self) -> nn.Module:
        return self.gen_head.projector

    def set_output_embeddings(self, new_embeddings: nn.Module):
        self.gen_head.projector = new_embeddings
    
    def gen(self, x:torch.Tensor):
        return self.gen_head(x)

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,

        data_gen_idx:Optional[torch.Tensor]=None,
        
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
                                                                            torch.any(data_attention_mask, dim=1), 
                                                                            meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_repr)

        data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx]) 
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
            
            logits=data_logits,
        )
    

In [None]:
#| export
class OAK005(OAK003, DistilBertPreTrainedModel):
    use_generation,use_representation = True,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK003.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,

        num_batch_labels:Optional[int]=None, 
        ignore_token:Optional[int]=0,
        gen_loss_weight:Optional[float]=1.0,
        use_gen_loss:Optional[bool]=True,
        
        **kwargs
    ):
        super().__init__(config, **kwargs, num_metadata=num_metadata, resize_length=resize_length, num_batch_labels=num_batch_labels)
        store_attr('use_gen_loss')
        self.g_lw = gen_loss_weight
        self.encoder = Encoder005(config, num_metadata=num_metadata, resize_length=resize_length)

        self.gen_loss_fn = MultiCrossEntropy(tn_targ=num_batch_labels, ig_tok=ignore_token, reduce='mean')
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert

    def compute_gen_loss(self, inp_logits, targ_logits, inp_input_ids, targ_input_ids, targ_ptr):
        gen_loss = self.gen_loss_fn(inp_logits, targ_input_ids, targ_ptr) + self.gen_loss_fn(targ_logits, inp_input_ids)
        return self.g_lw * gen_loss

    def init_generation_head(self):
        self.encoder.gen_head.projector.weight.data = self.get_input_embeddings().weight.data.clone()

    def get_last_item_mask(self, num_input:torch.Tensor, input_sz:int):
        idx = torch.where(num_input > 0)[0]
        input_ptr = num_input[idx].cumsum(dim=0)-1
        return torch.zeros(input_sz, dtype=torch.bool, device=num_input.device).scatter(0, input_ptr, 1)

    def freeze(self):
        for n,p in self.named_parameters():
            p.requires_grad_(False)

    def unfreeze(self):
        for n,p in self.named_parameters():
            p.requires_grad_(True)

    def unfreeze_head(self):
        for n,p in self.encoder.gen_head.named_parameters():
            p.requires_grad_(True)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
                                 **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_gen_loss:
                loss += self.compute_gen_loss(data_o.logits, lbl2data_o.logits, data_input_ids,lbl2data_input_ids,
                                              lbl2data_data2ptr)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,

            logits=data_o.logits,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
    

### Example

In [None]:
model = OAK005.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,
                               
                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,

                               ignore_token=0, gen_loss_weight=1.0, use_gen_loss=True,
                               
                               use_encoder_parallel=False)

model.init_generation_head()

Some weights of OAK005 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.gen_head.layer_norm.bias', 'encoder.gen_head.layer_norm.weight', 'encoder.gen_head.projector.bias', 'encoder.gen_head.projector.weight', 'enc

In [None]:
model.freeze()
model.unfreeze_head()

In [None]:
model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

> /tmp/ipykernel_3071/2330200620.py(58)forward()
     56         **kwargs
     57     ):  
---> 58         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     59 
     60         if self.use_encoder_parallel:



ipdb>  c


> /tmp/ipykernel_3071/1501785349.py(33)forward()
     31         **kwargs
     32     ):  
---> 33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
     35         if data_type is not None and data_type == "meta":



ipdb>  c


> /tmp/ipykernel_3071/1501785349.py(33)forward()
     31         **kwargs
     32     ):  
---> 33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
     35         if data_type is not None and data_type == "meta":



ipdb>  c


In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [None]:
o = func()

> /tmp/ipykernel_3071/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_3071/2330200620.py:43


ipdb>  b model.encoder.forward


Breakpoint 2 at /tmp/ipykernel_3071/1501785349.py:21


ipdb>  c


> /tmp/ipykernel_3071/2330200620.py(58)forward()
     56         **kwargs
     57     ):  
---> 58         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     59 
     60         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(60)forward()
     58         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     59 
---> 60         if self.use_encoder_parallel:
     61             encoder = XCDataParallel(module=self.encoder)
     62         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(62)forward()
     60         if self.use_encoder_parallel:
     61             encoder = XCDataParallel(module=self.encoder)
---> 62         else: encoder = self.encoder
     63 
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(64)forward()
     62         else: encoder = self.encoder
     63 
---> 64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(65)forward()
     63 
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     67 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(66)forward()
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     67 
     68 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(65)forward()
     63 
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     67 



ipdb>  s


> /tmp/ipykernel_3071/2330200620.py(66)forward()
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     67 
     68 



ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(65)forward()
     63 
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     67 



ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1507)_wrapped_call_impl()
   1505         return result
   1506 
-> 1507     def _wrapped_call_impl(self, *args, **kwargs):
   1508         if self._compiled_call_impl is not None:
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]



ipdb>  c


> /tmp/ipykernel_3071/1501785349.py(33)forward()
     31         **kwargs
     32     ):  
---> 33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
     35         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(35)forward()
     33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
---> 35         if data_type is not None and data_type == "meta":
     36             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
     37         else:



ipdb>  data_o.shape


*** AttributeError: 'BaseModelOutput' object has no attribute 'shape'


ipdb>  data_o[0].shape


torch.Size([5, 5, 768])


ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(38)forward()
     36             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
     37         else:
---> 38             data_repr = self.dr(data_o[0], data_attention_mask)
     39 
     40         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(40)forward()
     38             data_repr = self.dr(data_o[0], data_attention_mask)
     39 
---> 40         data_fused_repr = meta_repr = None
     41         if data_aug_meta_prefix is not None:
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(41)forward()
     39 
     40         data_fused_repr = meta_repr = None
---> 41         if data_aug_meta_prefix is not None:
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     43             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(42)forward()
     40         data_fused_repr = meta_repr = None
     41         if data_aug_meta_prefix is not None:
---> 42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     43             if len(meta_kwargs):
     44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(43)forward()
     41         if data_aug_meta_prefix is not None:
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
---> 43             if len(meta_kwargs):
     44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
     45                                                                             torch.any(data_attention_mask, dim=1),



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(44)forward()
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     43             if len(meta_kwargs):
---> 44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
     45                                                                             torch.any(data_attention_mask, dim=1),
     46                                                                             meta_kwargs)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(45)forward()
     43             if len(meta_kwargs):
     44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
---> 45                                                                             torch.any(data_attention_mask, dim=1),
     46                                                                             meta_kwargs)
     47                 data_fused_repr = self.dr_fused(data_fused_repr)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(46)forward()
     44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
     45                                                                             torch.any(data_attention_mask, dim=1),
---> 46                                                                             meta_kwargs)
     47                 data_fused_repr = self.dr_fused(data_fused_repr)
     48 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(44)forward()
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     43             if len(meta_kwargs):
---> 44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
     45                                                                             torch.any(data_attention_mask, dim=1),
     46                                                                             meta_kwargs)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(47)forward()
     45                                                                             torch.any(data_attention_mask, dim=1),
     46                                                                             meta_kwargs)
---> 47                 data_fused_repr = self.dr_fused(data_fused_repr)
     48 
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(49)forward()
     47                 data_fused_repr = self.dr_fused(data_fused_repr)
     48 
---> 49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
     51         return EncoderOutput(



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  data_logits.shape


torch.Size([5, 5, 30522])


ipdb>  data_gen_idx
ipdb>  data_gen_idx is None


True


ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(52)forward()
     50 
     51         return EncoderOutput(
---> 52             rep=data_repr,
     53             fused_rep=data_fused_repr,
     54             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(53)forward()
     51         return EncoderOutput(
     52             rep=data_repr,
---> 53             fused_rep=data_fused_repr,
     54             meta_repr=meta_repr,
     55 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(54)forward()
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,
---> 54             meta_repr=meta_repr,
     55 
     56             logits=data_logits,



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(56)forward()
     54             meta_repr=meta_repr,
     55 
---> 56             logits=data_logits,
     57         )
     58 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...xBackward0>)})
> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  n


--Return--
EncoderOutput...xBackward0>)})
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1511)_wrapped_call_impl()
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(69)forward()
     67 
     68 
---> 69         loss = None; lbl2data_o = EncoderOutput()
     70         if lbl2data_input_ids is not None:
     71             lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(70)forward()
     68 
     69         loss = None; lbl2data_o = EncoderOutput()
---> 70         if lbl2data_input_ids is not None:
     71             lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))
     72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(71)forward()
     69         loss = None; lbl2data_o = EncoderOutput()
     70         if lbl2data_input_ids is not None:
---> 71             lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))
     72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     73 



ipdb>  lbl2data_data2ptr


tensor([1, 1, 1, 1, 1])


ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(72)forward()
     70         if lbl2data_input_ids is not None:
     71             lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))
---> 72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     73 
     74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 



ipdb>  lbl2data_gen_idx


tensor([True, True, True, True, True])


ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(74)forward()
     72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     73 
---> 74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
     76                                  **lbl2data_meta_kwargs)



ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(75)forward()
     73 
     74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
---> 75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
     76                                  **lbl2data_meta_kwargs)
     77 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(74)forward()
     72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     73 
---> 74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
     76                                  **lbl2data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(76)forward()
     74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
---> 76                                  **lbl2data_meta_kwargs)
     77 
     78             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(74)forward()
     72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     73 
---> 74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
     76                                  **lbl2data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(33)forward()
     31         **kwargs
     32     ):  
---> 33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
     35         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(35)forward()
     33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
---> 35         if data_type is not None and data_type == "meta":
     36             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
     37         else:



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(38)forward()
     36             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
     37         else:
---> 38             data_repr = self.dr(data_o[0], data_attention_mask)
     39 
     40         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(40)forward()
     38             data_repr = self.dr(data_o[0], data_attention_mask)
     39 
---> 40         data_fused_repr = meta_repr = None
     41         if data_aug_meta_prefix is not None:
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(41)forward()
     39 
     40         data_fused_repr = meta_repr = None
---> 41         if data_aug_meta_prefix is not None:
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     43             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(49)forward()
     47                 data_fused_repr = self.dr_fused(data_fused_repr)
     48 
---> 49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
     51         return EncoderOutput(



ipdb>  data_gen_idx


tensor([True, True, True, True, True])


ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  data_logits.shape


torch.Size([5, 11, 30522])


ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(52)forward()
     50 
     51         return EncoderOutput(
---> 52             rep=data_repr,
     53             fused_rep=data_fused_repr,
     54             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(53)forward()
     51         return EncoderOutput(
     52             rep=data_repr,
---> 53             fused_rep=data_fused_repr,
     54             meta_repr=meta_repr,
     55 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(54)forward()
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,
---> 54             meta_repr=meta_repr,
     55 
     56             logits=data_logits,



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(56)forward()
     54             meta_repr=meta_repr,
     55 
---> 56             logits=data_logits,
     57         )
     58 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...eta_repr=None)
> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(78)forward()
     76                                  **lbl2data_meta_kwargs)
     77 
---> 78             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     79                                      plbl2data_data2ptr,plbl2data_idx)
     80 



ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(79)forward()
     77 
     78             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
---> 79                                      plbl2data_data2ptr,plbl2data_idx)
     80 
     81             if self.use_gen_loss:



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(78)forward()
     76                                  **lbl2data_meta_kwargs)
     77 
---> 78             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     79                                      plbl2data_data2ptr,plbl2data_idx)
     80 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(81)forward()
     79                                      plbl2data_data2ptr,plbl2data_idx)
     80 
---> 81             if self.use_gen_loss:
     82                 loss += self.compute_gen_loss(data_o.logits, lbl2data_o.logits, data_input_ids,lbl2data_input_ids,
     83                                               lbl2data_data2ptr)



ipdb>  self.use_gen_loss


True


ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(82)forward()
     80 
     81             if self.use_gen_loss:
---> 82                 loss += self.compute_gen_loss(data_o.logits, lbl2data_o.logits, data_input_ids,lbl2data_input_ids,
     83                                               lbl2data_data2ptr)
     84 



ipdb>  data_input_ids.shape


torch.Size([5, 5])


ipdb>  lbl2data_input_ids.shape


torch.Size([5, 11])


ipdb>  lbl2data_data2ptr.shape


torch.Size([5])


ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(83)forward()
     81             if self.use_gen_loss:
     82                 loss += self.compute_gen_loss(data_o.logits, lbl2data_o.logits, data_input_ids,lbl2data_input_ids,
---> 83                                               lbl2data_data2ptr)
     84 
     85             if self.use_query_loss:



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(82)forward()
     80 
     81             if self.use_gen_loss:
---> 82                 loss += self.compute_gen_loss(data_o.logits, lbl2data_o.logits, data_input_ids,lbl2data_input_ids,
     83                                               lbl2data_data2ptr)
     84 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(85)forward()
     83                                               lbl2data_data2ptr)
     84 
---> 85             if self.use_query_loss:
     86                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     87                                           plbl2data_data2ptr,plbl2data_idx)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(86)forward()
     84 
     85             if self.use_query_loss:
---> 86                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     87                                           plbl2data_data2ptr,plbl2data_idx)
     88 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(87)forward()
     85             if self.use_query_loss:
     86                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
---> 87                                           plbl2data_data2ptr,plbl2data_idx)
     88 
     89             if self.use_calib_loss:



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(86)forward()
     84 
     85             if self.use_query_loss:
---> 86                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     87                                           plbl2data_data2ptr,plbl2data_idx)
     88 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(89)forward()
     87                                           plbl2data_data2ptr,plbl2data_idx)
     88 
---> 89             if self.use_calib_loss:
     90                 loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     91                                               plbl2data_data2ptr,plbl2data_idx)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(93)forward()
     91                                               plbl2data_data2ptr,plbl2data_idx)
     92 
---> 93             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
     94 
     95             if self.use_fusion_loss:



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(95)forward()
     93             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
     94 
---> 95             if self.use_fusion_loss:
     96                 loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
     97                 loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(100)forward()
     98 
     99 
--> 100         if not return_dict:
    101             o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
    102             return ((loss,) + o) if loss is not None else o



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(105)forward()
    103 
    104 
--> 105         return XCModelOutput(
    106             loss=loss,
    107 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(106)forward()
    104 
    105         return XCModelOutput(
--> 106             loss=loss,
    107 
    108             data_repr=data_o.rep,



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(108)forward()
    106             loss=loss,
    107 
--> 108             data_repr=data_o.rep,
    109             data_fused_repr=data_o.fused_rep,
    110 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(109)forward()
    107 
    108             data_repr=data_o.rep,
--> 109             data_fused_repr=data_o.fused_rep,
    110 
    111             logits=data_o.logits,



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(111)forward()
    109             data_fused_repr=data_o.fused_rep,
    110 
--> 111             logits=data_o.logits,
    112 
    113             lbl2data_repr=lbl2data_o.rep,



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(113)forward()
    111             logits=data_o.logits,
    112 
--> 113             lbl2data_repr=lbl2data_o.rep,
    114             lbl2data_fused_repr=lbl2data_o.fused_rep,
    115         )



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(114)forward()
    112 
    113             lbl2data_repr=lbl2data_o.rep,
--> 114             lbl2data_fused_repr=lbl2data_o.fused_rep,
    115         )
    116 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(105)forward()
    103 
    104 
--> 105         return XCModelOutput(
    106             loss=loss,
    107 



ipdb>  


--Return--
XCModelOutput...sed_repr=None)
> /tmp/ipykernel_3071/2330200620.py(105)forward()
    103 
    104 
--> 105         return XCModelOutput(
    106             loss=loss,
    107 



ipdb>  


--Return--
XCModelOutput...sed_repr=None)
> /tmp/ipykernel_3071/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  


--Return--
None
> /tmp/ipykernel_3071/492731717.py(1)<module>()
----> 1 o = func()



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [None]:
o.loss

tensor(21.9783, grad_fn=<AddBackward0>)

## `OAK006`

In [None]:
#| export
class Encoder006(DistilBertPreTrainedModel):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
        num_metadata:int,
        resize_length:Optional[int]=None,
    ):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        
        self.dr_head = RepresentationHead(config)
        self.dr_fused_head = RepresentationHead(config)
        self.meta_head = RepresentationHead(config)
        self.cross_head = CrossAttention(config)
        self.meta_embeddings = nn.Embedding(num_metadata, config.dim, sparse=True)
        self.pretrained_meta_embeddings = nn.Embedding(num_metadata, config.dim)

        self.ones = torch.ones(resize_length, dtype=torch.long, device=self.device) if resize_length is not None else None
        self.post_init()

    def freeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(False)

    def unfreeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(True)

    def set_meta_embeddings(self, embed:torch.Tensor):
        self.meta_embeddings.weight.data = embed

    def init_meta_embeddings(self):
        self.meta_embeddings.weight.data = torch.zeros_like(self.meta_embeddings.weight.data)

    def freeze_pretrained_meta_embeddings(self):
        self.pretrained_meta_embeddings.requires_grad_(False)

    def unfreeze_pretrained_meta_embeddings(self):
        self.pretrained_meta_embeddings.requires_grad_(True)

    def set_pretrained_meta_embeddings(self, embed:torch.Tensor):
        self.pretrained_meta_embeddings.weight.data = embed
        
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
    
    def dr(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def dr_fused(self, embed:torch.Tensor):
        embed = self.dr_fused_head(embed)
        return F.normalize(embed, dim=1)

    def meta(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)
    
    def meta_unnormalized(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return Pooling.mean_pooling(embed, attention_mask)

    def resize(self, idx:torch.Tensor, num_inputs:torch.Tensor):
        if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
        bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
        
        self.ones = self.ones.to(idx.device)
        ones = (
            torch.ones(total_num_inputs, dtype=torch.long, device=idx.device) 
            if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
        )

        max_num_inputs = num_inputs.max()
        if (num_inputs == max_num_inputs).all():
            return idx,ones
        
        xnum_inputs = max_num_inputs-num_inputs+1

        inputs_ptr = num_inputs.cumsum(dim=0)-1
        repeat_inputs = ones.scatter(0, inputs_ptr, xnum_inputs)
        
        resized_idx = idx.repeat_interleave(repeat_inputs, dim=0)
        ignore_mask = ones.scatter(0, inputs_ptr, 0).repeat_interleave(repeat_inputs, dim=0).view(bsz, -1)
        ignore_mask[:, -1] = 1; ignore_mask = ignore_mask.flatten()
        
        return resized_idx,ignore_mask

    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
            
            if len(idx):
                m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
                m_repr = F.normalize(self.meta_embeddings(m_idx) + self.pretrained_meta_embeddings(m_idx), dim=1)
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
                meta_repr[m_key] = m_repr[m_repr_mask]
                
                fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
                data_fused_repr[idx] += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
                                                                            torch.any(data_attention_mask, dim=1), 
                                                                            meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_repr)
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

In [None]:
#| export
class OAK006(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder006(config, num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_meta_embeddings(self):
        self.encoder.init_meta_embeddings()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = OAK006.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

Some weights of OAK006 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [None]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.0160, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK007`

In [20]:
#| export
class OAK007(OAK003, DistilBertPreTrainedModel):
    
    @delegates(OAK003.__init__)
    def __init__(
        self, 
        config,
        n_labels:int,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.label_embeddings = nn.Embedding(n_labels, config.dim)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_label_embeddings(self):
        self.label_embeddings.weight.data = torch.zeros_like(self.label_embeddings.weight.data)

    def get_label_representation(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
        data_o.rep = F.normalize(data_o.rep + self.label_embeddings(data_idx), dim=1)
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    def forward(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)

        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(lbl2data_idx), dim=1)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

### Example

In [22]:
model = OAK007.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000, n_labels=block.n_lbl,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()
model.init_label_embeddings()

Some weights of OAK007 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [None]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [30]:
o = model(**b.to(model.device))

In [31]:
o.loss

tensor(0.0609, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK008`

In [17]:
#| export
class OAK008(OAK003, DistilBertPreTrainedModel):
    
    @delegates(OAK003.__init__)
    def __init__(
        self, 
        config,
        n_labels:int,
        n_clusters:int,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.label_embeddings = nn.Embedding(n_clusters, config.dim)
        self.register_buffer("label_remap", torch.arange(n_labels)%n_clusters, persistent=True)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_label_embeddings(self):
        self.label_embeddings.weight.data = torch.zeros_like(self.label_embeddings.weight.data)

    def set_label_embeddings(self, embed:torch.Tensor):
        self.label_embeddings.weight.data = embed

    def set_label_remap(self, label_remap:torch.Tensor):
        if label_remap.shape[0] != self.label_remap.shape[0]:
            raise ValueError(f'Shape mismatch, `label_remap` should have {self.label_remap.shape[0]} elements.')
        self.label_remap = label_remap

    def get_label_representation(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
        data_o.rep = F.normalize(data_o.rep + self.label_embeddings(self.label_remap[data_idx]), dim=1)
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    def forward(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)

        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

### Example

In [20]:
model = OAK008.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000, 
                               n_labels=block.n_lbl, n_clusters=block.n_lbl//3,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()
model.init_label_embeddings()

Some weights of OAK008 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [21]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [22]:
model = model.to('cuda')

In [23]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [24]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [25]:
o.loss

tensor(0.0260, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK009`

In [17]:
#| export
class OAK009(OAK008, DistilBertPreTrainedModel):
    
    @delegates(OAK008.__init__)
    def __init__(
        self, 
        config,
        embed_dim:int,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.transform = nn.Linear(config.dim, embed_dim)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_transform(self):
        self.transform.weight.data = torch.eye(self.transform.out_features, self.transform.in_features, 
                                               dtype=self.transform.weight.dtype)

    def get_label_representation(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
        data_o.rep = F.normalize(data_o.rep + self.label_embeddings(self.label_remap[data_idx]), dim=1)

        if data_o.rep is not None: data_o.rep = F.normalize(self.transform(data_o.rep), dim=1)
        if data_o.fused_rep is not None: data_o.fused_rep = F.normalize(self.transform(data_o.fused_rep), dim=1)
        
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    def forward(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)

        if data_o.rep is not None: data_o.rep = F.normalize(self.transform(data_o.rep), dim=1)
        if data_o.fused_rep is not None: data_o.fused_rep = F.normalize(self.transform(data_o.fused_rep), dim=1)

        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)

            if lbl2data_o.rep is not None: lbl2data_o.rep = F.normalize(self.transform(lbl2data_o.rep), dim=1)
            if lbl2data_o.fused_rep is not None: lbl2data_o.fused_rep = F.normalize(self.transform(lbl2data_o.fused_rep), dim=1)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

### Examples

In [18]:
model = OAK009.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000, 
                               n_labels=block.n_lbl, n_clusters=block.n_lbl//3, embed_dim=4096,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()
model.init_label_embeddings()
model.init_transform()

Some weights of OAK009 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [19]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [20]:
model = model.to('cuda')

In [21]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [22]:
def func():
    import pdb; pdb.set_trace()
    o = model(**b.to(model.device))
    

In [129]:
func()

> /tmp/ipykernel_18914/3721260802.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     o = model(**b.to(model.device))
      4 



ipdb>  b model.forward


Breakpoint 2 at /tmp/ipykernel_18914/134643332.py:19


ipdb>  c


> /tmp/ipykernel_18914/134643332.py(35)forward()
     33         **kwargs
     34     ):  
---> 35         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     36 
     37         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_18914/134643332.py(37)forward()
     35         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     36 
---> 37         if self.use_encoder_parallel:
     38             encoder = XCDataParallel(module=self.encoder)
     39         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_18914/134643332.py(39)forward()
     37         if self.use_encoder_parallel:
     38             encoder = XCDataParallel(module=self.encoder)
---> 39         else: encoder = self.encoder
     40 
     41         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_18914/134643332.py(41)forward()
     39         else: encoder = self.encoder
     40 
---> 41         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     42         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     43                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_18914/134643332.py(42)forward()
     40 
     41         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 42         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     43                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     44 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(43)forward()
     41         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     42         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 43                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     44 
     45         if data_o.rep is not None:



ipdb>  


> /tmp/ipykernel_18914/134643332.py(42)forward()
     40 
     41         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 42         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     43                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     44 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(43)forward()
     41         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     42         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 43                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     44 
     45         if data_o.rep is not None:



ipdb>  


> /tmp/ipykernel_18914/134643332.py(42)forward()
     40 
     41         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 42         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     43                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     44 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(45)forward()
     43                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     44 
---> 45         if data_o.rep is not None:
     46             data_o.rep = self.transform(data_o.rep)
     47         if data_o.fused_rep is not None:



ipdb>  data_o.rep.shape


torch.Size([5, 768])


ipdb>  data_o.fused_rep.shape


torch.Size([5, 768])


ipdb>  n


> /tmp/ipykernel_18914/134643332.py(46)forward()
     44 
     45         if data_o.rep is not None:
---> 46             data_o.rep = self.transform(data_o.rep)
     47         if data_o.fused_rep is not None:
     48             data_o.fused_rep = self.transform(data_o.fused_rep)



ipdb>  n


> /tmp/ipykernel_18914/134643332.py(47)forward()
     45         if data_o.rep is not None:
     46             data_o.rep = self.transform(data_o.rep)
---> 47         if data_o.fused_rep is not None:
     48             data_o.fused_rep = self.transform(data_o.fused_rep)
     49 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(48)forward()
     46             data_o.rep = self.transform(data_o.rep)
     47         if data_o.fused_rep is not None:
---> 48             data_o.fused_rep = self.transform(data_o.fused_rep)
     49 
     50         loss = None; lbl2data_o = EncoderOutput()



ipdb>  n


> /tmp/ipykernel_18914/134643332.py(50)forward()
     48             data_o.fused_rep = self.transform(data_o.fused_rep)
     49 
---> 50         loss = None; lbl2data_o = EncoderOutput()
     51         if lbl2data_input_ids is not None:
     52             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_18914/134643332.py(51)forward()
     49 
     50         loss = None; lbl2data_o = EncoderOutput()
---> 51         if lbl2data_input_ids is not None:
     52             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     53             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(52)forward()
     50         loss = None; lbl2data_o = EncoderOutput()
     51         if lbl2data_input_ids is not None:
---> 52             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     53             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     54                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_18914/134643332.py(53)forward()
     51         if lbl2data_input_ids is not None:
     52             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
---> 53             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     54                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     55             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)



ipdb>  


> /tmp/ipykernel_18914/134643332.py(54)forward()
     52             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     53             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
---> 54                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     55             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
     56 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(53)forward()
     51         if lbl2data_input_ids is not None:
     52             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
---> 53             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     54                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     55             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)



ipdb>  


> /tmp/ipykernel_18914/134643332.py(54)forward()
     52             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     53             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
---> 54                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     55             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
     56 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(53)forward()
     51         if lbl2data_input_ids is not None:
     52             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
---> 53             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     54                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     55             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)



ipdb>  


> /tmp/ipykernel_18914/134643332.py(55)forward()
     53             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     54                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
---> 55             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
     56 
     57             if lbl2data_o.rep is not None:



ipdb>  


> /tmp/ipykernel_18914/134643332.py(57)forward()
     55             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
     56 
---> 57             if lbl2data_o.rep is not None:
     58                 lbl2data_o.rep = self.transform(lbl2data_o.rep)
     59             if lbl2data_o.fused_rep is not None:



ipdb>  


> /tmp/ipykernel_18914/134643332.py(58)forward()
     56 
     57             if lbl2data_o.rep is not None:
---> 58                 lbl2data_o.rep = self.transform(lbl2data_o.rep)
     59             if lbl2data_o.fused_rep is not None:
     60                 lbl2data_o.fused_rep = self.transform(lbl2data_o.fused_rep)



ipdb>  


> /tmp/ipykernel_18914/134643332.py(59)forward()
     57             if lbl2data_o.rep is not None:
     58                 lbl2data_o.rep = self.transform(lbl2data_o.rep)
---> 59             if lbl2data_o.fused_rep is not None:
     60                 lbl2data_o.fused_rep = self.transform(lbl2data_o.fused_rep)
     61 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(62)forward()
     60                 lbl2data_o.fused_rep = self.transform(lbl2data_o.fused_rep)
     61 
---> 62             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     63                                      plbl2data_data2ptr,plbl2data_idx)
     64 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(63)forward()
     61 
     62             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
---> 63                                      plbl2data_data2ptr,plbl2data_idx)
     64 
     65             if self.use_query_loss:



ipdb>  


> /tmp/ipykernel_18914/134643332.py(62)forward()
     60                 lbl2data_o.fused_rep = self.transform(lbl2data_o.fused_rep)
     61 
---> 62             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     63                                      plbl2data_data2ptr,plbl2data_idx)
     64 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(65)forward()
     63                                      plbl2data_data2ptr,plbl2data_idx)
     64 
---> 65             if self.use_query_loss:
     66                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     67                                           plbl2data_data2ptr,plbl2data_idx)



ipdb>  


> /tmp/ipykernel_18914/134643332.py(66)forward()
     64 
     65             if self.use_query_loss:
---> 66                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     67                                           plbl2data_data2ptr,plbl2data_idx)
     68 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(67)forward()
     65             if self.use_query_loss:
     66                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
---> 67                                           plbl2data_data2ptr,plbl2data_idx)
     68 
     69             if self.use_calib_loss:



ipdb>  


> /tmp/ipykernel_18914/134643332.py(66)forward()
     64 
     65             if self.use_query_loss:
---> 66                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     67                                           plbl2data_data2ptr,plbl2data_idx)
     68 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(69)forward()
     67                                           plbl2data_data2ptr,plbl2data_idx)
     68 
---> 69             if self.use_calib_loss:
     70                 loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     71                                               plbl2data_data2ptr,plbl2data_idx)



ipdb>  


> /tmp/ipykernel_18914/134643332.py(73)forward()
     71                                               plbl2data_data2ptr,plbl2data_idx)
     72 
---> 73             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
     74 
     75             if self.use_fusion_loss:



ipdb>  


> /tmp/ipykernel_18914/134643332.py(75)forward()
     73             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
     74 
---> 75             if self.use_fusion_loss:
     76                 loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
     77                 loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_18914/134643332.py(80)forward()
     78 
     79 
---> 80         if not return_dict:
     81             o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
     82             return ((loss,) + o) if loss is not None else o



ipdb>  


> /tmp/ipykernel_18914/134643332.py(84)forward()
     82             return ((loss,) + o) if loss is not None else o
     83 
---> 84         return XCModelOutput(
     85             loss=loss,
     86 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(85)forward()
     83 
     84         return XCModelOutput(
---> 85             loss=loss,
     86 
     87             data_repr=data_o.rep,



ipdb>  


> /tmp/ipykernel_18914/134643332.py(87)forward()
     85             loss=loss,
     86 
---> 87             data_repr=data_o.rep,
     88             data_fused_repr=data_o.fused_rep,
     89 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(88)forward()
     86 
     87             data_repr=data_o.rep,
---> 88             data_fused_repr=data_o.fused_rep,
     89 
     90             lbl2data_repr=lbl2data_o.rep,



ipdb>  


> /tmp/ipykernel_18914/134643332.py(90)forward()
     88             data_fused_repr=data_o.fused_rep,
     89 
---> 90             lbl2data_repr=lbl2data_o.rep,
     91             lbl2data_fused_repr=lbl2data_o.fused_rep,
     92         )



ipdb>  


> /tmp/ipykernel_18914/134643332.py(91)forward()
     89 
     90             lbl2data_repr=lbl2data_o.rep,
---> 91             lbl2data_fused_repr=lbl2data_o.fused_rep,
     92         )
     93 



ipdb>  


> /tmp/ipykernel_18914/134643332.py(84)forward()
     82             return ((loss,) + o) if loss is not None else o
     83 
---> 84         return XCModelOutput(
     85             loss=loss,
     86 



ipdb>  


--Return--
XCModelOutput...sed_repr=None)
> /tmp/ipykernel_18914/134643332.py(84)forward()
     82             return ((loss,) + o) if loss is not None else o
     83 
---> 84         return XCModelOutput(
     85             loss=loss,
     86 



ipdb>  


--Return--
None
> /tmp/ipykernel_18914/3721260802.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     o = model(**b.to(model.device))
      4 



ipdb>  


--Call--
> /home/scai/phd/aiz218323/.local/lib/python3.9/site-packages/IPython/core/displayhook.py(258)__call__()
    256         sys.stdout.flush()
    257 
--> 258     def __call__(self, result=None):
    259         """Printing with history cache management.
    260 



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [23]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [25]:
o.loss

tensor(0.0260, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK010`

In [30]:
#| export
class Encoder010(Encoder003):

    def __init__(
        self, 
        config,
        n_clusters:int,
        n_metadata:int,
        **kwargs
    ):
        super().__init__(config, num_metadata=n_clusters, **kwargs)
        self.pretrained_meta_embeddings = nn.Embedding(n_metadata, config.dim)
        self.register_buffer("metadata_remap", torch.arange(n_metadata)%n_clusters, persistent=True)
        self.post_init()

    def set_metadata_remap(self, metadata_remap:torch.Tensor):
        if metadata_remap.shape[0] != self.metadata_remap.shape[0]:
            raise ValueError(f'Shape mismatch, `metadata_remap` should have {self.metadata_remap.shape[0]} elements.')
        self.metadata_remap = metadata_remap

    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
            
            if len(idx):
                m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
                m_repr = F.normalize(self.meta_embeddings(self.metadata_remap[m_idx]) + self.pretrained_meta_embeddings(m_idx), dim=1)
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
                meta_repr[m_key] = m_repr[m_repr_mask]
                
                fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
                data_fused_repr[idx] += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr
    

In [21]:
#| export
class OAK010(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        n_clusters:int,
        n_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder010(config, n_clusters=n_clusters, n_metadata=n_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_meta_embeddings(self):
        self.encoder.init_meta_embeddings()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [32]:
model = OAK010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               n_metadata=block.train.dset.meta['cat_meta'].n_meta, n_clusters=1000, resize_length=5000,
                               
                               calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.0, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

Some weights of OAK010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [33]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

remap = torch.arange(block.train.dset.meta['cat_meta'].n_meta)%1000
model.encoder.set_metadata_remap(remap)

In [34]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [35]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [36]:
o = model(**b.to(model.device))

In [37]:
o.loss

tensor(0.0260, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK011`

In [26]:
#| export
class Encoder011(Encoder003):
    
    def __init__(
        self, 
        config, 
        n_clusters:int,
        n_metadata:int,
        **kwargs
    ):
        super().__init__(config, num_metadata=n_clusters, **kwargs)
        self.pretrained_meta_embeddings = nn.Embedding(n_clusters, config.dim)
        self.register_buffer("metadata_remap", torch.arange(n_metadata)%n_clusters, persistent=True)
        self.post_init()

    def set_metadata_remap(self, metadata_remap:torch.Tensor):
        if metadata_remap.shape[0] != self.metadata_remap.shape[0]:
            raise ValueError(f'Shape mismatch, `metadata_remap` should have {self.metadata_remap.shape[0]} elements.')
        self.metadata_remap = metadata_remap

    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
            
            if len(idx):
                m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
                m_repr = F.normalize(self.meta_embeddings(self.metadata_remap[m_idx]) + self.pretrained_meta_embeddings(self.metadata_remap[m_idx]), dim=1)
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
                meta_repr[m_key] = m_repr[m_repr_mask]
                
                fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
                data_fused_repr[idx] += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr

    

In [27]:
#| export
class OAK011(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        n_clusters:int,
        n_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder011(config, n_clusters=n_clusters, n_metadata=n_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_meta_embeddings(self):
        self.encoder.init_meta_embeddings()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert

### Example

In [28]:
model = OAK011.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               n_metadata=block.train.dset.meta['cat_meta'].n_meta, n_clusters=1000, resize_length=5000,
                               
                               calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.0, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

Some weights of OAK011 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [29]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

remap = torch.arange(block.train.dset.meta['cat_meta'].n_meta)%1000
model.encoder.set_metadata_remap(remap)

In [30]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [31]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [32]:
o.loss

tensor(0.0252, grad_fn=<AddBackward0>)

## `OAK012`

In [37]:
#| export
class Encoder012(Encoder):

    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.meta_transform = RepresentationHead(config)
        self.post_init()

    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
            
            if len(idx):
                m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
                m_repr = F.normalize(self.meta_transform(self.meta_embeddings(m_idx)), dim=1)
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
                meta_repr[m_key] = m_repr[m_repr_mask]
                
                fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
                data_fused_repr[idx] += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr
    

In [38]:
#| export
class OAK012(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder012(config, num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()
        
    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [39]:
model = OAK012.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()

Some weights of OAK012 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [41]:
model.encoder.meta_transform.post_init()

In [43]:
model.encoder.set_meta_embeddings(torch.zeros(656086, 768))

In [44]:
model = model.to('cuda')

In [45]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [46]:
b.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'data_input_ids', 'data_attention_mask'])

In [47]:
o = model(**b.to(model.device))

In [48]:
o.loss

tensor(0.0260, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK013`

In [55]:
#| export
class Encoder013(Encoder003):

    def __init__(
        self, 
        config,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.meta_distilbert = DistilBertModel(config)
        self.post_init()

    def meta_encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        o = self.meta_distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        return Pooling.mean_pooling(o[0], attention_mask)

    def init_meta_encoder(self):
        sd_meta, sd_dr = self.meta_distilbert.state_dict(), self.distilbert.state_dict()
        for k in sd_dr:
            assert sd_meta[k].shape == sd_dr[k].shape
            with torch.no_grad():
                sd_meta[k].copy_(sd_dr[k])

    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            n_meta, bsz = m_args['data2ptr'].max(), len(m_args['data2ptr'])
            assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
            
            m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
            m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
            
            m_idx = m_args['idx']
            m_repr = F.normalize(self.meta_embeddings(m_idx) + m_embed + self.pretrained_meta_embeddings(m_idx), dim=1)
            
            m_repr, m_repr_mask = m_repr.view(bsz, -1, self.config.dim), torch.ones((m_repr.shape[0],), device=m_repr.device, dtype=torch.bool).view(bsz, -1)
            meta_repr[m_key] = m_repr[m_repr_mask]
            
            fused_repr = self.cross_head(data_fused_repr, data_mask, m_repr, m_repr_mask)[0]
            data_fused_repr += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr
    

In [56]:
#| export
class OAK013(OAK008, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]
    _keys_to_ignore_on_load_missing = ["encoder.meta_distilbert"]

    @delegates(OAK008.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, num_metadata=num_metadata, resize_length=resize_length, **kwargs)
        self.encoder = Encoder013(config, num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert

    def init_meta_encoder(self):
        self.encoder.init_meta_encoder()
        

### Example

In [57]:
model = OAK013.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000, 
                               n_labels=block.n_lbl, n_clusters=block.n_lbl//3,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()
model.init_label_embeddings()

Some weights of OAK013 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [58]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

model.init_meta_encoder()

In [23]:
model = model.to('cuda')

In [24]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [25]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [99]:
o.loss

tensor(0.0149, device='cuda:0', grad_fn=<AddBackward0>)

In [100]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))

In [101]:
o = func()

> /tmp/ipykernel_26512/1795951242.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))



ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_26512/2049060262.py:45


ipdb>  b model.encoder.forward


Breakpoint 2 at /tmp/ipykernel_26512/4122862164.py:106


ipdb>  c


> /tmp/ipykernel_26512/2049060262.py(61)forward()
     59         **kwargs
     60     ):  
---> 61         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     62 
     63         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_26512/2049060262.py(63)forward()
     61         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     62 
---> 63         if self.use_encoder_parallel:
     64             encoder = XCDataParallel(module=self.encoder)
     65         else: encoder = self.encoder



ipdb>  n


> /tmp/ipykernel_26512/2049060262.py(65)forward()
     63         if self.use_encoder_parallel:
     64             encoder = XCDataParallel(module=self.encoder)
---> 65         else: encoder = self.encoder
     66 
     67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  n


> /tmp/ipykernel_26512/2049060262.py(67)forward()
     65         else: encoder = self.encoder
     66 
---> 67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(68)forward()
     66 
     67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 



ipdb>  data_meta_kwargs.keys()


dict_keys(['lnk2data_attention_mask', 'lnk2data_input_ids', 'lnk2data_idx', 'lnk2data_data2ptr'])


ipdb>  n


> /tmp/ipykernel_26512/2049060262.py(69)forward()
     67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 
     71         loss = None; lbl2data_o = EncoderOutput()



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(68)forward()
     66 
     67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(69)forward()
     67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 
     71         loss = None; lbl2data_o = EncoderOutput()



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(68)forward()
     66 
     67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(115)forward()
    113         **kwargs
    114     ):  
--> 115         data_o = self.encode(data_input_ids, data_attention_mask)
    116 
    117         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_26512/4122862164.py(117)forward()
    115         data_o = self.encode(data_input_ids, data_attention_mask)
    116 
--> 117         if data_type is not None and data_type == "meta":
    118             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    119         else:



ipdb>  data_o


BaseModelOutput(last_hidden_state=tensor([[[-0.6077, -1.4566, -1.2658,  ...,  0.4173,  0.4841, -0.4502],
         [-0.6280, -0.1254, -0.2795,  ...,  0.7303,  0.4780, -0.8034],
         [-0.5878,  0.2601, -0.8291,  ...,  0.7271, -0.0244, -0.1744],
         [-0.3885, -0.0131, -0.3512,  ...,  0.8768, -0.2073, -0.3525],
         [-0.2144, -1.0242, -1.2639,  ...,  1.1083,  0.2689,  0.1441]],

        [[-0.0116, -0.6070, -1.0937,  ...,  0.4919, -0.6339,  0.3495],
         [-0.4328, -0.5217, -0.7185,  ...,  0.3906, -0.9542,  0.9589],
         [-0.6673, -0.8967, -1.2427,  ...,  0.3123, -0.9577, -0.1136],
         [-0.2185,  0.0981, -0.4987,  ...,  0.7605, -0.7738,  0.0544],
         [ 0.1080,  0.3637, -0.7733,  ...,  0.7831, -0.8321,  0.0020]],

        [[-0.1806,  0.5675, -1.1267,  ...,  1.1764,  0.0713,  0.6914],
         [ 0.0746,  0.4647, -1.5286,  ...,  0.7264,  0.2030,  0.5578],
         [-0.2783,  0.2235, -0.8286,  ...,  0.4618,  0.0872, -0.2953],
         [ 0.7958, -0.2424, -1.0187,  .

ipdb>  l


    112         data_unnormalized:Optional[bool]=False,
    113         **kwargs
    114     ):  
    115         data_o = self.encode(data_input_ids, data_attention_mask)
    116 
--> 117         if data_type is not None and data_type == "meta":
    118             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    119         else:
    120             data_repr = self.dr(data_o[0], data_attention_mask)
    121 
    122         data_fused_repr = meta_repr = None



ipdb>  n


> /tmp/ipykernel_26512/4122862164.py(120)forward()
    118             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    119         else:
--> 120             data_repr = self.dr(data_o[0], data_attention_mask)
    121 
    122         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(122)forward()
    120             data_repr = self.dr(data_o[0], data_attention_mask)
    121 
--> 122         data_fused_repr = meta_repr = None
    123         if data_aug_meta_prefix is not None:
    124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  data_repr.norm(dim=1)


tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  n


> /tmp/ipykernel_26512/4122862164.py(123)forward()
    121 
    122         data_fused_repr = meta_repr = None
--> 123         if data_aug_meta_prefix is not None:
    124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    125             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(124)forward()
    122         data_fused_repr = meta_repr = None
    123         if data_aug_meta_prefix is not None:
--> 124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    125             if len(meta_kwargs):
    126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(125)forward()
    123         if data_aug_meta_prefix is not None:
    124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 125             if len(meta_kwargs):
    126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    127                                                                             torch.any(data_attention_mask, dim=1),



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(126)forward()
    124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    125             if len(meta_kwargs):
--> 126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    127                                                                             torch.any(data_attention_mask, dim=1),
    128                                                                             meta_kwargs)



ipdb>  s


> /tmp/ipykernel_26512/4122862164.py(127)forward()
    125             if len(meta_kwargs):
    126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
--> 127                                                                             torch.any(data_attention_mask, dim=1),
    128                                                                             meta_kwargs)
    129                 data_fused_repr = self.dr_fused(data_fused_repr)



ipdb>  n


> /tmp/ipykernel_26512/4122862164.py(128)forward()
    126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    127                                                                             torch.any(data_attention_mask, dim=1),
--> 128                                                                             meta_kwargs)
    129                 data_fused_repr = self.dr_fused(data_fused_repr)
    130 



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(126)forward()
    124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    125             if len(meta_kwargs):
--> 126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    127                                                                             torch.any(data_attention_mask, dim=1),
    128                                                                             meta_kwargs)



ipdb>  s


--Call--
> /tmp/ipykernel_26512/3565704745.py(28)fuse_meta_into_embeddings()
     26                 sd_meta[k].copy_(sd_dr[k])
     27 
---> 28     def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
     29         meta_repr = {}
     30 



ipdb>  n


> /tmp/ipykernel_26512/3565704745.py(29)fuse_meta_into_embeddings()
     27 
     28     def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
---> 29         meta_repr = {}
     30 
     31         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)



ipdb>  


> /tmp/ipykernel_26512/3565704745.py(31)fuse_meta_into_embeddings()
     29         meta_repr = {}
     30 
---> 31         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
     32         for m_key, m_args in meta_kwargs.items():
     33             n_meta, bsz = m_args['data2ptr'].max(), len(m_args['data2ptr'])



ipdb>  


> /tmp/ipykernel_26512/3565704745.py(32)fuse_meta_into_embeddings()
     30 
     31         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
---> 32         for m_key, m_args in meta_kwargs.items():
     33             n_meta, bsz = m_args['data2ptr'].max(), len(m_args['data2ptr'])
     34             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'



ipdb>  data_fused_repr.shape


torch.Size([5, 1, 768])


ipdb>  data_mask.shape


torch.Size([5, 1])


ipdb>  data_mask


tensor([[True],
        [True],
        [True],
        [True],
        [True]])


ipdb>  n


> /tmp/ipykernel_26512/3565704745.py(33)fuse_meta_into_embeddings()
     31         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
     32         for m_key, m_args in meta_kwargs.items():
---> 33             n_meta, bsz = m_args['data2ptr'].max(), len(m_args['data2ptr'])
     34             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     35 



ipdb>  n


> /tmp/ipykernel_26512/3565704745.py(34)fuse_meta_into_embeddings()
     32         for m_key, m_args in meta_kwargs.items():
     33             n_meta, bsz = m_args['data2ptr'].max(), len(m_args['data2ptr'])
---> 34             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     35 
     36             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']



ipdb>  n_meta


tensor(3)


ipdb>  bsz


5


ipdb>  n


> /tmp/ipykernel_26512/3565704745.py(36)fuse_meta_into_embeddings()
     34             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     35 
---> 36             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
     37             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     38 



ipdb>  n


> /tmp/ipykernel_26512/3565704745.py(37)fuse_meta_into_embeddings()
     35 
     36             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
---> 37             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     38 
     39             m_idx = m_args['idx']



ipdb>  n


> /tmp/ipykernel_26512/3565704745.py(39)fuse_meta_into_embeddings()
     37             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     38 
---> 39             m_idx = m_args['idx']
     40             m_repr = F.normalize(self.meta_embeddings(m_idx) + m_embed + self.pretrained_meta_embeddings(m_idx), dim=1)
     41 



ipdb>  m_embed


tensor([[ 0.4711, -0.5405, -0.3613,  ...,  0.6627,  0.2056, -0.4132],
        [-0.5355,  0.3508, -1.1401,  ..., -0.0078,  0.1995, -0.0930],
        [-0.4559, -0.0452, -0.4907,  ...,  0.2223,  0.1730, -0.4363],
        ...,
        [-0.7420, -0.2882, -0.0634,  ...,  0.1905,  0.5780, -0.6984],
        [-0.0296,  0.0637,  0.8700,  ..., -0.6256, -0.0730, -1.0056],
        [-0.1513, -0.3573, -1.3101,  ...,  0.4211, -0.1453, -0.6851]],
       grad_fn=<DivBackward0>)


ipdb>  embed.shape


*** NameError: name 'embed' is not defined


ipdb>  m_embed.shape


torch.Size([15, 768])


ipdb>  n


> /tmp/ipykernel_26512/3565704745.py(40)fuse_meta_into_embeddings()
     38 
     39             m_idx = m_args['idx']
---> 40             m_repr = F.normalize(self.meta_embeddings(m_idx) + m_embed + self.pretrained_meta_embeddings(m_idx), dim=1)
     41 
     42             m_repr, m_repr_mask = m_repr.view(bsz, -1, self.config.dim), torch.ones((m_repr.shape[0],), device=m_repr.device, dtype=torch.bool).view(bsz, -1)



ipdb>  m_idx


tensor([ 72729, 161854, 161858, 120509, 306634, 487671, 102556,  68239,  54422,
         79395, 174407,  72843,  84871,  84868,  84732])


ipdb>  self.meta_embeddings(m_idx).shape


torch.Size([15, 768])


ipdb>  m_embed.shape


torch.Size([15, 768])


ipdb>  self.pretrained_meta_embeddings(m_idx).shape


torch.Size([15, 768])


ipdb>  n


> /tmp/ipykernel_26512/3565704745.py(42)fuse_meta_into_embeddings()
     40             m_repr = F.normalize(self.meta_embeddings(m_idx) + m_embed + self.pretrained_meta_embeddings(m_idx), dim=1)
     41 
---> 42             m_repr, m_repr_mask = m_repr.view(bsz, -1, self.config.dim), torch.ones((m_repr.shape[0],), device=m_repr.device, dtype=torch.bool).view(bsz, -1)
     43             meta_repr[m_key] = m_repr[m_repr_mask]
     44 



ipdb>  n


> /tmp/ipykernel_26512/3565704745.py(43)fuse_meta_into_embeddings()
     41 
     42             m_repr, m_repr_mask = m_repr.view(bsz, -1, self.config.dim), torch.ones((m_repr.shape[0],), device=m_repr.device, dtype=torch.bool).view(bsz, -1)
---> 43             meta_repr[m_key] = m_repr[m_repr_mask]
     44 
     45             fused_repr = self.cross_head(data_fused_repr, data_mask, m_repr, m_repr_mask)[0]



ipdb>  m_repr.shape


torch.Size([5, 3, 768])


ipdb>  m_repr_mask.shape


torch.Size([5, 3])


ipdb>  m_repr_mask


tensor([[True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True]])


ipdb>  n


> /tmp/ipykernel_26512/3565704745.py(45)fuse_meta_into_embeddings()
     43             meta_repr[m_key] = m_repr[m_repr_mask]
     44 
---> 45             fused_repr = self.cross_head(data_fused_repr, data_mask, m_repr, m_repr_mask)[0]
     46             data_fused_repr += fused_repr
     47 



ipdb>  data_mask.shape


torch.Size([5, 1])


ipdb>  data_mask


tensor([[True],
        [True],
        [True],
        [True],
        [True]])


ipdb>  m_repr_mask.shape


torch.Size([5, 3])


ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1696)__getattr__()
   1694     # See full discussion on the problems with returning `Union` here
   1695     # https://github.com/microsoft/pyright/issues/4213
-> 1696     def __getattr__(self, name: str) -> Any:
   1697         if '_parameters' in self.__dict__:
   1698             _parameters = self.__dict__['_parameters']



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1697)__getattr__()
   1695     # https://github.com/microsoft/pyright/issues/4213
   1696     def __getattr__(self, name: str) -> Any:
-> 1697         if '_parameters' in self.__dict__:
   1698             _parameters = self.__dict__['_parameters']
   1699             if name in _parameters:



ipdb>  r


--Return--
CrossAttentio..., bias=True)
)
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1708)__getattr__()
   1706             modules = self.__dict__['_modules']
   1707             if name in modules:
-> 1708                 return modules[name]
   1709         raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
   1710 



ipdb>  n


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1528)_wrapped_call_impl()
   1526         return result
   1527 
-> 1528     def _wrapped_call_impl(self, *args, **kwargs):
   1529         if self._compiled_call_impl is not None:
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1529)_wrapped_call_impl()
   1527 
   1528     def _wrapped_call_impl(self, *args, **kwargs):
-> 1529         if self._compiled_call_impl is not None:
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1532)_wrapped_call_impl()
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):



ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1534)_call_impl()
   1532             return self._call_impl(*args, **kwargs)
   1533 
-> 1534     def _call_impl(self, *args, **kwargs):
   1535         forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
   1536         # If we don't have any hooks, we want to skip the rest of the logic in



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1535)_call_impl()
   1533 
   1534     def _call_impl(self, *args, **kwargs):
-> 1535         forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
   1536         # If we don't have any hooks, we want to skip the rest of the logic in
   1537         # this function, and just call forward.



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1538)_call_impl()
   1536         # If we don't have any hooks, we want to skip the rest of the logic in
   1537         # this function, and just call forward.
-> 1538         if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1539)_call_impl()
   1537         # this function, and just call forward.
   1538         if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
-> 1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
   1541             return forward_call(*args, **kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1538)_call_impl()
   1536         # If we don't have any hooks, we want to skip the rest of the logic in
   1537         # this function, and just call forward.
-> 1538         if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1539)_call_impl()
   1537         # this function, and just call forward.
   1538         if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
-> 1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
   1541             return forward_call(*args, **kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1538)_call_impl()
   1536         # If we don't have any hooks, we want to skip the rest of the logic in
   1537         # this function, and just call forward.
-> 1538         if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1540)_call_impl()
   1538         if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
-> 1540                 or _global_forward_hooks or _global_forward_pre_hooks):
   1541             return forward_call(*args, **kwargs)
   1542 



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1538)_call_impl()
   1536         # If we don't have any hooks, we want to skip the rest of the logic in
   1537         # this function, and just call forward.
-> 1538         if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1540)_call_impl()
   1538         if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
-> 1540                 or _global_forward_hooks or _global_forward_pre_hooks):
   1541             return forward_call(*args, **kwargs)
   1542 



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1538)_call_impl()
   1536         # If we don't have any hooks, we want to skip the rest of the logic in
   1537         # this function, and just call forward.
-> 1538         if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1541)_call_impl()
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:



ipdb>  s


--Call--
> /tmp/ipykernel_26512/1765836756.py(23)forward()
     21         self.o.weight.data = torch.eye(self.o.out_features, self.o.in_features, dtype=self.o.weight.dtype)
     22 
---> 23     def forward(
     24         self,
     25         q: torch.Tensor,



ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(31)forward()
     29         output_attentions:Optional[bool] = False,
     30     ):
---> 31         bs, q_len, dim = q.size()
     32         v, k_len = k, k.size(1)
     33 



ipdb>  !q.size()


torch.Size([5, 1, 768])


ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(32)forward()
     30     ):
     31         bs, q_len, dim = q.size()
---> 32         v, k_len = k, k.size(1)
     33 
     34         h_dim = self.dim//self.n_h



ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(34)forward()
     32         v, k_len = k, k.size(1)
     33 
---> 34         h_dim = self.dim//self.n_h
     35 
     36         def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)



ipdb>  k_len


3


ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(36)forward()
     34         h_dim = self.dim//self.n_h
     35 
---> 36         def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)
     37 
     38         def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)



ipdb>  h_dim


64


ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(38)forward()
     36         def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)
     37 
---> 38         def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)
     39 
     40         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)



ipdb>  


> /tmp/ipykernel_26512/1765836756.py(40)forward()
     38         def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)
     39 
---> 40         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
     41         k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
     42         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)



ipdb>  


> /tmp/ipykernel_26512/1765836756.py(41)forward()
     39 
     40         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
---> 41         k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
     42         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)
     43 



ipdb>  


> /tmp/ipykernel_26512/1765836756.py(42)forward()
     40         q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
     41         k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
---> 42         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)
     43 
     44         q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)



ipdb>  


> /tmp/ipykernel_26512/1765836756.py(44)forward()
     42         v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)
     43 
---> 44         q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
     45         sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
     46 



ipdb>  !q.shape


torch.Size([5, 12, 1, 64])


ipdb>  !k.shape


torch.Size([5, 12, 3, 64])


ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(45)forward()
     43 
     44         q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
---> 45         sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
     46 
     47         q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)



ipdb>  k.transpose(2, 3).shape


torch.Size([5, 12, 64, 3])


ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(47)forward()
     45         sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
     46 
---> 47         q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
     48         mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
     49 



ipdb>  sc.shape


torch.Size([5, 12, 1, 3])


ipdb>  sc


tensor([[[[-4.1424e-04, -2.2880e-04, -6.0960e-05]],

         [[ 9.9972e-04,  9.5448e-04,  7.1228e-04]],

         [[ 2.2895e-03,  1.8841e-03,  1.4976e-04]],

         [[-2.1307e-03, -5.1904e-04,  1.2063e-03]],

         [[ 1.9698e-03,  6.5417e-04,  8.0679e-04]],

         [[ 1.2304e-03,  1.1282e-03,  1.5732e-03]],

         [[ 2.9919e-03,  1.0506e-03,  2.1831e-03]],

         [[ 4.1416e-04,  1.0185e-03,  7.6501e-04]],

         [[ 5.5429e-04,  4.4133e-03,  5.2506e-03]],

         [[ 1.3871e-03,  9.2414e-04,  2.5209e-03]],

         [[ 5.8925e-04,  2.4008e-04,  1.0212e-03]],

         [[ 4.2999e-03,  1.7437e-03,  2.4098e-03]]],


        [[[ 1.5860e-03,  4.0973e-03,  6.9158e-03]],

         [[ 2.4058e-03,  7.0596e-03,  1.0392e-02]],

         [[-9.3684e-04,  6.7379e-03,  8.7702e-03]],

         [[ 2.2622e-04,  3.5357e-03,  4.1718e-03]],

         [[ 1.9554e-03,  6.1795e-03,  7.3816e-03]],

         [[ 9.0992e-04,  7.4287e-03,  7.7551e-03]],

         [[ 2.3973e-03,  1.8775e-03,  4.6831

ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(48)forward()
     46 
     47         q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
---> 48         mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
     49 
     50         sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)



ipdb>  q_m.shape


torch.Size([5, 1, 1, 1])


ipdb>  k_m.shape


torch.Size([5, 1, 1, 3])


ipdb>  sc.shape


torch.Size([5, 12, 1, 3])


ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(50)forward()
     48         mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
     49 
---> 50         sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)
     51 
     52         w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)



ipdb>  mask


tensor([[[[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]]],


        [[[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]]],


        [[[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]],

         [[1., 1., 1.]]],


        [[[1., 1., 1.]],

         [[1., 1., 1.]],

      

ipdb>  mask.shape


torch.Size([5, 12, 1, 3])


ipdb>  l


     45         sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
     46 
     47         q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
     48         mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
     49 
---> 50         sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)
     51 
     52         w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
     53         w = self.dropout(w)  # (bs, n_h, q_len, k_len)
     54 
     55         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)



ipdb>  sc.shape


torch.Size([5, 12, 1, 3])


ipdb>  mask.shape


torch.Size([5, 12, 1, 3])


ipdb>  mask.shape


torch.Size([5, 12, 1, 3])


ipdb>  mask.sum()


tensor(180.)


ipdb>  sc.shape


torch.Size([5, 12, 1, 3])


ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(52)forward()
     50         sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)
     51 
---> 52         w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
     53         w = self.dropout(w)  # (bs, n_h, q_len, k_len)
     54 



ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(53)forward()
     51 
     52         w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
---> 53         w = self.dropout(w)  # (bs, n_h, q_len, k_len)
     54 
     55         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)



ipdb>  (mask == 0).sum()


tensor(0)


ipdb>  !w.shape


torch.Size([5, 12, 1, 3])


ipdb>  !w.sum(dim=-1)


tensor([[[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
    

ipdb>  !w.sum(dim=-1) == 1


tensor([[[ True],
         [False],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [False],
         [False]],

        [[False],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [ True],
         [False],
         [ True]],

        [[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [ True],
         [ True],
         [ True],
         [ True]],

        [[ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
         [ True],
         [ True],
         [False],
         [ True],
         [ True]],

        [[False],
         [ True],
         [ True],
         [ True],
         [ True],
         [ True],
         [False],
  

ipdb>  !w.sum(dim=-1)-1


tensor([[[ 0.0000e+00],
         [ 1.1921e-07],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [-1.1921e-07],
         [-1.1921e-07],
         [-5.9605e-08]],

        [[-5.9605e-08],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 1.1921e-07],
         [ 0.0000e+00],
         [-5.9605e-08],
         [ 0.0000e+00]],

        [[ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [-5.9605e-08],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00]],

        [[ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [ 0.0000e+00],
         [

ipdb>  !w.sum(dim=-1)


tensor([[[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
         [1.0000]],

        [[1.0000],
         [1.0000],
         [1.0000],
         [1.0000],
    

ipdb>  !(w.sum(dim=-1)-1).sum()


tensor(-2.9802e-07, grad_fn=<SumBackward0>)


ipdb>  l


     48         mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
     49 
     50         sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)
     51 
     52         w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
---> 53         w = self.dropout(w)  # (bs, n_h, q_len, k_len)
     54 
     55         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
     56 
     57         if output_attentions: return (o, w)
     58         else: return (o,)



ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(55)forward()
     53         w = self.dropout(w)  # (bs, n_h, q_len, k_len)
     54 
---> 55         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
     56 
     57         if output_attentions: return (o, w)



ipdb>  !w.shape


torch.Size([5, 12, 1, 3])


ipdb>  !v.shape


torch.Size([5, 12, 3, 64])


ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(57)forward()
     55         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
     56 
---> 57         if output_attentions: return (o, w)
     58         else: return (o,)
     59 



ipdb>  !o.shape


torch.Size([5, 1, 768])


ipdb>  n


> /tmp/ipykernel_26512/1765836756.py(58)forward()
     55         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
     56 
     57         if output_attentions: return (o, w)
---> 58         else: return (o,)
     59 



ipdb>  


--Return--
(tensor([[[-0....iewBackward0>),)
> /tmp/ipykernel_26512/1765836756.py(58)forward()
     55         o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
     56 
     57         if output_attentions: return (o, w)
---> 58         else: return (o,)
     59 



ipdb>  


--Return--
(tensor([[[-0....iewBackward0>),)
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1541)_call_impl()
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:



ipdb>  


--Return--
(tensor([[[-0....iewBackward0>),)
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1532)_wrapped_call_impl()
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):



ipdb>  


> /tmp/ipykernel_26512/3565704745.py(46)fuse_meta_into_embeddings()
     44 
     45             fused_repr = self.cross_head(data_fused_repr, data_mask, m_repr, m_repr_mask)[0]
---> 46             data_fused_repr += fused_repr
     47 
     48         return data_fused_repr.squeeze(), meta_repr



ipdb>  data_fused_repr.shape


torch.Size([5, 1, 768])


ipdb>  fused_repr.shape


torch.Size([5, 1, 768])


ipdb>  n


> /tmp/ipykernel_26512/3565704745.py(32)fuse_meta_into_embeddings()
     30 
     31         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
---> 32         for m_key, m_args in meta_kwargs.items():
     33             n_meta, bsz = m_args['data2ptr'].max(), len(m_args['data2ptr'])
     34             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'



ipdb>  


> /tmp/ipykernel_26512/3565704745.py(48)fuse_meta_into_embeddings()
     45             fused_repr = self.cross_head(data_fused_repr, data_mask, m_repr, m_repr_mask)[0]
     46             data_fused_repr += fused_repr
     47 
---> 48         return data_fused_repr.squeeze(), meta_repr
     49 



ipdb>  n


--Return--
(tensor([[-0.0...ezeBackward0>), {'lnk2data': tensor([[ 0.0...dexBackward0>)})
> /tmp/ipykernel_26512/3565704745.py(48)fuse_meta_into_embeddings()
     45             fused_repr = self.cross_head(data_fused_repr, data_mask, m_repr, m_repr_mask)[0]
     46             data_fused_repr += fused_repr
     47 
---> 48         return data_fused_repr.squeeze(), meta_repr
     49 



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(129)forward()
    127                                                                             torch.any(data_attention_mask, dim=1),
    128                                                                             meta_kwargs)
--> 129                 data_fused_repr = self.dr_fused(data_fused_repr)
    130 
    131         return EncoderOutput(



ipdb>  n


> /tmp/ipykernel_26512/4122862164.py(131)forward()
    129                 data_fused_repr = self.dr_fused(data_fused_repr)
    130 
--> 131         return EncoderOutput(
    132             rep=data_repr,
    133             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(132)forward()
    130 
    131         return EncoderOutput(
--> 132             rep=data_repr,
    133             fused_rep=data_fused_repr,
    134             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(133)forward()
    131         return EncoderOutput(
    132             rep=data_repr,
--> 133             fused_rep=data_fused_repr,
    134             meta_repr=meta_repr,
    135         )



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(134)forward()
    132             rep=data_repr,
    133             fused_rep=data_fused_repr,
--> 134             meta_repr=meta_repr,
    135         )
    136 



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(131)forward()
    129                 data_fused_repr = self.dr_fused(data_fused_repr)
    130 
--> 131         return EncoderOutput(
    132             rep=data_repr,
    133             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...xBackward0>)})
> /tmp/ipykernel_26512/4122862164.py(131)forward()
    129                 data_fused_repr = self.dr_fused(data_fused_repr)
    130 
--> 131         return EncoderOutput(
    132             rep=data_repr,
    133             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(71)forward()
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 
---> 71         loss = None; lbl2data_o = EncoderOutput()
     72         if lbl2data_input_ids is not None:
     73             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(72)forward()
     70 
     71         loss = None; lbl2data_o = EncoderOutput()
---> 72         if lbl2data_input_ids is not None:
     73             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(73)forward()
     71         loss = None; lbl2data_o = EncoderOutput()
     72         if lbl2data_input_ids is not None:
---> 73             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(74)forward()
     72         if lbl2data_input_ids is not None:
     73             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
---> 74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     76             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(75)forward()
     73             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
---> 75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     76             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
     77 



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(74)forward()
     72         if lbl2data_input_ids is not None:
     73             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
---> 74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     76             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(75)forward()
     73             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
---> 75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     76             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)
     77 



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(74)forward()
     72         if lbl2data_input_ids is not None:
     73             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
---> 74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     76             lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(self.label_remap[lbl2data_idx]), dim=1)



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(115)forward()
    113         **kwargs
    114     ):  
--> 115         data_o = self.encode(data_input_ids, data_attention_mask)
    116 
    117         if data_type is not None and data_type == "meta":



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(117)forward()
    115         data_o = self.encode(data_input_ids, data_attention_mask)
    116 
--> 117         if data_type is not None and data_type == "meta":
    118             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    119         else:



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(120)forward()
    118             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    119         else:
--> 120             data_repr = self.dr(data_o[0], data_attention_mask)
    121 
    122         data_fused_repr = meta_repr = None



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [102]:
o.loss

tensor(0.0149, grad_fn=<AddBackward0>)

## `OAK014`

In [118]:
#| export
class Encoder014(Encoder013):

    def __init__(
        self,
        config,
        n_meta_clusters:int,
        n_metadata:int,
        **kwargs
    ):
        super().__init__(config, num_metadata=n_meta_clusters, **kwargs)
        self.pretrained_meta_embeddings = nn.Embedding(n_meta_clusters, config.dim)
        self.register_buffer("metadata_remap", torch.arange(n_metadata)%n_meta_clusters, persistent=True)
        self.post_init()

    def set_metadata_remap(self, metadata_remap:torch.Tensor):
        if metadata_remap.shape[0] != self.metadata_remap.shape[0]:
            raise ValueError(f'Shape mismatch, `metadata_remap` should have {self.metadata_remap.shape[0]} elements.')
        with torch.no_grad():
            self.metadata_remap.copy_(metadata_remap)

    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            n_meta, bsz = m_args['data2ptr'].max(), len(m_args['data2ptr'])
            assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
            
            m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
            m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
            
            m_idx = m_args['idx']
            m_repr = F.normalize(self.meta_embeddings(self.metadata_remap[m_idx]) + m_embed + self.pretrained_meta_embeddings(self.metadata_remap[m_idx]), dim=1)
            
            m_repr, m_repr_mask = m_repr.view(bsz, -1, self.config.dim), torch.ones((m_repr.shape[0],), device=m_repr.device, dtype=torch.bool).view(bsz, -1)
            meta_repr[m_key] = m_repr[m_repr_mask]
            
            fused_repr = self.cross_head(data_fused_repr, data_mask, m_repr, m_repr_mask)[0]
            data_fused_repr += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr
        

In [122]:
#| export
class OAK014(OAK013, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]
    _keys_to_ignore_on_load_missing = ["encoder.meta_distilbert"]

    @delegates(OAK013.__init__)
    def __init__(
        self, 
        config,
        n_metadata:int,
        n_meta_clusters:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, num_metadata=n_metadata, resize_length=resize_length, **kwargs)
        self.encoder = Encoder014(config, n_metadata=n_metadata, n_meta_clusters=n_meta_clusters, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()
        
        

### Example

In [123]:
model = OAK014.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               n_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000, n_meta_clusters=block.train.dset.meta['lnk_meta'].n_meta//3,
                               n_labels=block.n_lbl, n_clusters=block.n_lbl//3, 
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()
model.init_label_embeddings()

Some weights of OAK014 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [124]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

model.init_meta_encoder()

remap = torch.arange(block.train.dset.meta['lnk_meta'].n_meta)%1000
model.encoder.set_metadata_remap(remap)

In [51]:
model = model.to('cuda')

In [52]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [53]:
o = model(**b.to(model.device))

In [54]:
o.loss

tensor(0.0149, device='cuda:0', grad_fn=<AddBackward0>)

In [116]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))

In [None]:
o = func()

> /tmp/ipykernel_26512/1795951242.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))



ipdb>  b model.forward


Breakpoint 3 at /tmp/ipykernel_26512/2049060262.py:45


ipdb>  b model.encoder.forward


Breakpoint 4 at /tmp/ipykernel_26512/4122862164.py:106


ipdb>  c


> /tmp/ipykernel_26512/2049060262.py(61)forward()
     59         **kwargs
     60     ):  
---> 61         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     62 
     63         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_26512/2049060262.py(63)forward()
     61         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     62 
---> 63         if self.use_encoder_parallel:
     64             encoder = XCDataParallel(module=self.encoder)
     65         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(65)forward()
     63         if self.use_encoder_parallel:
     64             encoder = XCDataParallel(module=self.encoder)
---> 65         else: encoder = self.encoder
     66 
     67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(67)forward()
     65         else: encoder = self.encoder
     66 
---> 67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  c


> /tmp/ipykernel_26512/2049060262.py(68)forward()
     66 
     67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 



ipdb>  c


> /tmp/ipykernel_26512/2049060262.py(68)forward()
     66 
     67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 



ipdb>  n


> /tmp/ipykernel_26512/2049060262.py(69)forward()
     67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 
     71         loss = None; lbl2data_o = EncoderOutput()



ipdb>  


> /tmp/ipykernel_26512/2049060262.py(68)forward()
     66 
     67         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 68         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     69                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     70 



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(115)forward()
    113         **kwargs
    114     ):  
--> 115         data_o = self.encode(data_input_ids, data_attention_mask)
    116 
    117         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_26512/4122862164.py(117)forward()
    115         data_o = self.encode(data_input_ids, data_attention_mask)
    116 
--> 117         if data_type is not None and data_type == "meta":
    118             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    119         else:



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(120)forward()
    118             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    119         else:
--> 120             data_repr = self.dr(data_o[0], data_attention_mask)
    121 
    122         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(122)forward()
    120             data_repr = self.dr(data_o[0], data_attention_mask)
    121 
--> 122         data_fused_repr = meta_repr = None
    123         if data_aug_meta_prefix is not None:
    124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(123)forward()
    121 
    122         data_fused_repr = meta_repr = None
--> 123         if data_aug_meta_prefix is not None:
    124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    125             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(124)forward()
    122         data_fused_repr = meta_repr = None
    123         if data_aug_meta_prefix is not None:
--> 124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    125             if len(meta_kwargs):
    126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(125)forward()
    123         if data_aug_meta_prefix is not None:
    124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 125             if len(meta_kwargs):
    126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    127                                                                             torch.any(data_attention_mask, dim=1),



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(126)forward()
    124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    125             if len(meta_kwargs):
--> 126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    127                                                                             torch.any(data_attention_mask, dim=1),
    128                                                                             meta_kwargs)



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(127)forward()
    125             if len(meta_kwargs):
    126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
--> 127                                                                             torch.any(data_attention_mask, dim=1),
    128                                                                             meta_kwargs)
    129                 data_fused_repr = self.dr_fused(data_fused_repr)



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(128)forward()
    126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    127                                                                             torch.any(data_attention_mask, dim=1),
--> 128                                                                             meta_kwargs)
    129                 data_fused_repr = self.dr_fused(data_fused_repr)
    130 



ipdb>  


> /tmp/ipykernel_26512/4122862164.py(126)forward()
    124             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    125             if len(meta_kwargs):
--> 126                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
    127                                                                             torch.any(data_attention_mask, dim=1),
    128                                                                             meta_kwargs)



ipdb>  s


--Call--
> /tmp/ipykernel_26512/1735825316.py(22)fuse_meta_into_embeddings()
     20             self.metadata_remap.copy_(metadata_remap)
     21 
---> 22     def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
     23         meta_repr = {}
     24 



ipdb>  n


> /tmp/ipykernel_26512/1735825316.py(23)fuse_meta_into_embeddings()
     21 
     22     def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
---> 23         meta_repr = {}
     24 
     25         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)



ipdb>  


> /tmp/ipykernel_26512/1735825316.py(25)fuse_meta_into_embeddings()
     23         meta_repr = {}
     24 
---> 25         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
     26         for m_key, m_args in meta_kwargs.items():
     27             n_meta, bsz = m_args['data2ptr'].max(), len(m_args['data2ptr'])



ipdb>  


> /tmp/ipykernel_26512/1735825316.py(26)fuse_meta_into_embeddings()
     24 
     25         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
---> 26         for m_key, m_args in meta_kwargs.items():
     27             n_meta, bsz = m_args['data2ptr'].max(), len(m_args['data2ptr'])
     28             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'



ipdb>  


> /tmp/ipykernel_26512/1735825316.py(27)fuse_meta_into_embeddings()
     25         data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
     26         for m_key, m_args in meta_kwargs.items():
---> 27             n_meta, bsz = m_args['data2ptr'].max(), len(m_args['data2ptr'])
     28             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     29 



ipdb>  


> /tmp/ipykernel_26512/1735825316.py(28)fuse_meta_into_embeddings()
     26         for m_key, m_args in meta_kwargs.items():
     27             n_meta, bsz = m_args['data2ptr'].max(), len(m_args['data2ptr'])
---> 28             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     29 
     30             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']



ipdb>  


> /tmp/ipykernel_26512/1735825316.py(30)fuse_meta_into_embeddings()
     28             assert torch.all(m_args['data2ptr'] == n_meta), f'All datapoints should have same number of metadata.'
     29 
---> 30             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
     31             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     32 



ipdb>  


> /tmp/ipykernel_26512/1735825316.py(31)fuse_meta_into_embeddings()
     29 
     30             m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
---> 31             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     32 
     33             m_idx = m_args['idx']



ipdb>  


> /tmp/ipykernel_26512/1735825316.py(33)fuse_meta_into_embeddings()
     31             m_embed = self.meta_encode(input_ids=m_input_ids, attention_mask=m_attention_mask)
     32 
---> 33             m_idx = m_args['idx']
     34             m_repr = F.normalize(self.meta_embeddings(self.metadata_remap[m_idx]) + m_embed + self.pretrained_meta_embeddings(self.metadata_remap[m_idx]), dim=1)
     35 



## `OAK015`

In [24]:
#| export
class OAK015(OAK003):

    def __init__(
        self, 
        config,
        neg2data_aug_meta_prefix:Optional[str]=None,
        **kwargs,
    ):
        super().__init__(config, **kwargs)
        store_attr('neg2data_aug_meta_prefix')
        self.rep_loss_fn = MultiTripletWithNegatives(margin=kwargs['margin'], n_negatives=kwargs['num_negatives'], 
                                                     tau=kwargs['tau'], apply_softmax=kwargs['apply_softmax'], 
                                                     reduce='mean')
        
    def compute_loss(self, data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, neg2data_repr, 
                     neg2data_data2ptr, neg2data_idx, plbl2data_data2ptr, plbl2data_idx, **kwargs):
        return self.rep_loss_fn(data_repr, pos_targ=lbl2data_repr, n_pos=lbl2data_data2ptr, pos_idx=lbl2data_idx, 
                                neg_targ=neg2data_repr, n_neg=neg2data_data2ptr, neg_idx=neg2data_idx, 
                                n_ppos=plbl2data_data2ptr, ppos_idx=plbl2data_idx, **kwargs)
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,

        neg2data_input_ids:Optional[torch.Tensor]=None,
        neg2data_attention_mask:Optional[torch.Tensor]=None,
        neg2data_data2ptr:Optional[torch.Tensor]=None,
        neg2data_idx:Optional[torch.Tensor]=None,
        
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        loss = None; lbl2data_o = EncoderOutput()
        if (
            lbl2data_input_ids is not None and 
            neg2data_input_ids is not None
        ):
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)

            neg2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('neg2data', self.neg2data_aug_meta_prefix, **kwargs)
            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
                                 data_aug_meta_prefix=self.neg2data_aug_meta_prefix, **neg2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
                                     neg2data_data2ptr, neg2data_idx, plbl2data_data2ptr, plbl2data_idx, **kwargs)
            
            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
                                          neg2data_data2ptr, neg2data_idx, plbl2data_data2ptr, plbl2data_idx, **kwargs)
                
            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, neg2data_o.rep, neg2data_data2ptr, neg2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
                
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
            

### Example

In [79]:
model = OAK015.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', margin=0.3, num_negatives=5, 
                               tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None,
                               neg2data_aug_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False, normalize=True)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

Some weights of OAK015 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [96]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [102]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [108]:
o = model(**b.to(model.device))

In [105]:
o.loss

tensor(0.1556, grad_fn=<AddBackward0>)

In [106]:
def func():
    import pdb; pdb.set_trace()
    o = model(**b.to(model.device))
    

In [110]:
func()

> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/3721260802.py[0m(2)[0;36mfunc[0;34m()[0m
[0;32m      1 [0;31m[0;32mdef[0m [0mfunc[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      3 [0;31m    [0mo[0m [0;34m=[0m [0mmodel[0m[0;34m([0m[0;34m**[0m[0mb[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mmodel[0m[0;34m.[0m[0mdevice[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(45)[0;36mforward[0;34m()[0m
[0;32m     43 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m1[0;32m--> 45 [0;31m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     46 [0;31m[0;34m[0m[0m
[0m[0;32m     47 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(119)[0;36mforward[0;34m()[0m
[0;32m    117 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    118 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m2[0;32m-> 119 [0;31m        [0mdata_o[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    120 [0;31m[0;34m[0m[0m
[0m[0;32m    121 [0;31m        [0;32mif[0m [0mdata_type[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mdata_type[0m [0;34m==[0m [0;34m"meta"[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2221889608.py[0m(29)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     27 [0;31m[0;34m[0m[0m
[0m[0;32m     28 [0;31m    [0;32mdef[0m [0mfuse_meta_into_embeddings[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mdata_repr[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mdata_mask[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mmeta_kwargs[0m[0;34m:[0m[0mDict[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m3[0;32m--> 29 [0;31m        [0mmeta_repr[0m [0;34m=[0m [0;34m{[0m[0;34m}[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     30 [0;31m[0;34m[0m[0m
[0m[0;32m     31 [0;31m        [0mdata_fused_repr[0m[0;34m,[0m [0mdata_mask[0m [0;34m=[0m [0mdata_repr[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0;36m1[0m[0;34m,[0m [0mself[0m[0;34m.

ipdb>  c


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(119)[0;36mforward[0;34m()[0m
[0;32m    117 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    118 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m2[0;32m-> 119 [0;31m        [0mdata_o[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    120 [0;31m[0;34m[0m[0m
[0m[0;32m    121 [0;31m        [0;32mif[0m [0mdata_type[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mdata_type[0m [0;34m==[0m [0;34m"meta"[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  r


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(135)[0;36mforward[0;34m()[0m
[0;32m    133 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_fused[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    134 [0;31m[0;34m[0m[0m
[0m[0;32m--> 135 [0;31m        return EncoderOutput(
[0m[0;32m    136 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    137 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1750)[0;36m_call_impl[0;34m()[0m
[0;32m   1748 [0;31m                [0;32mor[0m [0m_global_backward_pre_hooks[0m [0;32mor[0m [0m_global_backward_hooks[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1749 [0;31m                [0;32mor[0m [0m_global_forward_hooks[0m [0;32mor[0m [0m_global_forward_pre_hooks[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1750 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1751 [0;31m[0;34m[0m[0m
[0m[0;32m   1752 [0;31m        [0mresult[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1739)[0;36m_wrapped_call_impl[0;34m()[0m
[0;32m   1737 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_compiled_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m  [0;31m# type: ignore[misc][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1738 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1739 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1740 [0;31m[0;34m[0m[0m
[0m[0;32m   1741 [0;31m    [0;31m# torchrec tests the code consistency with the following code[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(64)[0;36mforward[0;34m()[0m
[0;32m     62 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mlbl2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     63 [0;31m[0;34m[0m[0m
[0m[0;32m---> 64 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     65 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m     66 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0m

ipdb>  lbl2data_o


EncoderOutput(rep=tensor([[-0.0252,  0.0914, -0.0155,  ..., -0.0417,  0.0890, -0.0145],
        [ 0.0107, -0.0238, -0.0349,  ...,  0.0151, -0.0291, -0.0294],
        [-0.0068, -0.0017, -0.0149,  ...,  0.0273,  0.0346,  0.0232],
        [-0.0244, -0.0302, -0.0259,  ...,  0.0584,  0.0058, -0.0318]],
       grad_fn=<DivBackward0>), fused_rep=None, logits=None, fusion_weights=None, meta_repr=None)


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(65)[0;36mforward[0;34m()[0m
[0;32m     63 [0;31m[0;34m[0m[0m
[0m[0;32m     64 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m     66 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(66)[0;36mforward[0;34m()[0m
[0;32m     64 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     65 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m---> 66 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m     68 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_id

ipdb>  neg2data_meta_kwargs


{}


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(65)[0;36mforward[0;34m()[0m
[0;32m     63 [0;31m[0;34m[0m[0m
[0m[0;32m     64 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m     66 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(66)[0;36mforward[0;34m()[0m
[0;32m     64 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     65 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m---> 66 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m     68 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_id

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(65)[0;36mforward[0;34m()[0m
[0;32m     63 [0;31m[0;34m[0m[0m
[0m[0;32m     64 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m     66 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(119)[0;36mforward[0;34m()[0m
[0;32m    117 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    118 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m2[0;32m-> 119 [0;31m        [0mdata_o[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    120 [0;31m[0;34m[0m[0m
[0m[0;32m    121 [0;31m        [0;32mif[0m [0mdata_type[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mdata_type[0m [0;34m==[0m [0;34m"meta"[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  r


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(135)[0;36mforward[0;34m()[0m
[0;32m    133 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_fused[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    134 [0;31m[0;34m[0m[0m
[0m[0;32m--> 135 [0;31m        return EncoderOutput(
[0m[0;32m    136 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    137 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1750)[0;36m_call_impl[0;34m()[0m
[0;32m   1748 [0;31m                [0;32mor[0m [0m_global_backward_pre_hooks[0m [0;32mor[0m [0m_global_backward_hooks[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1749 [0;31m                [0;32mor[0m [0m_global_forward_hooks[0m [0;32mor[0m [0m_global_forward_pre_hooks[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1750 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1751 [0;31m[0;34m[0m[0m
[0m[0;32m   1752 [0;31m        [0mresult[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1739)[0;36m_wrapped_call_impl[0;34m()[0m
[0;32m   1737 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_compiled_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m  [0;31m# type: ignore[misc][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1738 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1739 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1740 [0;31m[0;34m[0m[0m
[0m[0;32m   1741 [0;31m    [0;31m# torchrec tests the code consistency with the following code[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(68)[0;36mforward[0;34m()[0m
[0;32m     66 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m---> 68 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
[0m[0;32m     69 [0;31m                                     [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m


ipdb>  lbl2data_o.rep.shape


torch.Size([4, 768])


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(69)[0;36mforward[0;34m()[0m
[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m     68 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
[0m[0;32m---> 69 [0;31m                                     [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(68)[0;36mforward[0;34m()[0m
[0;32m     66 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m---> 68 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
[0m[0;32m     69 [0;31m                                     [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(69)[0;36mforward[0;34m()[0m
[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m     68 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
[0m[0;32m---> 69 [0;31m                                     [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(68)[0;36mforward[0;34m()[0m
[0;32m     66 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m---> 68 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
[0m[0;32m     69 [0;31m                                     [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(71)[0;36mforward[0;34m()[0m
[0;32m     69 [0;31m                                     [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m---> 71 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     72 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
[0m[0;32m     73 [0;31m                                          [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  loss


tensor(0.0780, grad_fn=<DivBackward0>)


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
[0m[0;32m     73 [0;31m                                          [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(73)[0;36mforward[0;34m()[0m
[0;32m     71 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     72 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
[0m[0;32m---> 73 [0;31m                                          [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
[0m[0;32m     73 [0;31m                                          [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(73)[0;36mforward[0;34m()[0m
[0;32m     71 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     72 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
[0m[0;32m---> 73 [0;31m                                          [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m[0;34m[0m[0m
[0m[0;32m     75 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m[0;34m[0m[0m
[0m[0;32m     71 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx, neg2data_o.rep, 
[0m[0;32m     73 [0;31m                                          [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m                                          [0mneg2data_data2ptr[0m[0;34m,[0m [0mneg2data_idx[0m[0;34m,[0m [0mplbl2data_data2ptr[0m[0;34m,[0m [0mplbl2data_idx[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx,
[0m[0;32m     77 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m        [0;32mif[0m [0;32mnot[0m [0mreturn_dict[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     82 [0;31m            [0mo[0m [0;34m=[0m [0;34m([0m[0mdata_o[0m[0;34m.[0m[0mlogits[0m[0;34m,[0m[0mdata_o[0m[0;34m.[0m[0mrep[0m[0;34m,[0m[0mdata_o[0m[0;34m.[0m[0mfused_rep[0m[0;34m,[0m[0mlbl2data_o[0m[0;34m.[0m[0mlogits[0m[0;34m,[0m[0mlbl2data_o[0m[0;34m.[0m[0mrep[0m[0;34m,[0m[0mlbl2data_o[0m[0;34m.[0m[0mfused_rep[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     83 [0;31m            [0;32mreturn[0m [0;34m([0m[0;34m([0m[0mloss[0m[0;34m,[0m[0;34m)[0m [0;34m

ipdb>  c


[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m



## `OAK16`

In [25]:
#| export
class OAK016(OAK003):

    def __init__(
        self, 
        config,
        neg2data_aug_meta_prefix:Optional[str]=None,
        **kwargs,
    ):
        super().__init__(config, **kwargs)
        store_attr('neg2data_aug_meta_prefix')
        self.rep_loss_fn = MarginMSEWithNegatives()
        
    def compute_loss(self, data_repr, lbl2data_repr, lbl2data_scores, neg2data_repr, neg2data_scores, **kwargs):
        return self.rep_loss_fn(data_repr, lbl2data_repr, lbl2data_scores, neg2data_repr, neg2data_scores, **kwargs)
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_scores:Optional[torch.Tensor]=None,

        neg2data_input_ids:Optional[torch.Tensor]=None,
        neg2data_attention_mask:Optional[torch.Tensor]=None,
        neg2data_data2ptr:Optional[torch.Tensor]=None,
        neg2data_idx:Optional[torch.Tensor]=None,
        neg2data_scores:Optional[torch.Tensor]=None,
        
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        loss = None; lbl2data_o = EncoderOutput()
        if (
            lbl2data_input_ids is not None and 
            neg2data_input_ids is not None
        ):
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)

            neg2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('neg2data', self.neg2data_aug_meta_prefix, **kwargs)
            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
                                 data_aug_meta_prefix=self.neg2data_aug_meta_prefix, **neg2data_meta_kwargs)

            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
                                     neg2data_scores, **kwargs)
            
            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
                                         neg2data_scores, **kwargs)
                
            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, neg2data_o.rep, neg2data_data2ptr, neg2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
                
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
            

### Example

In [113]:
model = OAK016.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', margin=0.3, num_negatives=5, 
                               tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None,
                               neg2data_aug_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False, normalize=True)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

Some weights of OAK016 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [116]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [119]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [123]:
o = model(**b.to(model.device))

In [124]:
o.loss

tensor(0.0051, grad_fn=<AddBackward0>)

In [125]:
def func():
    import pdb; pdb.set_trace()
    o = model(**b.to(model.device))
    

In [126]:
func()

> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/3721260802.py[0m(2)[0;36mfunc[0;34m()[0m
[0;32m      1 [0;31m[0;32mdef[0m [0mfunc[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      3 [0;31m    [0mo[0m [0;34m=[0m [0mmodel[0m[0;34m([0m[0;34m**[0m[0mb[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mmodel[0m[0;34m.[0m[0mdevice[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m[0;34m[0m[0m
[0m


ipdb>  b


Num Type         Disp Enb   Where
1   breakpoint   keep no    at /var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py:45
	breakpoint already hit 3 times
2   breakpoint   keep no    at /var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py:119
	breakpoint already hit 8 times
3   breakpoint   keep no    at /var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2221889608.py:29
	breakpoint already hit 3 times


ipdb>  enable 1


Enabled breakpoint 1 at /var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1017695108.py:45


ipdb>  enable 2 


Enabled breakpoint 2 at /var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py:119


ipdb>  enable 3


Enabled breakpoint 3 at /var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2221889608.py:29


ipdb>  c


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(119)[0;36mforward[0;34m()[0m
[0;32m    117 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    118 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m2[0;32m-> 119 [0;31m        [0mdata_o[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    120 [0;31m[0;34m[0m[0m
[0m[0;32m    121 [0;31m        [0;32mif[0m [0mdata_type[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mdata_type[0m [0;34m==[0m [0;34m"meta"[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  r


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2221889608.py[0m(29)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     27 [0;31m[0;34m[0m[0m
[0m[0;32m     28 [0;31m    [0;32mdef[0m [0mfuse_meta_into_embeddings[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mdata_repr[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mdata_mask[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mmeta_kwargs[0m[0;34m:[0m[0mDict[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m3[0;32m--> 29 [0;31m        [0mmeta_repr[0m [0;34m=[0m [0;34m{[0m[0;34m}[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     30 [0;31m[0;34m[0m[0m
[0m[0;32m     31 [0;31m        [0mdata_fused_repr[0m[0;34m,[0m [0mdata_mask[0m [0;34m=[0m [0mdata_repr[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0;36m1[0m[0;34m,[0m [0mself[0m[0;34m.

ipdb>  r


--Return--
(tensor([[-0.0...ezeBackward0>), {'cat2data': tensor([[0., ...dexBackward0>)})
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2221889608.py[0m(46)[0;36mfuse_meta_into_embeddings[0;34m()[0m
[0;32m     43 [0;31m                [0mfused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcross_head[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m [0mdata_mask[0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m [0mm_repr[0m[0;34m,[0m [0mm_repr_mask[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m                [0mdata_fused_repr[0m[0;34m[[0m[0midx[0m[0;34m][0m [0;34m+=[0m [0mfused_repr[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m[0;34m[0m[0m
[0m[0;32m---> 46 [0;31m        [0;32mreturn[0m [0mdata_fused_repr[0m[0;34m.[0m[0msqueeze[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mmeta_repr[0m[0;34m[0m[0;34m[0m[0m
[0

ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(133)[0;36mforward[0;34m()[0m
[0;32m    131 [0;31m                                                                            [0mtorch[0m[0;34m.[0m[0many[0m[0;34m([0m[0mdata_attention_mask[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    132 [0;31m                                                                            [0mmeta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 133 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_fused[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    134 [0;31m[0;34m[0m[0m
[0m[0;32m    135 [0;31m        return EncoderOutput(
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(135)[0;36mforward[0;34m()[0m
[0;32m    133 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_fused[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    134 [0;31m[0;34m[0m[0m
[0m[0;32m--> 135 [0;31m        return EncoderOutput(
[0m[0;32m    136 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    137 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(136)[0;36mforward[0;34m()[0m
[0;32m    134 [0;31m[0;34m[0m[0m
[0m[0;32m    135 [0;31m        return EncoderOutput(
[0m[0;32m--> 136 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    137 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    138 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(137)[0;36mforward[0;34m()[0m
[0;32m    135 [0;31m        return EncoderOutput(
[0m[0;32m    136 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 137 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    138 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    139 [0;31m        [0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(138)[0;36mforward[0;34m()[0m
[0;32m    136 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    137 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 138 [0;31m            [0mmeta_repr[0m[0;34m=[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    139 [0;31m        [0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    140 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(135)[0;36mforward[0;34m()[0m
[0;32m    133 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_fused[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    134 [0;31m[0;34m[0m[0m
[0m[0;32m--> 135 [0;31m        return EncoderOutput(
[0m[0;32m    136 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    137 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


--Return--
EncoderOutput...xBackward0>)})
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(135)[0;36mforward[0;34m()[0m
[0;32m    133 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_fused[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    134 [0;31m[0;34m[0m[0m
[0m[0;32m--> 135 [0;31m        return EncoderOutput(
[0m[0;32m    136 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    137 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
EncoderOutput...xBackward0>)})
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1750)[0;36m_call_impl[0;34m()[0m
[0;32m   1748 [0;31m                [0;32mor[0m [0m_global_backward_pre_hooks[0m [0;32mor[0m [0m_global_backward_hooks[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1749 [0;31m                [0;32mor[0m [0m_global_forward_hooks[0m [0;32mor[0m [0m_global_forward_pre_hooks[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1750 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1751 [0;31m[0;34m[0m[0m
[0m[0;32m   1752 [0;31m        [0mresult[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
EncoderOutput...xBackward0>)})
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1739)[0;36m_wrapped_call_impl[0;34m()[0m
[0;32m   1737 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_compiled_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m  [0;31m# type: ignore[misc][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1738 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1739 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1740 [0;31m[0;34m[0m[0m
[0m[0;32m   1741 [0;31m    [0;31m# torchrec tests the code consistency with the following code[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(52)[0;36mforward[0;34m()[0m
[0;32m     50 [0;31m                         [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mdata_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m[0;34m[0m[0m
[0m[0;32m---> 52 [0;31m        [0mloss[0m [0;34m=[0m [0;32mNone[0m[0;34m;[0m [0mlbl2data_o[0m [0;34m=[0m [0mEncoderOutput[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     53 [0;31m        if (
[0m[0;32m     54 [0;31m            [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(54)[0;36mforward[0;34m()[0m
[0;32m     52 [0;31m        [0mloss[0m [0;34m=[0m [0;32mNone[0m[0;34m;[0m [0mlbl2data_o[0m [0;34m=[0m [0mEncoderOutput[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     53 [0;31m        if (
[0m[0;32m---> 54 [0;31m            [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;31m            [0mneg2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     56 [0;31m        [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(55)[0;36mforward[0;34m()[0m
[0;32m     53 [0;31m        if (
[0m[0;32m     54 [0;31m            [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 55 [0;31m            [0mneg2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     56 [0;31m        [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     57 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(57)[0;36mforward[0;34m()[0m
[0;32m     55 [0;31m            [0mneg2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     56 [0;31m        [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 57 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     58 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     59 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mlbl2data_me

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(58)[0;36mforward[0;34m()[0m
[0;32m     56 [0;31m        [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     57 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 58 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     59 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mlbl2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(59)[0;36mforward[0;34m()[0m
[0;32m     57 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     58 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m---> 59 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mlbl2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m     61 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_met

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(58)[0;36mforward[0;34m()[0m
[0;32m     56 [0;31m        [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     57 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 58 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     59 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mlbl2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(59)[0;36mforward[0;34m()[0m
[0;32m     57 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     58 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m---> 59 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mlbl2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m     61 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_met

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(58)[0;36mforward[0;34m()[0m
[0;32m     56 [0;31m        [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     57 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'lbl2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 58 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     59 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mlbl2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(119)[0;36mforward[0;34m()[0m
[0;32m    117 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    118 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m2[0;32m-> 119 [0;31m        [0mdata_o[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    120 [0;31m[0;34m[0m[0m
[0m[0;32m    121 [0;31m        [0;32mif[0m [0mdata_type[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mdata_type[0m [0;34m==[0m [0;34m"meta"[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  r


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(135)[0;36mforward[0;34m()[0m
[0;32m    133 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_fused[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    134 [0;31m[0;34m[0m[0m
[0m[0;32m--> 135 [0;31m        return EncoderOutput(
[0m[0;32m    136 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    137 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1750)[0;36m_call_impl[0;34m()[0m
[0;32m   1748 [0;31m                [0;32mor[0m [0m_global_backward_pre_hooks[0m [0;32mor[0m [0m_global_backward_hooks[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1749 [0;31m                [0;32mor[0m [0m_global_forward_hooks[0m [0;32mor[0m [0m_global_forward_pre_hooks[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1750 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1751 [0;31m[0;34m[0m[0m
[0m[0;32m   1752 [0;31m        [0mresult[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1739)[0;36m_wrapped_call_impl[0;34m()[0m
[0;32m   1737 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_compiled_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m  [0;31m# type: ignore[misc][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1738 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1739 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1740 [0;31m[0;34m[0m[0m
[0m[0;32m   1741 [0;31m    [0;31m# torchrec tests the code consistency with the following code[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(61)[0;36mforward[0;34m()[0m
[0;32m     59 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mlbl2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m---> 61 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m     63 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0m

ipdb>  neg2data_meta_kwargs


*** NameError: name 'neg2data_meta_kwargs' is not defined


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(62)[0;36mforward[0;34m()[0m
[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m     61 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 62 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m     63 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(63)[0;36mforward[0;34m()[0m
[0;32m     61 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m---> 63 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m[0;34m[0m[0m
[0m[0;32m     65 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.re

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(62)[0;36mforward[0;34m()[0m
[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m     61 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 62 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m     63 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(63)[0;36mforward[0;34m()[0m
[0;32m     61 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     62 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m---> 63 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m[0;34m[0m[0m
[0m[0;32m     65 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.re

ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(62)[0;36mforward[0;34m()[0m
[0;32m     60 [0;31m[0;34m[0m[0m
[0m[0;32m     61 [0;31m            [0mneg2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0;34m'neg2data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 62 [0;31m            neg2data_o = encoder(data_input_ids=neg2data_input_ids, data_attention_mask=neg2data_attention_mask, 
[0m[0;32m     63 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(119)[0;36mforward[0;34m()[0m
[0;32m    117 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    118 [0;31m    [0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m2[0;32m-> 119 [0;31m        [0mdata_o[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencode[0m[0;34m([0m[0mdata_input_ids[0m[0;34m,[0m [0mdata_attention_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    120 [0;31m[0;34m[0m[0m
[0m[0;32m    121 [0;31m        [0;32mif[0m [0mdata_type[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mdata_type[0m [0;34m==[0m [0;34m"meta"[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  r


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/1768784971.py[0m(135)[0;36mforward[0;34m()[0m
[0;32m    133 [0;31m                [0mdata_fused_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mdr_fused[0m[0;34m([0m[0mdata_fused_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    134 [0;31m[0;34m[0m[0m
[0m[0;32m--> 135 [0;31m        return EncoderOutput(
[0m[0;32m    136 [0;31m            [0mrep[0m[0;34m=[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    137 [0;31m            [0mfused_rep[0m[0;34m=[0m[0mdata_fused_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1750)[0;36m_call_impl[0;34m()[0m
[0;32m   1748 [0;31m                [0;32mor[0m [0m_global_backward_pre_hooks[0m [0;32mor[0m [0m_global_backward_hooks[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1749 [0;31m                [0;32mor[0m [0m_global_forward_hooks[0m [0;32mor[0m [0m_global_forward_pre_hooks[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1750 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1751 [0;31m[0;34m[0m[0m
[0m[0;32m   1752 [0;31m        [0mresult[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


--Return--
EncoderOutput...eta_repr=None)
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1739)[0;36m_wrapped_call_impl[0;34m()[0m
[0;32m   1737 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_compiled_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m  [0;31m# type: ignore[misc][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1738 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1739 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1740 [0;31m[0;34m[0m[0m
[0m[0;32m   1741 [0;31m    [0;31m# torchrec tests the code consistency with the following code[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(65)[0;36mforward[0;34m()[0m
[0;32m     63 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
[0m[0;32m     66 [0;31m                                     [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(66)[0;36mforward[0;34m()[0m
[0;32m     64 [0;31m[0;34m[0m[0m
[0m[0;32m     65 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
[0m[0;32m---> 66 [0;31m                                     [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m     68 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(65)[0;36mforward[0;34m()[0m
[0;32m     63 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
[0m[0;32m     66 [0;31m                                     [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(66)[0;36mforward[0;34m()[0m
[0;32m     64 [0;31m[0;34m[0m[0m
[0m[0;32m     65 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
[0m[0;32m---> 66 [0;31m                                     [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m     68 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(65)[0;36mforward[0;34m()[0m
[0;32m     63 [0;31m                                 [0mdata_aug_meta_prefix[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mneg2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mneg2data_meta_kwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
[0m[0;32m     66 [0;31m                                     [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(68)[0;36mforward[0;34m()[0m
[0;32m     66 [0;31m                                     [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m---> 68 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     69 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
[0m[0;32m     70 [0;31m                                         [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  loss


tensor(0.0026, grad_fn=<MseLossBackward0>)


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(69)[0;36mforward[0;34m()[0m
[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m     68 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 69 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
[0m[0;32m     70 [0;31m                                         [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(70)[0;36mforward[0;34m()[0m
[0;32m     68 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     69 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
[0m[0;32m---> 70 [0;31m                                         [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(69)[0;36mforward[0;34m()[0m
[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m     68 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 69 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
[0m[0;32m     70 [0;31m                                         [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(70)[0;36mforward[0;34m()[0m
[0;32m     68 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     69 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
[0m[0;32m---> 70 [0;31m                                         [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(69)[0;36mforward[0;34m()[0m
[0;32m     67 [0;31m[0;34m[0m[0m
[0m[0;32m     68 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 69 [0;31m                loss += self.compute_loss(data_o.rep, lbl2data_o.rep, lbl2data_scores, neg2data_o.rep, 
[0m[0;32m     70 [0;31m                                         [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(72)[0;36mforward[0;34m()[0m
[0;32m     70 [0;31m                                         [0mneg2data_scores[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx,
[0m[0;32m     74 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  self.use_calib_loss


False


ipdb>  self.use_calib_loss = True
ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(73)[0;36mforward[0;34m()[0m
[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 73 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx,
[0m[0;32m     74 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, neg2data_o.rep, neg2data_data2ptr, neg2data_idx,
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(74)[0;36mforward[0;34m()[0m
[0;32m     72 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx,
[0m[0;32m---> 74 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, neg2data_o.rep, neg2data_data2ptr, neg2data_idx,
[0m[0;32m     76 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(73)[0;36mforward[0;34m()[0m
[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m     72 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 73 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx,
[0m[0;32m     74 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, neg2data_o.rep, neg2data_data2ptr, neg2data_idx,
[0m


ipdb>  


  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx,
[0m[0;32m     74 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, neg2data_o.rep, neg2data_data2ptr, neg2data_idx,
[0m[0;32m     76 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(76)[0;36mforward[0;34m()[0m
[0;32m     74 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     75 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, neg2data_o.rep, neg2data_data2ptr, neg2data_idx,
[0m[0;32m---> 76 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m[0;32m     78 [0;31m        [0;32mif[0m [0;32mnot[0m [0mreturn_dict[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx,
[0m[0;32m     74 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, neg2data_o.rep, neg2data_data2ptr, neg2data_idx,
[0m[0;32m     76 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m


ipdb>  


IndexError: index 34 is out of bounds for dimension 0 with size 34
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx,
[0m[0;32m     74 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, neg2data_o.rep, neg2data_data2ptr, neg2data_idx,
[0m[0;32m     76 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m


ipdb>  


--Return--
None
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/2665181823.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep, lbl2data_data2ptr, lbl2data_idx,
[0m[0;32m     74 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 75 [0;31m                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, neg2data_o.rep, neg2data_data2ptr, neg2data_idx,
[0m[0;32m     76 [0;31m                                              [0mplbl2data_data2ptr[0m[0;34m,[0m[0mplbl2data_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     77 [0;31m[0;34m[0m[0m
[0m


ipdb>  


IndexError: index 34 is out of bounds for dimension 0 with size 34
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1750)[0;36m_call_impl[0;34m()[0m
[0;32m   1748 [0;31m                [0;32mor[0m [0m_global_backward_pre_hooks[0m [0;32mor[0m [0m_global_backward_hooks[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1749 [0;31m                [0;32mor[0m [0m_global_forward_hooks[0m [0;32mor[0m [0m_global_forward_pre_hooks[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1750 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1751 [0;31m[0;34m[0m[0m
[0m[0;32m   1752 [0;31m        [0mresult[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
None
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1750)[0;36m_call_impl[0;34m()[0m
[0;32m   1748 [0;31m                [0;32mor[0m [0m_global_backward_pre_hooks[0m [0;32mor[0m [0m_global_backward_hooks[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1749 [0;31m                [0;32mor[0m [0m_global_forward_hooks[0m [0;32mor[0m [0m_global_forward_pre_hooks[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1750 [0;31m            [0;32mreturn[0m [0mforward_call[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1751 [0;31m[0;34m[0m[0m
[0m[0;32m   1752 [0;31m        [0mresult[0m [0;34m=[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


IndexError: index 34 is out of bounds for dimension 0 with size 34
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1739)[0;36m_wrapped_call_impl[0;34m()[0m
[0;32m   1737 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_compiled_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m  [0;31m# type: ignore[misc][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1738 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1739 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1740 [0;31m[0;34m[0m[0m
[0m[0;32m   1741 [0;31m    [0;31m# torchrec tests the code consistency with the following code[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
None
> [0;32m/Users/suchith720/miniconda3/envs/mogic/lib/python3.13/site-packages/torch/nn/modules/module.py[0m(1739)[0;36m_wrapped_call_impl[0;34m()[0m
[0;32m   1737 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_compiled_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m  [0;31m# type: ignore[misc][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1738 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1739 [0;31m            [0;32mreturn[0m [0mself[0m[0;34m.[0m[0m_call_impl[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1740 [0;31m[0;34m[0m[0m
[0m[0;32m   1741 [0;31m    [0;31m# torchrec tests the code consistency with the following code[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


IndexError: index 34 is out of bounds for dimension 0 with size 34
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/3721260802.py[0m(3)[0;36mfunc[0;34m()[0m
[0;32m      1 [0;31m[0;32mdef[0m [0mfunc[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 3 [0;31m    [0mo[0m [0;34m=[0m [0mmodel[0m[0;34m([0m[0;34m**[0m[0mb[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mmodel[0m[0;34m.[0m[0mdevice[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
None
> [0;32m/var/folders/x0/r_2wlyls39s3_q5g99w33dn80000gn/T/ipykernel_22918/3721260802.py[0m(3)[0;36mfunc[0;34m()[0m
[0;32m      1 [0;31m[0;32mdef[0m [0mfunc[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 3 [0;31m    [0mo[0m [0;34m=[0m [0mmodel[0m[0;34m([0m[0;34m**[0m[0mb[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mmodel[0m[0;34m.[0m[0mdevice[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m[0;34m[0m[0m
[0m


ipdb>  o


XCModelOutput(loss=tensor(0.0051, grad_fn=<AddBackward0>), logits=None, data_repr=tensor([[-0.0278, -0.0147, -0.0297,  ..., -0.0008,  0.0018, -0.0341],
        [-0.0027, -0.0311, -0.0260,  ...,  0.0229, -0.0220, -0.0337],
        [-0.0068, -0.0017, -0.0149,  ...,  0.0273,  0.0346,  0.0232],
        [-0.0346, -0.0148, -0.0316,  ...,  0.0544,  0.0028, -0.0312]],
       grad_fn=<DivBackward0>), data_fused_repr=tensor([[-0.0270, -0.0149, -0.0287,  ..., -0.0018,  0.0008, -0.0328],
        [-0.0035, -0.0299, -0.0253,  ...,  0.0212, -0.0216, -0.0323],
        [-0.0075, -0.0026, -0.0150,  ...,  0.0257,  0.0329,  0.0216],
        [-0.0331, -0.0150, -0.0304,  ...,  0.0532,  0.0017, -0.0300]],
       grad_fn=<DivBackward0>), lbl2data_repr=tensor([[-0.0252,  0.0914, -0.0155,  ..., -0.0417,  0.0890, -0.0145],
        [ 0.0107, -0.0238, -0.0349,  ...,  0.0151, -0.0291, -0.0294],
        [-0.0068, -0.0017, -0.0149,  ...,  0.0273,  0.0346,  0.0232],
        [-0.0244, -0.0302, -0.0259,  ...,  0.0584,  

ipdb>  o.loss


tensor(0.0051, grad_fn=<AddBackward0>)


ipdb>  c


[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m



IndexError: index 34 is out of bounds for dimension 0 with size 34

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

