In [1]:
#| default_exp models.sandwich

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [4]:
#| export
import torch, torch.nn as nn, torch.nn.functional as F, re
from typing import Optional, Dict, Tuple
from dataclasses import dataclass

from transformers.utils.generic import ModelOutput
from transformers import PretrainedConfig, DistilBertConfig, DistilBertPreTrainedModel, DistilBertModel
from transformers.models.distilbert.modeling_distilbert import create_sinusoidal_embeddings, TransformerBlock

from xcai.losses import *
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

In [5]:
#| export
from xcai.models.product_key import *
from xcai.core import store_attr

## Load data

In [6]:
from xcai.main import *

In [7]:
output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/03_ngame-for-wikiseealsotitles'

data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/'
config_file = 'wikiseealsotitles'
config_key = 'data_meta'

mname = 'sentence-transformers/msmarco-distilbert-base-v4'

In [8]:
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'

In [9]:
pkl_file = f'{pkl_dir}/mogicX/wikiseealsotitles_data-meta_distilbert-base-uncased_sxc.joblib'

In [12]:
block = build_block(pkl_file, config_file, True, config_key, data_dir=data_dir, n_sdata_meta_samples=3, do_build=True)



In [13]:
batch = block.train.one_batch(100)

In [14]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr'])

In [15]:
from transformers import DistilBertModel
m = DistilBertModel.from_pretrained('distilbert-base-uncased')

In [16]:
o = m(input_ids=batch['data_input_ids'], attention_mask=batch['data_attention_mask'])
o = Pooling.mean_pooling(o.last_hidden_state, batch['data_attention_mask'])

In [24]:
o

tensor([[ 0.1333, -0.0122, -0.4621,  ..., -0.2476,  0.0845,  0.1799],
        [ 0.3242,  0.1735, -0.3537,  ...,  0.0098, -0.1589, -0.0305],
        [ 0.1397,  0.1279, -0.5648,  ...,  0.0695,  0.0088,  0.1104],
        ...,
        [ 0.2267, -0.1360, -0.0137,  ...,  0.0121, -0.0757,  0.2140],
        [ 0.3148,  0.0690, -0.1643,  ..., -0.0073, -0.0534, -0.1919],
        [ 0.2646,  0.3013, -0.5160,  ..., -0.0020,  0.2130,  0.0120]],
       grad_fn=<DivBackward0>)

## Parameters

In [194]:
#| export
class Parameters:
    
    @staticmethod
    def from_data_aug_meta_prefix_for_encoder(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^{prefix}.*_(input_ids|attention_mask|data2ptr|idx)$', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs
    
    @staticmethod
    def from_aug_meta_prefix_for_feature(feat:str, prefix:str, **kwargs):
        keys = ['attention_mask', 'input_ids', 'idx']        
        inputs = {f'{prefix}_{k}': kwargs[f'{prefix}_{k}'] for k in keys if f'{prefix}_{k}' in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            inputs.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        return inputs

    @staticmethod
    def from_aug_meta_prefix_for_loss(feat:str, prefix:str, **kwargs):
        keys = [f'{prefix}_idx', f'p{prefix}_idx']
        args = {k: kwargs[k] for k in keys if k in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            args.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        if prefix is not None and f'p{prefix}_{feat}2ptr' in kwargs:
            args.update({f'p{prefix}_data2ptr': kwargs[f'p{prefix}_{feat}2ptr']})

        inputs = {}
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = args[arg]
        return inputs
        

### Example

In [47]:
params = Parameters.from_aug_meta_prefix_for_loss('lbl', 'cat2lbl', **batch)

In [48]:
params

{'cat2lbl': {'idx': tensor([ 67082,  79350, 101577,  71494, 169643,  68324, 135330, 125763,  82231,
           86188, 155467,  28048, 148694,  93180,  85482, 138151, 102793,  84865,
          439022, 439023,  91528,  74225, 141645, 289094, 170004, 488572,  56283,
          129962,  99971, 174524, 490420,  72240, 127050, 126452, 170813, 474981,
           46855, 217367, 179791, 402654, 105939, 130669,  81409,  74830,  70044,
          121141,  62139, 121002,  54379,  68123,  74898, 166643, 166550,  75601,
           68691,  61761,    656, 128580, 190220, 500130,  62668,  71775, 538327,
          131647, 102522,  62594, 114572, 321592,  83631,  53106, 292602, 478961,
           84872, 561039,  77674, 157101, 463424, 157434, 499332, 350836, 426850,
           72514, 490081,  68235,  57133,   9952,  70022,  77811,  64823,  79370,
          167418,  65653, 117839,  74213,  57745, 271915]),
  'data2ptr': tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
        

## Configuration

In [193]:
#| export
class SandwichConfig(DistilBertConfig):

    def __init__(
        self,
        data_aug_meta_prefix:Optional[str] = None, 
        lbl2data_aug_meta_prefix:Optional[str] = None,

        data_enrich:Optional[bool] = True,
        lbl2data_enrich:Optional[bool] = True,
        
        num_batch_labels:Optional[int] = None,
        batch_size:Optional[int] = None,
        margin:Optional[float] = 0.3,
        num_negatives:Optional[int] = 10,
        tau:Optional[float] = 0.1,
        apply_softmax:Optional[bool] = True,

        use_calib_loss:Optional[float] = False,
        calib_margin:Optional[float] = 0.05,
        calib_num_negatives:Optional[int] = 10,
        calib_tau:Optional[float] = 0.1,
        calib_apply_softmax:Optional[bool] = False,
        calib_loss_weight:Optional[float] = 0.1,
        
        use_query_loss:Optional[float] = True,
        
        use_meta_loss:Optional[bool] = False,
        meta_loss_weight:Optional[float] = 0.1,
        
        use_encoder_parallel:Optional[bool] = True,
        
        **kwargs,
    ):
        self.data_aug_meta_prefix = data_aug_meta_prefix
        self.lbl2data_aug_meta_prefix = lbl2data_aug_meta_prefix

        self.data_enrich = data_enrich
        self.lbl2data_enrich = lbl2data_enrich

        self.num_batch_labels = num_batch_labels
        self.batch_size = batch_size
        self.margin = margin
        self.num_negatives = num_negatives
        self.tau = tau
        self.apply_softmax = apply_softmax

        self.use_calib_loss = use_calib_loss
        self.calib_loss_weight = calib_loss_weight
        self.calib_margin = calib_margin
        self.calib_num_negatives = calib_num_negatives
        self.calib_tau = calib_tau
        self.calib_apply_softmax = calib_apply_softmax
        
        self.use_query_loss = use_query_loss
        
        self.use_meta_loss = use_meta_loss
        self.meta_loss_weight = meta_loss_weight

        self.use_encoder_parallel = use_encoder_parallel
        
        super().__init__(**kwargs)
        

### Example

In [57]:
config = SandwichConfig()

In [58]:
config.meta_loss_weight

0.1

## Combiner

In [192]:
#| export
class CrossCombinerBlock(TransformerBlock):

    def __init__(self, config: PretrainedConfig):
        super().__init__(config)

    def post_init(self):
        for module in self.modules(): self._init_weights(module)

    def _initialize_weights(self, module: nn.Module):
        for m in module.modules(): self._init_weights(m)

    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            torch.nn.init.eye_(module.weight)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(
        self,
        x: torch.Tensor,
        m: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, ...]:
        
        # Cross-Attention
        ca_output = self.attention(
            query=x,
            key=m,
            value=m,
            mask=attn_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )
        if output_attentions:
            ca_output, ca_weights = ca_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
            if type(ca_output) is not tuple:
                raise TypeError(f"ca_output must be a tuple but it is {type(ca_output)} type")

            ca_output = ca_output[0]
        ca_output = self.sa_layer_norm(ca_output + x)  # (bs, seq_length, dim)

        # Feed Forward Network
        ffn_output = self.ffn(ca_output)  # (bs, seq_length, dim)
        ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + ca_output)  # (bs, seq_length, dim)

        output = (ffn_output,)
        if output_attentions:
            output = (ca_weights,) + output
        return output
        

## BaseEncoder

In [195]:
#| export
@dataclass
class EncoderOutput(ModelOutput):
    data_repr: Optional[torch.FloatTensor] = None
    data_meta_repr: Optional[torch.FloatTensor] = None
    enriched_data_repr: Optional[torch.FloatTensor] = None
    meta_repr: Optional[torch.FloatTensor] = None
    

In [196]:
#| export
class BaseEncoder(DistilBertPreTrainedModel):
    
    config_class= None

    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        self.meta_distilbert = DistilBertModel(config)
        
        self.query_head = RepresentationHead(config)
        self.meta_query_head = RepresentationHead(config)
        self.enriched_query_head = RepresentationHead(config)
        
        self.post_init()

    @torch.no_grad()
    def init_heads_to_identity(self):
        self.query_head.post_init()
        self.meta_query_head.post_init()
        self.enriched_query_head.post_init()

    @torch.no_grad()
    def init_meta_distilbert(self):
        sd, msd = self.distilbert.state_dict(), self.meta_distilbert.state_dict()
        sd_keys, msd_keys = sd.keys(), msd.keys()
        assert len(sd_keys) == len(msd_keys), f'mismatched keys: {len(sd_keys)} != {len(msd_keys)}'
        for k in sd_keys:
            assert sd[k].shape == msd[k].shape
            msd[k].copy_(sd[k])

    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )

    def encode_meta(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.meta_distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
    def encode_query(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.query_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def encode_meta_query(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_query_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def encode_enriched_query(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.enriched_query_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def enrich_query_representation(self):
        raise NotImplementedError("Override this method in a subclass.")

    def forward(self, data_input_ids: torch.Tensor, data_attention_mask: torch.Tensor, data_aug_meta_prefix: Optional[str]=None,
                data_enrich: Optional[bool]=True, **kwargs):  
        raise NotImplementedError("Override this method in a subclass.")
    

## Encoder

In [197]:
#| export
class Encoder(BaseEncoder):
    
    config_class = SandwichConfig
    
    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.combiner_head = CrossCombinerBlock(config)
        self.post_init()

    @torch.no_grad()
    def init_heads_to_identity(self):
        self.combiner_head.post_init()
        super().init_heads_to_identity()

    @torch.no_grad()
    def init_combiner_to_last_layer(self):
        lsd, csd = self.distilbert.transformer.layer[-1].state_dict(), self.combiner_head.state_dict()
        lsd_keys, csd_keys = lsd.keys(), csd.keys()
        assert len(lsd_keys) == len(csd_keys), f'mismatched keys: {len(lsd_keys)} != {len(csd_keys)}'
        for k in csd_keys:
            assert csd[k].shape == lsd[k].shape
            csd[k].copy_(lsd[k])

    @torch.no_grad()
    def init_meta_distilbert(self):
        super().init_meta_distilbert()

    def enrich_query_representation(self, data_o:torch.Tensor, data_meta_o:torch.Tensor, data_attention_mask:torch.Tensor):
        attn_mask = data_attention_mask.view(len(data_attention_mask), 1, 1, -1).bool()
        fusion_o = self.combiner_head(x=data_o, m=data_meta_o, attn_mask=attn_mask)[0]
        enriched_data_repr = self.encode_enriched_query(data_o + fusion_o, data_attention_mask)
        return enriched_data_repr

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_enrich: Optional[bool]=True,
        **kwargs
    ):
        data_o = self.encode(data_input_ids, data_attention_mask)
        data_repr = self.encode_query(data_o[0], data_attention_mask)

        data_meta_o = self.encode_meta(data_input_ids, data_attention_mask)
        data_meta_repr = self.encode_meta_query(data_meta_o[0], data_attention_mask)
        
        meta_kwargs = Parameters.from_data_aug_meta_prefix_for_encoder(data_aug_meta_prefix, **kwargs)
        meta_kwargs = meta_kwargs.get(data_aug_meta_prefix, None)

        meta_repr = torch.zeros(0, len(data_repr), device=data_repr.device, dtype=data_repr.dtype)
        if meta_kwargs is not None and len(meta_kwargs['idx']):
            meta_o = self.encode_meta(meta_kwargs['input_ids'], meta_kwargs['attention_mask'])
            meta_repr = self.encode_meta_query(meta_o[0], meta_kwargs['attention_mask'])

        enriched_data_repr = (
            self.enrich_query_representation(data_o[0], data_meta_o[0], data_attention_mask) 
            if data_enrich else torch.zeros(0, len(data_repr), device=data_repr.device, dtype=data_repr.dtype)
        )
        
        return EncoderOutput(
            data_repr=data_repr,
            data_meta_repr=data_meta_repr,
            enriched_data_repr=enriched_data_repr,
            meta_repr=meta_repr,
        )
        

### Example

In [109]:
config = SandwichConfig()
m = Encoder(config)

In [110]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr'])

In [111]:
data_aug_meta_prefix='cat2data'
output = m(**batch, data_aug_meta_prefix=data_aug_meta_prefix)

In [112]:
output

EncoderOutput(data_repr=tensor([[-0.0225,  0.0191, -0.0021,  ..., -0.0053, -0.0118,  0.0194],
        [-0.0146, -0.0108, -0.0409,  ..., -0.0205, -0.0109,  0.0028],
        [-0.0462, -0.0089, -0.0374,  ..., -0.0225, -0.0240, -0.0300],
        ...,
        [-0.0383, -0.0020, -0.0393,  ...,  0.0179, -0.0266,  0.0131],
        [-0.0453, -0.0005,  0.0191,  ..., -0.0426,  0.0021,  0.0176],
        [-0.0665,  0.0332, -0.0013,  ..., -0.0054, -0.0349,  0.0024]],
       grad_fn=<DivBackward0>), data_meta_repr=tensor([[-0.0594,  0.0056, -0.0344,  ..., -0.0204,  0.0081,  0.0007],
        [-0.0206, -0.0171, -0.0520,  ..., -0.0341,  0.0685,  0.0280],
        [ 0.0005, -0.0642, -0.0187,  ..., -0.0258,  0.0917, -0.0041],
        ...,
        [-0.0019, -0.0299, -0.0214,  ..., -0.0219,  0.0619,  0.0317],
        [-0.0137, -0.0348, -0.0467,  ..., -0.0402,  0.0678,  0.0130],
        [-0.0162, -0.0277, -0.0211,  ..., -0.0462,  0.0306, -0.0103]],
       grad_fn=<DivBackward0>), enriched_data_repr=tensor([[ 

## `SAW000`

In [198]:
#| export
@dataclass
class SAWModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    data_repr: Optional[torch.FloatTensor] = None
    data_enriched_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    lbl2data_enriched_repr: Optional[torch.FloatTensor] = None
    

In [199]:
#| export
class SAW000(nn.Module):

    config_class = SandwichConfig
    
    def __init__(
        self, 
        config: SandwichConfig,
    ):
        super().__init__(config)
        self.config, self.encoder = config, None
        self.rep_loss_fn = MultiTriplet(margin=config.margin, n_negatives=config.num_negatives, tau=config.tau, 
                                        apply_softmax=config.apply_softmax, reduce='mean')
        self.cab_loss_fn = Calibration(margin=config.calib_margin, tau=config.calib_tau, n_negatives=config.calib_num_negatives, 
                                       apply_softmax=config.calib_apply_softmax, reduce='mean')
        
    def init_heads_to_identity(self):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.init_heads_to_identity()

    def init_combiner_to_last_layer(self):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.init_combiner_to_last_layer()

    def init_meta_distilbert(self):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.init_meta_distilbert()
        
    def compute_loss(self, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def calibration_loss(self, einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.config.calib_loss_weight * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def compute_meta_loss(self, data_o, lbl2data_o, **kwargs):
        loss = 0.0
        meta_kwargs = Parameters.from_aug_meta_prefix_for_loss('data', self.config.data_aug_meta_prefix, **kwargs)
        prefix = self.config.data_aug_meta_prefix
        if meta_kwargs is not None and len(meta_kwargs[prefix]['idx']):
            idx = torch.where(meta_kwargs[prefix]['data2ptr'] > 0)[0]
            loss += self.config.meta_loss_weight * self.compute_loss(data_o.data_meta_repr[idx], 
                                                                     data_o.meta_repr,
                                                                     meta_kwargs[prefix]['data2ptr'][idx],
                                                                     meta_kwargs[prefix]['idx'],
                                                                     meta_kwargs[f'p{prefix}']['data2ptr'][idx],
                                                                     meta_kwargs[f'p{prefix}']['idx'])
            
        meta_kwargs = Parameters.from_aug_meta_prefix_for_loss('lbl', self.config.lbl2data_aug_meta_prefix, **kwargs)
        prefix = self.config.lbl2data_aug_meta_prefix
        if meta_kwargs is not None and len(meta_kwargs[prefix]['idx']):
            idx = torch.where(meta_kwargs[prefix]['data2ptr'] > 0)[0]
            loss += self.config.meta_loss_weight * self.compute_loss(lbl2data_o.data_meta_repr[idx], 
                                                                     lbl2data_o.meta_repr,
                                                                     meta_kwargs[prefix]['data2ptr'][idx],
                                                                     meta_kwargs[prefix]['idx'],
                                                                     meta_kwargs[f'p{prefix}']['data2ptr'][idx],
                                                                     meta_kwargs[f'p{prefix}']['idx'])
        return loss

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ): 
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.config.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_aug_meta_prefix_for_feature('data', self.config.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_enrich=self.config.data_enrich, **data_meta_kwargs)
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_aug_meta_prefix_for_feature('lbl', self.config.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_enrich=self.config.lbl2data_enrich, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.enriched_data_repr, lbl2data_o.enriched_data_repr,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.config.use_query_loss:
                loss += self.compute_loss(data_o.data_repr, lbl2data_o.data_repr,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.config.use_calib_loss:
                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.data_repr,
                                              lbl2data_data2ptr,lbl2data_idx,plbl2data_data2ptr,plbl2data_idx)
                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.enriched_data_repr,
                                              lbl2data_data2ptr,lbl2data_idx, plbl2data_data2ptr,plbl2data_idx)

            if self.config.use_meta_loss:
                loss += self.compute_meta_loss(data_o, lbl2data_o, **kwargs)
            
        if not return_dict:
            o = (data_o.data_repr,data_o.enriched_data_repr,lbl2data_o.data_repr,lbl2data_o.enriched_data_repr)
            return ((loss,) + o) if loss is not None else o
        
        return SAWModelOutput(
            loss=loss,
            data_repr=data_o.data_repr,
            data_enriched_repr=data_o.enriched_data_repr,
            lbl2data_repr=lbl2data_o.data_repr,
            lbl2data_enriched_repr=lbl2data_o.enriched_data_repr,
        )
        

## `SAW001`

In [200]:
#| export
class SAW001(SAW000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    def __init__(self, config):
        super().__init__(config)
        self.encoder = Encoder(config)
        
        self.post_init()
        self.remap_post_init()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [201]:
config = SandwichConfig(
    data_aug_meta_prefix='cat2data', 
    lbl2data_aug_meta_prefix='cat2lbl',

    data_enrich=True,
    lbl2data_enrich=True,

    batch_size=100, 
    num_batch_labels=5000, 
    margin=0.3,
    num_negatives=5,
    tau=0.1,
    apply_softmax=True,

    use_calib_loss=True,
    calib_loss_weight=0.1,
    calib_margin=0.05,
    calib_num_negatives=10,
    calib_tau=0.1,
    calib_apply_softmax=False,

    use_query_loss=True,

    use_meta_loss=False,
    meta_loss_weight=0.1,
    
    use_encoder_parallel=False
)

In [202]:
model = SAW001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', config=config)

Some weights of SAW001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.combiner_head.attention.k_lin.bias', 'encoder.combiner_head.attention.k_lin.weight', 'encoder.combiner_head.attention.out_lin.bias', 'encoder.combiner_head.attention.out_lin.weight', 'encoder.combiner_head.attention.q_lin.bias', 'encoder.combiner_head.attention.q_lin.weight', 'encoder.combiner_head.attention.v_lin.bias', 'encoder.combiner_head.attention.v_lin.weight', 'encoder.combiner_head.ffn.lin1.bias', 'encoder.combiner_head.ffn.lin1.weight', 'encoder.combiner_head.ffn.lin2.bias', 'encoder.combiner_head.ffn.lin2.weight', 'encoder.combiner_head.output_layer_norm.bias', 'encoder.combiner_head.output_layer_norm.weight', 'encoder.combiner_head.sa_layer_norm.bias', 'encoder.combiner_head.sa_layer_norm.weight', 'encoder.enriched_query_head.layer_norm.bias', 'encoder.enriched_query_head.layer_norm.weight', 'encoder.enriched_query_he

In [203]:
model.init_heads_to_identity()

In [204]:
model.init_combiner_to_last_layer()

In [205]:
model.init_meta_distilbert()

In [206]:
output = model(**batch)

In [185]:
from xcai.core import prepare_batch

In [207]:
def func():
    import pdb; pdb.set_trace()
    b = prepare_batch(model, batch, m_args=['pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_input_ids', 'cat2data_attention_mask', 
                                            'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_input_ids', 'cat2lbl_attention_mask'])
    o = model(**b)
    

In [209]:
func()

> [0;32m/tmp/ipykernel_5306/2352069300.py[0m(3)[0;36mfunc[0;34m()[0m
[0;32m      1 [0;31m[0;32mdef[0m [0mfunc[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 3 [0;31m    b = prepare_batch(model, batch, m_args=['pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_input_ids', 'cat2data_attention_mask', 
[0m[0;32m      4 [0;31m                                            'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_input_ids', 'cat2lbl_attention_mask'])
[0m[0;32m      5 [0;31m    [0mo[0m [0;34m=[0m [0mmodel[0m[0;34m([0m[0;34m**[0m[0mb[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_5306/2352069300.py[0m(5)[0;36mfunc[0;34m()[0m
[0;32m      2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      3 [0;31m    b = prepare_batch(model, batch, m_args=['pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_input_ids', 'cat2data_attention_mask', 
[0m[0;32m      4 [0;31m                                            'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_input_ids', 'cat2lbl_attention_mask'])
[0m[0;32m----> 5 [0;31m    [0mo[0m [0;34m=[0m [0mmodel[0m[0;34m([0m[0;34m**[0m[0mb[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m[0;34m[0m[0m
[0m


ipdb>  b


Num Type         Disp Enb   Where
1   breakpoint   keep no    at /tmp/ipykernel_5306/3635970643.py:58
	breakpoint already hit 4 times
2   breakpoint   keep no    at /tmp/ipykernel_5306/3710146577.py:38
	breakpoint already hit 5 times
3   breakpoint   keep no    at /tmp/ipykernel_5306/3710146577.py:32
	breakpoint already hit 3 times
4   breakpoint   keep no    at /tmp/ipykernel_5306/2717498366.py:22
	breakpoint already hit 3 times
5   breakpoint   keep yes   at /tmp/ipykernel_5306/1126467191.py:60
	breakpoint already hit 1 time


ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(75)[0;36mforward[0;34m()[0m
[0;32m     73 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m    ): 
[0m[0;32m---> 75 [0;31m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m[0;34m[0m[0m
[0m[0;32m     77 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(77)[0;36mforward[0;34m()[0m
[0;32m     75 [0;31m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m[0;34m[0m[0m
[0m[0;32m---> 77 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m            [0mencoder[0m [0;34m=[0m [0mXCDataParallel[0m[0;34m([0m[0mmodule[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mencoder[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     79 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(79)[0;36mforward[0;34m()[0m
[0;32m     77 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     78 [0;31m            [0mencoder[0m [0;34m=[0m [0mXCDataParallel[0m[0;34m([0m[0mmodule[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mencoder[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 79 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(81)[0;36mforward[0;34m()[0m
[0;32m     79 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m---> 81 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     82 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m     83 [0;31m                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_enrich=self.config.data_enrich, **data_meta_kwargs)
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(82)[0;36mforward[0;34m()[0m
[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 82 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m     83 [0;31m                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_enrich=self.config.data_enrich, **data_meta_kwargs)
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(83)[0;36mforward[0;34m()[0m
[0;32m     81 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     82 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m---> 83 [0;31m                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_enrich=self.config.data_enrich, **data_meta_kwargs)
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m[0;32m     85 [0;31m        [0mloss[0m [0;34m=[0m [0;32mNone[0m[0;34m;[0m [0mlbl2data_o[0m [0;34m=[0m [0mEncoderOutput[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(82)[0;36mforward[0;34m()[0m
[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 82 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m     83 [0;31m                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_enrich=self.config.data_enrich, **data_meta_kwargs)
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(83)[0;36mforward[0;34m()[0m
[0;32m     81 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     82 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m---> 83 [0;31m                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_enrich=self.config.data_enrich, **data_meta_kwargs)
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m[0;32m     85 [0;31m        [0mloss[0m [0;34m=[0m [0;32mNone[0m[0;34m;[0m [0mlbl2data_o[0m [0;34m=[0m [0mEncoderOutput[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(82)[0;36mforward[0;34m()[0m
[0;32m     80 [0;31m[0;34m[0m[0m
[0m[0;32m     81 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 82 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m     83 [0;31m                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_enrich=self.config.data_enrich, **data_meta_kwargs)
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(85)[0;36mforward[0;34m()[0m
[0;32m     83 [0;31m                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_enrich=self.config.data_enrich, **data_meta_kwargs)
[0m[0;32m     84 [0;31m[0;34m[0m[0m
[0m[0;32m---> 85 [0;31m        [0mloss[0m [0;34m=[0m [0;32mNone[0m[0;34m;[0m [0mlbl2data_o[0m [0;34m=[0m [0mEncoderOutput[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     86 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     87 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(86)[0;36mforward[0;34m()[0m
[0;32m     84 [0;31m[0;34m[0m[0m
[0m[0;32m     85 [0;31m        [0mloss[0m [0;34m=[0m [0;32mNone[0m[0;34m;[0m [0mlbl2data_o[0m [0;34m=[0m [0mEncoderOutput[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 86 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     87 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     88 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(87)[0;36mforward[0;34m()[0m
[0;32m     85 [0;31m        [0mloss[0m [0;34m=[0m [0;32mNone[0m[0;34m;[0m [0mlbl2data_o[0m [0;34m=[0m [0mEncoderOutput[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     86 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 87 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     88 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     89 [0;31m                                 data_aug_meta_prefix=self

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(88)[0;36mforward[0;34m()[0m
[0;32m     86 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     87 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 88 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     89 [0;31m                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_enrich=self.config.lbl2data_enrich, **lbl2data_meta_kwargs)
[0m[0;32m     90 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(89)[0;36mforward[0;34m()[0m
[0;32m     87 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     88 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m---> 89 [0;31m                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_enrich=self.config.lbl2data_enrich, **lbl2data_meta_kwargs)
[0m[0;32m     90 [0;31m[0;34m[0m[0m
[0m[0;32m     91 [0;31m            loss = self.compute_loss(data_o.enriched_data_repr, lbl2data_o.enriched_data_repr,lbl2data_data2ptr,lbl2data_idx,
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(88)[0;36mforward[0;34m()[0m
[0;32m     86 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     87 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 88 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     89 [0;31m                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_enrich=self.config.lbl2data_enrich, **lbl2data_meta_kwargs)
[0m[0;32m     90 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(89)[0;36mforward[0;34m()[0m
[0;32m     87 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     88 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m---> 89 [0;31m                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_enrich=self.config.lbl2data_enrich, **lbl2data_meta_kwargs)
[0m[0;32m     90 [0;31m[0;34m[0m[0m
[0m[0;32m     91 [0;31m            loss = self.compute_loss(data_o.enriched_data_repr, lbl2data_o.enriched_data_repr,lbl2data_data2ptr,lbl2data_idx,
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(88)[0;36mforward[0;34m()[0m
[0;32m     86 [0;31m        [0;32mif[0m [0mlbl2data_input_ids[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     87 [0;31m            [0mlbl2data_meta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_feature[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 88 [0;31m            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
[0m[0;32m     89 [0;31m                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_enrich=self.config.lbl2data_enrich, **lbl2data_meta_kwargs)
[0m[0;32m     90 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(91)[0;36mforward[0;34m()[0m
[0;32m     89 [0;31m                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_enrich=self.config.lbl2data_enrich, **lbl2data_meta_kwargs)
[0m[0;32m     90 [0;31m[0;34m[0m[0m
[0m[0;32m---> 91 [0;31m            loss = self.compute_loss(data_o.enriched_data_repr, lbl2data_o.enriched_data_repr,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     92 [0;31m                                     plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     93 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(92)[0;36mforward[0;34m()[0m
[0;32m     90 [0;31m[0;34m[0m[0m
[0m[0;32m     91 [0;31m            loss = self.compute_loss(data_o.enriched_data_repr, lbl2data_o.enriched_data_repr,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m---> 92 [0;31m                                     plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     93 [0;31m[0;34m[0m[0m
[0m[0;32m     94 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(91)[0;36mforward[0;34m()[0m
[0;32m     89 [0;31m                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_enrich=self.config.lbl2data_enrich, **lbl2data_meta_kwargs)
[0m[0;32m     90 [0;31m[0;34m[0m[0m
[0m[0;32m---> 91 [0;31m            loss = self.compute_loss(data_o.enriched_data_repr, lbl2data_o.enriched_data_repr,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     92 [0;31m                                     plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     93 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(94)[0;36mforward[0;34m()[0m
[0;32m     92 [0;31m                                     plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     93 [0;31m[0;34m[0m[0m
[0m[0;32m---> 94 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     95 [0;31m                loss += self.compute_loss(data_o.data_repr, lbl2data_o.data_repr,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     96 [0;31m                                          plbl2data_data2ptr,plbl2data_idx)
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(95)[0;36mforward[0;34m()[0m
[0;32m     93 [0;31m[0;34m[0m[0m
[0m[0;32m     94 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 95 [0;31m                loss += self.compute_loss(data_o.data_repr, lbl2data_o.data_repr,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     96 [0;31m                                          plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     97 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(96)[0;36mforward[0;34m()[0m
[0;32m     94 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     95 [0;31m                loss += self.compute_loss(data_o.data_repr, lbl2data_o.data_repr,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m---> 96 [0;31m                                          plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     97 [0;31m[0;34m[0m[0m
[0m[0;32m     98 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(95)[0;36mforward[0;34m()[0m
[0;32m     93 [0;31m[0;34m[0m[0m
[0m[0;32m     94 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_query_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 95 [0;31m                loss += self.compute_loss(data_o.data_repr, lbl2data_o.data_repr,lbl2data_data2ptr,lbl2data_idx,
[0m[0;32m     96 [0;31m                                          plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     97 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(98)[0;36mforward[0;34m()[0m
[0;32m     96 [0;31m                                          plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m     97 [0;31m[0;34m[0m[0m
[0m[0;32m---> 98 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     99 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.data_repr,
[0m[0;32m    100 [0;31m                                              lbl2data_data2ptr,lbl2data_idx,plbl2data_data2ptr,plbl2data_idx)
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(99)[0;36mforward[0;34m()[0m
[0;32m     97 [0;31m[0;34m[0m[0m
[0m[0;32m     98 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 99 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.data_repr,
[0m[0;32m    100 [0;31m                                              lbl2data_data2ptr,lbl2data_idx,plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m    101 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.enriched_data_repr,
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(100)[0;36mforward[0;34m()[0m
[0;32m     98 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     99 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.data_repr,
[0m[0;32m--> 100 [0;31m                                              lbl2data_data2ptr,lbl2data_idx,plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m    101 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.enriched_data_repr,
[0m[0;32m    102 [0;31m                                              lbl2data_data2ptr,lbl2data_idx, plbl2data_data2ptr,plbl2data_idx)
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(99)[0;36mforward[0;34m()[0m
[0;32m     97 [0;31m[0;34m[0m[0m
[0m[0;32m     98 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_calib_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 99 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.data_repr,
[0m[0;32m    100 [0;31m                                              lbl2data_data2ptr,lbl2data_idx,plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m    101 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.enriched_data_repr,
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(101)[0;36mforward[0;34m()[0m
[0;32m     99 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.data_repr,
[0m[0;32m    100 [0;31m                                              lbl2data_data2ptr,lbl2data_idx,plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m--> 101 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.enriched_data_repr,
[0m[0;32m    102 [0;31m                                              lbl2data_data2ptr,lbl2data_idx, plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m    103 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(102)[0;36mforward[0;34m()[0m
[0;32m    100 [0;31m                                              lbl2data_data2ptr,lbl2data_idx,plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m    101 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.enriched_data_repr,
[0m[0;32m--> 102 [0;31m                                              lbl2data_data2ptr,lbl2data_idx, plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m    103 [0;31m[0;34m[0m[0m
[0m[0;32m    104 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_meta_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(101)[0;36mforward[0;34m()[0m
[0;32m     99 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.data_repr,
[0m[0;32m    100 [0;31m                                              lbl2data_data2ptr,lbl2data_idx,plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m--> 101 [0;31m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.enriched_data_repr,
[0m[0;32m    102 [0;31m                                              lbl2data_data2ptr,lbl2data_idx, plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m    103 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(104)[0;36mforward[0;34m()[0m
[0;32m    102 [0;31m                                              lbl2data_data2ptr,lbl2data_idx, plbl2data_data2ptr,plbl2data_idx)
[0m[0;32m    103 [0;31m[0;34m[0m[0m
[0m[0;32m--> 104 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_meta_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    105 [0;31m                [0mloss[0m [0;34m+=[0m [0mself[0m[0;34m.[0m[0mcompute_meta_loss[0m[0;34m([0m[0mdata_o[0m[0;34m,[0m [0mlbl2data_o[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    106 [0;31m[0;34m[0m[0m
[0m


ipdb>  self.config.use_meta_loss


False


ipdb>  self.config.use_meta_loss = True
ipdb>  l


[1;32m     99 [0m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.data_repr,
[1;32m    100 [0m                                              lbl2data_data2ptr,lbl2data_idx,plbl2data_data2ptr,plbl2data_idx)
[1;32m    101 [0m                loss += self.calibration_loss(data_o.enriched_data_repr, data_o.data_repr, lbl2data_o.enriched_data_repr,
[1;32m    102 [0m                                              lbl2data_data2ptr,lbl2data_idx, plbl2data_data2ptr,plbl2data_idx)
[1;32m    103 [0m[0;34m[0m[0m
[0;32m--> 104 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_meta_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;32m    105 [0m                [0mloss[0m [0;34m+=[0m [0mself[0m[0;34m.[0m[0mcompute_meta_loss[0m[0;34m([0m[0mdata_o[0m[0;34m,[0m [0mlbl2data_o[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m    106 [0m[0

ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(105)[0;36mforward[0;34m()[0m
[0;32m    103 [0;31m[0;34m[0m[0m
[0m[0;32m    104 [0;31m            [0;32mif[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_meta_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 105 [0;31m                [0mloss[0m [0;34m+=[0m [0mself[0m[0;34m.[0m[0mcompute_meta_loss[0m[0;34m([0m[0mdata_o[0m[0;34m,[0m [0mlbl2data_o[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    106 [0;31m[0;34m[0m[0m
[0m[0;32m    107 [0;31m        [0;32mif[0m [0;32mnot[0m [0mreturn_dict[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  s


--Call--
> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(35)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     33 [0;31m        [0;32mreturn[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mcalib_loss_weight[0m [0;34m*[0m [0mself[0m[0;34m.[0m[0mcab_loss_fn[0m[0;34m([0m[0meinp_repr[0m[0;34m,[0m [0minp_repr[0m[0;34m,[0m [0mtarg_repr[0m[0;34m,[0m [0mtarg_ptr[0m[0;34m,[0m [0mtarg_idx[0m[0;34m,[0m [0mptarg_ptr[0m[0;34m,[0m [0mptarg_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     34 [0;31m[0;34m[0m[0m
[0m[0;32m---> 35 [0;31m    [0;32mdef[0m [0mcompute_meta_loss[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mdata_o[0m[0;34m,[0m [0mlbl2data_o[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     36 [0;31m        [0mloss[0m [0;34m=[0m [0;36m0.0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     37 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m

ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(36)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     34 [0;31m[0;34m[0m[0m
[0m[0;32m     35 [0;31m    [0;32mdef[0m [0mcompute_meta_loss[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mdata_o[0m[0;34m,[0m [0mlbl2data_o[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 36 [0;31m        [0mloss[0m [0;34m=[0m [0;36m0.0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     37 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_loss[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     38 [0;31m        [0mprefix[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(37)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     35 [0;31m    [0;32mdef[0m [0mcompute_meta_loss[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mdata_o[0m[0;34m,[0m [0mlbl2data_o[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     36 [0;31m        [0mloss[0m [0;34m=[0m [0;36m0.0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 37 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_loss[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     38 [0;31m        [0mprefix[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     39 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;

ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(38)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     36 [0;31m        [0mloss[0m [0;34m=[0m [0;36m0.0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     37 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_loss[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 38 [0;31m        [0mprefix[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     39 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     

ipdb>  meta_kwargs.keys()


dict_keys(['cat2data', 'pcat2data'])


ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(39)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     37 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_loss[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     38 [0;31m        [0mprefix[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 39 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     40 [0;31m            [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0m

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(40)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     38 [0;31m        [0mprefix[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     39 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 40 [0;31m            [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m>[0m [0;36m0[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m            loss += self.config.meta_loss_weight * self.compute_loss(data_o.data_meta_repr[idx], 
[0m[0;32

ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(41)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     39 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     40 [0;31m            [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m>[0m [0;36m0[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 41 [0;31m            loss += self.config.meta_loss_weight * self.compute_loss(data_o.data_meta_repr[idx], 
[0m[0;32m     42 [0;31m                                                                     [0mdata_o[0m[0;34m.[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m


ipdb>  idx.shape


torch.Size([100])


ipdb>  data_o.data_meta_repr[idx].shape


torch.Size([100, 768])


ipdb>  data_o.meta_repr.shape


torch.Size([253, 768])


ipdb>  self.config.meta_loss_weight


0.1


ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(42)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     40 [0;31m            [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m>[0m [0;36m0[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m            loss += self.config.meta_loss_weight * self.compute_loss(data_o.data_meta_repr[idx], 
[0m[0;32m---> 42 [0;31m                                                                     [0mdata_o[0m[0;34m.[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     43 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m  

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(43)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     41 [0;31m            loss += self.config.meta_loss_weight * self.compute_loss(data_o.data_meta_repr[idx], 
[0m[0;32m     42 [0;31m                                                                     [0mdata_o[0m[0;34m.[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 43 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m                                                                     [0mmeta_kwargs[0

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(44)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     42 [0;31m                                                                     [0mdata_o[0m[0;34m.[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     43 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 44 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'p{prefix}'[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m]

ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(45)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     43 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 45 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'p{prefix}'[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     46 [0;31m                                                                     meta_kwargs[f'p{prefix}']['idx'])
[0m

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(46)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     44 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'p{prefix}'[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 46 [0;31m                                                                     meta_kwargs[f'p{prefix}']['idx'])
[0m[0;32m     47 [0;31m[0;34m[0m[0m
[0m[0;32m     48 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_loss[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2d

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(41)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     39 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     40 [0;31m            [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m>[0m [0;36m0[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 41 [0;31m            loss += self.config.meta_loss_weight * self.compute_loss(data_o.data_meta_repr[idx], 
[0m[0;32m     42 [0;31m                                                                     [0mdata_o[0m[0;34m.[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(48)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     46 [0;31m                                                                     meta_kwargs[f'p{prefix}']['idx'])
[0m[0;32m     47 [0;31m[0;34m[0m[0m
[0m[0;32m---> 48 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_loss[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     49 [0;31m        [0mprefix[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     50 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(49)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     47 [0;31m[0;34m[0m[0m
[0m[0;32m     48 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_loss[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 49 [0;31m        [0mprefix[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     50 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m            [0midx[0m [0;34m=[0m [0mt

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(50)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     48 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_aug_meta_prefix_for_loss[0m[0;34m([0m[0;34m'lbl'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     49 [0;31m        [0mprefix[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 50 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m            [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m(

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(51)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     49 [0;31m        [0mprefix[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0mlbl2data_aug_meta_prefix[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     50 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 51 [0;31m            [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m>[0m [0;36m0[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m            loss += self.config.meta_loss_weight * self.compute_loss(lbl2data_o.data_meta_repr[idx], 
[

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(52)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     50 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m            [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m>[0m [0;36m0[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 52 [0;31m            loss += self.config.meta_loss_weight * self.compute_loss(lbl2data_o.data_meta_repr[idx], 
[0m[0;32m     53 [0;31m                                                                     [0mlbl2data_o[0m[0;34m.[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(53)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     51 [0;31m            [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m>[0m [0;36m0[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m            loss += self.config.meta_loss_weight * self.compute_loss(lbl2data_o.data_meta_repr[idx], 
[0m[0;32m---> 53 [0;31m                                                                     [0mlbl2data_o[0m[0;34m.[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     54 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(54)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     52 [0;31m            loss += self.config.meta_loss_weight * self.compute_loss(lbl2data_o.data_meta_repr[idx], 
[0m[0;32m     53 [0;31m                                                                     [0mlbl2data_o[0m[0;34m.[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 54 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     56 [0;31m                                                                     [0mmeta_k

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(55)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     53 [0;31m                                                                     [0mlbl2data_o[0m[0;34m.[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     54 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 55 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     56 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'p{prefix}'[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;

ipdb>  prefix


'cat2lbl'


ipdb>  meta_kwargs[f'p{prefix}']['data2ptr'][idx].shape


torch.Size([96])


ipdb>  n


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(56)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     54 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 56 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'p{prefix}'[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     57 [0;31m                                                                     meta_kwargs[f'p{prefix}']['idx'])
[0m

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(57)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     55 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     56 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'p{prefix}'[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 57 [0;31m                                                                     meta_kwargs[f'p{prefix}']['idx'])
[0m[0;32m     58 [0;31m        [0;32mreturn[0m [0mloss[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     59 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(52)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     50 [0;31m        [0;32mif[0m [0mmeta_kwargs[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mlen[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'idx'[0m[0;34m][0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m            [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmeta_kwargs[0m[0;34m[[0m[0mprefix[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m [0;34m>[0m [0;36m0[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 52 [0;31m            loss += self.config.meta_loss_weight * self.compute_loss(lbl2data_o.data_meta_repr[idx], 
[0m[0;32m     53 [0;31m                                                                     [0mlbl2data_o[0m[0;34m.[0m[0mmeta_repr[0m[0;34m,[0m[0;34m[0m[0;34m

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(58)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     56 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'p{prefix}'[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     57 [0;31m                                                                     meta_kwargs[f'p{prefix}']['idx'])
[0m[0;32m---> 58 [0;31m        [0;32mreturn[0m [0mloss[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     59 [0;31m[0;34m[0m[0m
[0m[1;31m5[0;32m    60 [0;31m    def forward(
[0m


ipdb>  


--Return--
tensor(0.0107...AddBackward0>)
> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(58)[0;36mcompute_meta_loss[0;34m()[0m
[0;32m     56 [0;31m                                                                     [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'p{prefix}'[0m[0;34m][0m[0;34m[[0m[0;34m'data2ptr'[0m[0;34m][0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     57 [0;31m                                                                     meta_kwargs[f'p{prefix}']['idx'])
[0m[0;32m---> 58 [0;31m        [0;32mreturn[0m [0mloss[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     59 [0;31m[0;34m[0m[0m
[0m[1;31m5[0;32m    60 [0;31m    def forward(
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(107)[0;36mforward[0;34m()[0m
[0;32m    105 [0;31m                [0mloss[0m [0;34m+=[0m [0mself[0m[0;34m.[0m[0mcompute_meta_loss[0m[0;34m([0m[0mdata_o[0m[0;34m,[0m [0mlbl2data_o[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    106 [0;31m[0;34m[0m[0m
[0m[0;32m--> 107 [0;31m        [0;32mif[0m [0;32mnot[0m [0mreturn_dict[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    108 [0;31m            [0mo[0m [0;34m=[0m [0;34m([0m[0mdata_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0mdata_o[0m[0;34m.[0m[0menriched_data_repr[0m[0;34m,[0m[0mlbl2data_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0mlbl2data_o[0m[0;34m.[0m[0menriched_data_repr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    109 [0;31m            [0;32mreturn[0m [0;34m([0m[0;34m([0m[0mloss[0m[0;34m,[0m[0;34m)[0m [0;34m+[0m [0mo[0m[0;34m)[0m [0;32mif

ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(111)[0;36mforward[0;34m()[0m
[0;32m    109 [0;31m            [0;32mreturn[0m [0;34m([0m[0;34m([0m[0mloss[0m[0;34m,[0m[0;34m)[0m [0;34m+[0m [0mo[0m[0;34m)[0m [0;32mif[0m [0mloss[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mo[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    110 [0;31m[0;34m[0m[0m
[0m[0;32m--> 111 [0;31m        return SAWModelOutput(
[0m[0;32m    112 [0;31m            [0mloss[0m[0;34m=[0m[0mloss[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    113 [0;31m            [0mdata_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(112)[0;36mforward[0;34m()[0m
[0;32m    110 [0;31m[0;34m[0m[0m
[0m[0;32m    111 [0;31m        return SAWModelOutput(
[0m[0;32m--> 112 [0;31m            [0mloss[0m[0;34m=[0m[0mloss[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    113 [0;31m            [0mdata_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    114 [0;31m            [0mdata_enriched_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0menriched_data_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(113)[0;36mforward[0;34m()[0m
[0;32m    111 [0;31m        return SAWModelOutput(
[0m[0;32m    112 [0;31m            [0mloss[0m[0;34m=[0m[0mloss[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 113 [0;31m            [0mdata_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    114 [0;31m            [0mdata_enriched_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0menriched_data_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    115 [0;31m            [0mlbl2data_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(114)[0;36mforward[0;34m()[0m
[0;32m    112 [0;31m            [0mloss[0m[0;34m=[0m[0mloss[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    113 [0;31m            [0mdata_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 114 [0;31m            [0mdata_enriched_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0menriched_data_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    115 [0;31m            [0mlbl2data_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    116 [0;31m            [0mlbl2data_enriched_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0menriched_data_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(115)[0;36mforward[0;34m()[0m
[0;32m    113 [0;31m            [0mdata_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    114 [0;31m            [0mdata_enriched_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0menriched_data_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 115 [0;31m            [0mlbl2data_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    116 [0;31m            [0mlbl2data_enriched_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0menriched_data_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    117 [0;31m        )
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(116)[0;36mforward[0;34m()[0m
[0;32m    114 [0;31m            [0mdata_enriched_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0menriched_data_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    115 [0;31m            [0mlbl2data_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 116 [0;31m            [0mlbl2data_enriched_repr[0m[0;34m=[0m[0mlbl2data_o[0m[0;34m.[0m[0menriched_data_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    117 [0;31m        )
[0m[0;32m    118 [0;31m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(111)[0;36mforward[0;34m()[0m
[0;32m    109 [0;31m            [0;32mreturn[0m [0;34m([0m[0;34m([0m[0mloss[0m[0;34m,[0m[0;34m)[0m [0;34m+[0m [0mo[0m[0;34m)[0m [0;32mif[0m [0mloss[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mo[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    110 [0;31m[0;34m[0m[0m
[0m[0;32m--> 111 [0;31m        return SAWModelOutput(
[0m[0;32m    112 [0;31m            [0mloss[0m[0;34m=[0m[0mloss[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    113 [0;31m            [0mdata_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


--Return--
SAWModelOutpu...ivBackward0>))
> [0;32m/tmp/ipykernel_5306/1126467191.py[0m(111)[0;36mforward[0;34m()[0m
[0;32m    109 [0;31m            [0;32mreturn[0m [0;34m([0m[0;34m([0m[0mloss[0m[0;34m,[0m[0;34m)[0m [0;34m+[0m [0mo[0m[0;34m)[0m [0;32mif[0m [0mloss[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mo[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    110 [0;31m[0;34m[0m[0m
[0m[0;32m--> 111 [0;31m        return SAWModelOutput(
[0m[0;32m    112 [0;31m            [0mloss[0m[0;34m=[0m[0mloss[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    113 [0;31m            [0mdata_repr[0m[0;34m=[0m[0mdata_o[0m[0;34m.[0m[0mdata_repr[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


--Return--
None
> [0;32m/tmp/ipykernel_5306/2352069300.py[0m(5)[0;36mfunc[0;34m()[0m
[0;32m      2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      3 [0;31m    b = prepare_batch(model, batch, m_args=['pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_input_ids', 'cat2data_attention_mask', 
[0m[0;32m      4 [0;31m                                            'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_input_ids', 'cat2lbl_attention_mask'])
[0m[0;32m----> 5 [0;31m    [0mo[0m [0;34m=[0m [0mmodel[0m[0;34m([0m[0;34m**[0m[0mb[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m[0;34m[0m[0m
[0m


ipdb>  


--Call--
> [0;32m/scratch/scai/phd/aiz218323/anaconda3/envs/mogic/lib/python3.10/site-packages/IPython/core/displayhook.py[0m(258)[0;36m__call__[0;34m()[0m
[0;32m    256 [0;31m        [0msys[0m[0;34m.[0m[0mstdout[0m[0;34m.[0m[0mflush[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    257 [0;31m[0;34m[0m[0m
[0m[0;32m--> 258 [0;31m    [0;32mdef[0m [0m__call__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mresult[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    259 [0;31m        """Printing with history cache management.
[0m[0;32m    260 [0;31m[0;34m[0m[0m
[0m


ipdb>  c


[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

