In [1]:
#| default_exp models.cachew

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [4]:
#| export
import torch, torch.nn as nn, torch.nn.functional as F, re
from typing import Optional, Dict, Tuple
from dataclasses import dataclass

from transformers.utils.generic import ModelOutput
from transformers import PretrainedConfig, DistilBertConfig, DistilBertPreTrainedModel, DistilBertModel
from transformers.models.distilbert.modeling_distilbert import create_sinusoidal_embeddings, TransformerBlock

from xcai.losses import *
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

  from .autonotebook import tqdm as notebook_tqdm


## Load data

In [5]:
from xcai.main import *

In [6]:
output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/03_ngame-for-wikiseealsotitles'

data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/'
config_file = 'wikiseealsotitles'
config_key = 'data_meta'

mname = 'sentence-transformers/msmarco-distilbert-base-v4'

In [7]:
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'

In [8]:
pkl_file = f'{pkl_dir}/mogicX/wikiseealsotitles_data-meta_distilbert-base-uncased_sxc.joblib'

In [9]:
block = build_block(pkl_file, config_file, True, config_key, data_dir=data_dir, n_sdata_meta_samples=3)

In [115]:
batch = block.train.one_batch(100)

In [116]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr'])

In [23]:
from transformers import DistilBertModel
m = DistilBertModel.from_pretrained('distilbert-base-uncased')

In [53]:
o = m(input_ids=batch['data_input_ids'], attention_mask=batch['data_attention_mask'])
o = Pooling.mean_pooling(o.last_hidden_state, batch['data_attention_mask'])

## Parameters

In [153]:
#| export
class Parameters:
    
    @staticmethod
    def from_data_aug_meta_prefix_for_encoder(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^{prefix}.*_(input_ids|attention_mask|data2ptr|idx)$', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs
    
    @staticmethod
    def from_aug_meta_prefix_for_feature(feat:str, prefix:str, **kwargs):
        keys = ['attention_mask', 'input_ids', 'idx']        
        inputs = {f'{prefix}_{k}': kwargs[f'{prefix}_{k}'] for k in keys if f'{prefix}_{k}' in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            inputs.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        return inputs

    @staticmethod
    def from_aug_meta_prefix_for_loss(feat:str, prefix:str, **kwargs):
        keys = [f'{prefix}_idx', f'p{prefix}_idx']
        args = {k: kwargs[k] for k in keys if k in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            args.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        if prefix is not None and f'p{prefix}_{feat}2ptr' in kwargs:
            args.update({f'p{prefix}_data2ptr': kwargs[f'p{prefix}_{feat}2ptr']})

        inputs = {}
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = args[arg]
        return inputs
        

In [146]:
params = Parameters.from_aug_meta_prefix_for_loss('lbl', 'cat2lbl', **batch)

In [147]:
params

{'cat2lbl': {'idx': tensor([ 87814,    656, 155406,   3056, 326685, 120394, 258077, 143634, 138944,
          174633, 155467,   1228, 148694,  72242, 161845, 138151, 395355, 439024,
          100619,  68113,  91529,  74225, 569108, 131690, 110921, 508118, 489998,
          129962, 354599, 200014, 490420, 342971, 127050, 126452, 170818, 474981,
          148836, 120396,  84773,  73276,  28131, 296564,  92821,  69899,  74828,
          302326, 152673,  46822, 244562,  68249,  68123, 102584, 150784,  75601,
           68691, 506612,  60180, 303594, 190220, 399247,  62666, 174024, 131647,
          102518,  62592, 491957,  66439, 130779,  53106,  86585, 478961, 144409,
          257443,  72352, 157102, 463424, 330411, 184623, 426850, 141850, 490081,
          426908, 153510,  74217, 167418,  69981, 137586, 167418,  77661, 249965,
          123368,  74827]),
  'data2ptr': tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

## Configuration

In [97]:
#| export
class MemoryConfig(DistilBertConfig):

    def __init__(
        self,
        top_k_metadata:Optional[int] = 5,
        num_metadata:Optional[int] = 100_000,
        **kwargs,
    ):
        self.top_k_metadata = top_k_metadata
        self.num_metadata = num_metadata
        super().__init__(**kwargs)
    

In [98]:
#| export
class CachewConfig(MemoryConfig):

    def __init__(
        self,
        data_aug_meta_prefix:Optional[str] = None, 
        lbl2data_aug_meta_prefix:Optional[str] = None,

        data_enrich:Optional[bool] = True,
        lbl2data_enrich:Optional[bool] = True,
        
        num_batch_labels:Optional[int] = None,
        batch_size:Optional[int] = None,
        margin:Optional[float] = 0.3,
        num_negatives:Optional[int] = 10,
        tau:Optional[float] = 0.1,
        apply_softmax:Optional[bool] = True,

        calib_margin:Optional[float] = 0.05,
        calib_num_negatives:Optional[int] = 10,
        calib_tau:Optional[float] = 0.1,
        calib_apply_softmax:Optional[bool] = False,
        calib_loss_weight:Optional[float] = 0.1,
        use_calib_loss:Optional[float] = False,
        
        use_query_loss:Optional[float] = True,

        meta_loss_weight:Optional[float] = 0.1,
        use_meta_loss:Optional[bool] = False,
        
        use_encoder_parallel:Optional[bool] = True,
        
        **kwargs,
    ):
        self.data_aug_meta_prefix = data_aug_meta_prefix
        self.lbl2data_aug_meta_prefix = lbl2data_aug_meta_prefix

        self.data_enrich = data_enrich
        self.lbl2data_enrich = lbl2data_enrich

        self.num_batch_labels = num_batch_labels
        self.batch_size = batch_size
        self.margin = margin
        self.num_negatives = num_negatives
        self.tau = tau
        self.apply_softmax = apply_softmax

        self.calib_margin = calib_margin
        self.calib_num_negatives = calib_num_negatives
        self.calib_tau = calib_tau
        self.calib_apply_softmax = calib_apply_softmax
        self.calib_loss_weight = calib_loss_weight
        self.use_calib_loss = use_calib_loss

        self.use_query_loss = use_query_loss

        self.meta_loss_weight = meta_loss_weight
        self.use_meta_loss = use_meta_loss

        self.use_encoder_parallel = use_encoder_parallel

        super().__init__(**kwargs)
        

## Memory

In [22]:
#| export
class Memory(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.top_k_metadata = config.top_k_metadata
        
        self.memory_embeddings = nn.Embedding(config.num_metadata, config.dim)
        
        self.position_embeddings = nn.Embedding(config.num_metadata, config.dim)
        if config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                n_pos=config.num_metadata, dim=config.dim, out=self.position_embeddings.weight
            )
        
        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def set_memory_embeddings(self, embed:torch.Tensor):
        with torch.no_grad():
            self.memory_embeddings.weight.copy_(embed)

    def align_embeddings(self, embeddings:torch.Tensor, group_lengths:torch.Tensor):
        n, dim = embeddings.shape
        num_groups, max_len = len(group_lengths), group_lengths.max()
        group_ids = torch.repeat_interleave(torch.arange(num_groups, device=embeddings.device), group_lengths)

        row_indices = torch.arange(n, device=embeddings.device)

        group_start = torch.cat([torch.zeros(1, dtype=group_lengths.dtype, device=group_lengths.device), group_lengths.cumsum(0)[:-1]], dim=0)

        within_idx = row_indices - group_start[group_ids]

        output, mask = torch.zeros((num_groups, max_len, dim), device=embeddings.device), torch.zeros((num_groups, max_len), device=embeddings.device)
        output[group_ids, within_idx] = embeddings
        mask[group_ids, within_idx] = 1.0

        return output, mask
        
    def forward(self, input_embeds:torch.Tensor, input_indices:Optional[torch.Tensor]=None, input_data2ptr:Optional[torch.Tensor]=None):
        assert input_embeds.dim() == 2, f'Input embeddings should be 2-dimensional, but got dim:{input_embeds.dim()}'
        
        meta_norm = F.normalize(self.memory_embeddings.weight, dim=-1)
        input_norm = F.normalize(input_embeds, dim=-1)
        
        scores = input_norm@meta_norm.T
        values, indices = torch.topk(scores, self.top_k_metadata, dim=-1)
        
        pred_embeddings = self.memory_embeddings(indices) + self.position_embeddings(indices)
        pred_embeddings = self.LayerNorm(pred_embeddings)
        pred_embeddings = self.dropout(pred_embeddings)
        pred_mask = torch.ones(pred_embeddings.shape[0], pred_embeddings.shape[1], device=pred_embeddings.device)

        input_embeddings = input_mask = None
        if input_indices is not None:
            input_embeddings = self.memory_embeddings(input_indices) + self.position_embeddings(input_indices)
            input_embeddings = self.LayerNorm(input_embeddings)
            input_embeddings = self.dropout(input_embeddings)
            input_embeddings, input_mask = self.align_embeddings(input_embeddings, input_data2ptr)

        embeddings = pred_embeddings if input_embeddings is None else torch.cat([pred_embeddings, input_embeddings], dim=1)
        mask = pred_mask if input_mask is None else torch.cat([pred_mask, input_mask], dim=1)
        
        return embeddings, mask, scores
        

In [53]:
config = MemoryConfig(num_metadata=block.train.dset.meta['cat_meta'].n_meta)

In [54]:
m = Memory(config)

In [55]:
data_aug_meta_prefix = 'cat2data'
meta_kwargs = Parameters.from_data_aug_meta_prefix_for_encoder(data_aug_meta_prefix, **batch)

In [56]:
data_repr = torch.randn(batch['data_input_ids'].shape[0], config.dim)

In [57]:
embeddings, mask, scores = m(data_repr, meta_kwargs[data_aug_meta_prefix]['idx'], meta_kwargs[data_aug_meta_prefix]['data2ptr'])

In [58]:
embeddings.shape, mask.shape, scores.shape

(torch.Size([100, 8, 768]), torch.Size([100, 8]), torch.Size([100, 656086]))

## Combiner

In [16]:
#| export
class CrossCombinerBlock(TransformerBlock):

    def __init__(self, config: PretrainedConfig):
        super().__init__(config)

    def post_init(self):
        for module in self.modules(): self._init_weights(module)

    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            torch.nn.init.eye_(module.weight)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(
        self,
        x: torch.Tensor,
        m: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, ...]:
        
        # Cross-Attention
        ca_output = self.attention(
            query=x,
            key=m,
            value=m,
            mask=attn_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )
        if output_attentions:
            ca_output, ca_weights = ca_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
            if type(ca_output) is not tuple:
                raise TypeError(f"ca_output must be a tuple but it is {type(ca_output)} type")

            ca_output = ca_output[0]
        ca_output = self.sa_layer_norm(ca_output + x)  # (bs, seq_length, dim)

        # Feed Forward Network
        ffn_output = self.ffn(ca_output)  # (bs, seq_length, dim)
        ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + ca_output)  # (bs, seq_length, dim)

        output = (ffn_output,)
        if output_attentions:
            output = (ca_weights,) + output
        return output
        

## Encoder

In [134]:
#| export
@dataclass
class EncoderOutput(ModelOutput):
    repr: Optional[torch.FloatTensor] = None
    enriched_repr: Optional[torch.FloatTensor] = None
    meta_scores: Optional[torch.FloatTensor] = None
    

In [135]:
#| export
class Encoder(DistilBertPreTrainedModel):
    
    config_class = MemoryConfig
    
    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        self.query_head = RepresentationHead(config)
        self.combiner_head = CrossCombinerBlock(config)
        self.enriched_query_head = RepresentationHead(config)

        self.memory = Memory(config)
        
        self.post_init()

    @torch.no_grad()
    def init_heads_to_identity(self):
        self.query_head.post_init()
        self.combiner_head.post_init()
        self.enriched_query_head.post_init()

    @torch.no_grad()
    def init_combiner_to_last_layer(self):
        lsd = self.distilbert.transformer.layer[-1].state_dict()
        lsd_keys = lsd.keys()        
        csd = self.combiner_head.state_dict()
        csd_keys = csd.keys()
        
        assert len(lsd_keys) == len(csd_keys), f'mismatched keys: {len(lsd_keys)} != {len(csd_keys)}'
        
        for k in csd_keys:
            assert csd[k].shape == lsd[k].shape
            csd[k].copy_(lsd[k])
            
    @torch.no_grad()
    def set_memory_embeddings(self, embed:torch.Tensor):
        self.memory.set_memory_embeddings(embed)
        
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
    def encode_query(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.query_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def encode_enriched_query(self, embed:torch.Tensor):
        return F.normalize(self.enriched_query_head(embed), dim=1)

    def enrich_query_representation(self, data_repr:torch.Tensor, meta_kwargs:Optional[Dict]=None):
        meta_repr, meta_mask, meta_scores = self.memory(data_repr) if meta_kwargs is None else self.memory(data_repr, meta_kwargs['idx'], meta_kwargs['data2ptr'])
        
        meta_mask = meta_mask.view(len(meta_mask), 1, 1, -1).bool()
        fusion_repr = self.combiner_head(x=data_repr.view(len(data_repr), 1, -1), m=meta_repr, attn_mask=meta_mask)
        fusion_repr = fusion_repr[0].squeeze(dim=1)
        
        enriched_data_repr = self.encode_enriched_query(data_repr + fusion_repr)
        return enriched_data_repr, meta_scores

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_enrich: Optional[bool]=True,
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        data_repr = self.encode_query(data_o[0], data_attention_mask)
        
        enriched_data_repr = meta_scores = None
        meta_kwargs = Parameters.from_data_aug_meta_prefix_for_encoder(data_aug_meta_prefix, **kwargs)
        meta_kwargs = meta_kwargs.get(data_aug_meta_prefix, None)
        
        if data_enrich:
            enriched_data_repr, meta_scores = self.enrich_query_representation(data_repr, meta_kwargs)
            
        return EncoderOutput(
            repr=data_repr,
            enriched_repr=enriched_data_repr,
            meta_scores=meta_scores
        )
        

In [180]:
config = MemoryConfig(num_metadata=block.train.dset.meta['cat_meta'].n_meta)
m = Encoder(config)

In [181]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr'])

In [182]:
data_aug_meta_prefix='cat2data'
output = m(**batch, data_aug_meta_prefix=data_aug_meta_prefix)

In [183]:
output

EncoderOutput(repr=tensor([[-0.0538, -0.0170,  0.0591,  ..., -0.0791, -0.0830,  0.0450],
        [-0.0295,  0.0319, -0.0031,  ..., -0.0622, -0.0621,  0.0216],
        [-0.0148,  0.0062, -0.0024,  ..., -0.0583, -0.0546,  0.0029],
        ...,
        [ 0.0536, -0.0087, -0.0208,  ..., -0.0220, -0.0539,  0.0342],
        [-0.0158,  0.0252,  0.0377,  ..., -0.0533, -0.0447,  0.0333],
        [ 0.0298, -0.0146, -0.0188,  ..., -0.0568, -0.0479,  0.0317]],
       grad_fn=<DivBackward0>), enriched_repr=tensor([[ 0.0427, -0.0454, -0.0224,  ..., -0.0174,  0.0331,  0.0002],
        [ 0.0302, -0.0223,  0.0149,  ...,  0.0240, -0.0156,  0.0016],
        [ 0.0192,  0.0120,  0.0273,  ..., -0.0221, -0.0642, -0.0499],
        ...,
        [ 0.0385,  0.0089,  0.0263,  ..., -0.0362,  0.0579, -0.0503],
        [ 0.0260,  0.0138,  0.0007,  ..., -0.0468,  0.0162,  0.0187],
        [-0.0065, -0.0112,  0.0086,  ..., -0.0254,  0.0095,  0.0348]],
       grad_fn=<DivBackward0>), meta_scores=tensor([[ 0.0004, -0.07

## `CAW000`

In [149]:
#| export
@dataclass
class CAWModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    data_repr: Optional[torch.FloatTensor] = None
    data_enriched_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    lbl2data_enriched_repr: Optional[torch.FloatTensor] = None
    

In [150]:
#| export
class CAW000(nn.Module):

    config_class = CachewConfig
    
    def __init__(
        self, 
        config: CachewConfig,
    ):
        super().__init__(config)
        self.config, self.encoder = config, None
        self.rep_loss_fn = MultiTriplet(margin=config.margin, n_negatives=config.num_negatives, tau=config.tau, 
                                        apply_softmax=config.apply_softmax, reduce='mean')
        self.cab_loss_fn = Calibration(margin=config.calib_margin, tau=config.calib_tau, n_negatives=config.calib_num_negatives, 
                                       apply_softmax=config.calib_apply_softmax, reduce='mean')
        
    def init_heads_to_identity(self):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.init_heads_to_identity()

    def init_combiner_to_last_layer(self):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.init_combiner_to_last_layer()

    def set_memory_embeddings(self, embed:torch.Tensor):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.set_memory_embeddings(embed)
        
    def compute_loss(self, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def calibration_loss(self, einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.config.calib_loss_weight * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ): 
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.config.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_aug_meta_prefix_for_feature('data', self.config.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_enrich=self.config.data_enrich, **data_meta_kwargs)
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_aug_meta_prefix_for_feature('lbl', self.config.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_enrich=self.config.lbl2data_enrich, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.enriched_repr, lbl2data_o.repr,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.config.use_query_loss:
                loss += self.compute_loss(data_o.repr, lbl2data_o.repr,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.config.use_calib_loss:
                loss += self.calibration_loss(data_o.enriched_repr, data_o.repr, lbl2data_o.repr,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
        if not return_dict:
            o = (data_o.repr,data_o.enriched_repr,lbl2data_o.repr,lbl2data_o.enriched_repr)
            return ((loss,) + o) if loss is not None else o
        
        return CAWModelOutput(
            loss=loss,
            data_repr=data_o.repr,
            data_enriched_repr=data_o.enriched_repr,
            lbl2data_repr=lbl2data_o.repr,
            lbl2data_enriched_repr=lbl2data_o.enriched_repr,
        )
        

## `CAW001`

In [151]:
#| export
class CAW001(CAW000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    def __init__(self, config):
        super().__init__(config)
        self.encoder = Encoder(config)
        
        self.post_init()
        self.remap_post_init()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [50]:
config = CachewConfig(
    top_k_metadata = 5,
    num_metadata=block.train.dset.meta['cat_meta'].n_meta,

    data_aug_meta_prefix='cat2data', 
    lbl2data_aug_meta_prefix=None,

    data_enrich=True,
    lbl2data_enrich=False,

    batch_size=100, 
    num_batch_labels=5000, 
    margin=0.3,
    num_negatives=5,
    tau=0.1,
    apply_softmax=True,

    calib_margin=0.3,
    calib_num_negatives=10,
    calib_tau=0.1,
    calib_apply_softmax=False,
    calib_loss_weight=0.1,
    use_calib_loss=True,

    use_query_loss=True, 
    use_encoder_parallel=False
)

In [23]:
model = CAW001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', config=config)

Some weights of CAW001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.combiner_head.attention.k_lin.bias', 'encoder.combiner_head.attention.k_lin.weight', 'encoder.combiner_head.attention.out_lin.bias', 'encoder.combiner_head.attention.out_lin.weight', 'encoder.combiner_head.attention.q_lin.bias', 'encoder.combiner_head.attention.q_lin.weight', 'encoder.combiner_head.attention.v_lin.bias', 'encoder.combiner_head.attention.v_lin.weight', 'encoder.combiner_head.ffn.lin1.bias', 'encoder.combiner_head.ffn.lin1.weight', 'encoder.combiner_head.ffn.lin2.bias', 'encoder.combiner_head.ffn.lin2.weight', 'encoder.combiner_head.output_layer_norm.bias', 'encoder.combiner_head.output_layer_norm.weight', 'encoder.combiner_head.sa_layer_norm.bias', 'encoder.combiner_head.sa_layer_norm.weight', 'encoder.enriched_query_head.layer_norm.bias', 'encoder.enriched_query_head.layer_norm.weight', 'encoder.enriched_query_he

In [539]:
model.init_heads_to_identity()

In [24]:
model.init_combiner_to_last_layer()

In [None]:
output = model(**batch)

In [430]:
def func():
    import pdb; pdb.set_trace()
    b = prepare_batch(model, batch, m_args=['cat2data_idx', 'cat2data_data2ptr'])
    o = model(**b)
    

In [27]:
o

CAWModelOutput(loss=tensor(0.1293, grad_fn=<AddBackward0>), data_repr=tensor([[-3.0410e-02, -3.0241e-02, -3.2113e-02,  ..., -3.5623e-02,
         -2.5720e-02,  7.6438e-02],
        [-2.5134e-02,  1.2242e-02, -2.8343e-02,  ...,  7.1558e-02,
         -2.2071e-02, -1.9431e-02],
        [-2.2176e-02, -1.7220e-02, -2.3142e-02,  ...,  4.7774e-02,
         -2.0947e-03, -1.5657e-02],
        ...,
        [-6.7133e-03, -2.3140e-02,  1.3361e-04,  ..., -1.4936e-02,
         -9.2534e-03, -8.7000e-04],
        [-2.4098e-02,  2.2817e-02,  3.1067e-02,  ..., -4.0922e-03,
         -1.7239e-02, -3.6599e-03],
        [ 2.4830e-02,  4.8248e-02, -2.4326e-03,  ...,  1.3954e-01,
          4.4811e-02, -2.6876e-02]], grad_fn=<DivBackward0>), data_enriched_repr=tensor([[-2.0211e-02, -5.9858e-03, -2.2439e-02,  ..., -2.1754e-02,
         -2.2301e-02,  7.6863e-02],
        [-9.8610e-03, -2.1608e-02,  1.0650e-02,  ...,  3.6301e-02,
         -2.0757e-02,  2.3284e-02],
        [ 2.1760e-02, -2.2429e-02, -1.9076e-02, 

In [546]:
from xcai.core import *
from xcai.losses import *

In [543]:
b = prepare_batch(model, batch, m_args=['cat2data_idx', 'cat2data_data2ptr'])

In [544]:
m = model.to('cuda')
b = b.to('cuda')

In [None]:
output = m(**b)

In [424]:
o

CAWModelOutput(loss=tensor(0.1652, device='cuda:0', grad_fn=<AddBackward0>), data_repr=tensor([[-0.0091, -0.0404,  0.0069,  ..., -0.0133,  0.0293, -0.0296],
        [ 0.0700, -0.0143, -0.0293,  ...,  0.0230, -0.0322,  0.0585],
        [ 0.0543,  0.0195, -0.0226,  ...,  0.0164,  0.0351, -0.0225],
        ...,
        [ 0.0748, -0.0329, -0.0415,  ...,  0.0313,  0.0020,  0.0383],
        [ 0.0754, -0.0068, -0.0007,  ..., -0.0117, -0.0502,  0.0220],
        [-0.0040,  0.0592, -0.0239,  ..., -0.0101,  0.0522,  0.0024]],
       device='cuda:0', grad_fn=<DivBackward0>), data_enriched_repr=tensor([[-0.0785, -0.0036,  0.1046,  ...,  0.0229,  0.0140, -0.0448],
        [-0.0200,  0.0137, -0.0136,  ...,  0.0444,  0.0249, -0.0375],
        [-0.0187, -0.0391,  0.0246,  ...,  0.0032,  0.0197,  0.0781],
        ...,
        [-0.0304,  0.0034,  0.0236,  ...,  0.0043,  0.0382, -0.0257],
        [-0.0267, -0.0444, -0.0131,  ...,  0.0010, -0.0422,  0.0187],
        [-0.0179, -0.0321, -0.0272,  ...,  0.041

## `New` MultiTriplet

In [67]:
import xcai.losses as xloss

from xcai.losses import BaseLoss
from xcai.core import store_attr

In [48]:
class BaseMultiTriplet(BaseLoss):

    def __init__(
        self,
        margin:Optional[float]=0.8,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        **kwargs
    ):
        super().__init__(**kwargs)
        store_attr('margin,tau,apply_softmax,n_negatives')

    def align_indices(self, indices:torch.Tensor, group_lengths:torch.Tensor):
        n, num_groups, max_len = len(indices), len(group_lengths), group_lengths.max()
        group_ids = torch.repeat_interleave(torch.arange(num_groups, device=indices.device), group_lengths)
    
        row_indices = torch.arange(n, device=indices.device)
    
        group_start = torch.cat([torch.zeros(1, dtype=group_lengths.dtype, device=group_lengths.device), group_lengths.cumsum(0)[:-1]], dim=0)
    
        within_idx = row_indices - group_start[group_ids]
    
        output = torch.zeros((num_groups, max_len), dtype=indices.dtype, device=indices.device)
        mask = torch.zeros((num_groups, max_len), device=indices.device)
        output[group_ids, within_idx] = indices
        mask[group_ids, within_idx] = 1.0
    
        return output, mask

    def remove_redundant_indices(self, inp2targ_idx:torch.Tensor, n_inp2targ:torch.Tensor, pinp2targ_idx:torch.Tensor, n_pinp2targ:torch.Tensor):
        mask = torch.isin(pinp2targ_idx, inp2targ_idx)
        new_pinp2targ_idx = pinp2targ_idx[mask]
    
        num_groups = len(n_pinp2targ)
        group_ids = torch.repeat_interleave(torch.arange(num_groups, device=n_pinp2targ.device), n_pinp2targ)
        new_n_pinp2targ = torch.bincount(group_ids[mask], minlength=num_groups)
    
        return new_pinp2targ_idx, new_n_pinp2targ

    def reset_indices(self, inp2targ_idx:torch.Tensor, n_inp2targ:torch.Tensor, pinp2targ_idx:torch.Tensor, n_pinp2targ:torch.Tensor):
        _, reset_indices, counts = torch.unique(torch.cat([inp2targ_idx, pinp2targ_idx]), return_inverse=True, return_counts=True)
    
        _, idx_sorted = torch.sort(reset_indices, stable=True)
        cum_sum = torch.cat((torch.zeros((1,), dtype=counts.dtype, device=counts.device), counts.cumsum(0)[:-1]))
        indices = idx_sorted[cum_sum]
    
        inp2targ_idx = reset_indices[:len(inp2targ_idx)]
        pinp2targ_idx = reset_indices[len(inp2targ_idx):]
    
        return inp2targ_idx, pinp2targ_idx, indices

    def compute_scores(self, inp, targ, indices=None):
        if indices is not None: targ = targ[indices]
        return inp@targ.T

    def forward(
        self, 
        
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,

        inp:Optional[torch.FloatTensor]=None, 
        targ:Optional[torch.FloatTensor]=None,
        scores:Optional[torch.FloatTensor]=None,
        
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        store_attr('margin,tau,apply_softmax,n_negatives', is_none=False)

        pinp2targ_idx, n_pinp2targ = self.remove_redundant_indices(inp2targ_idx, n_inp2targ, pinp2targ_idx, n_pinp2targ)
        inp2targ_idx, pinp2targ_idx, indices = self.reset_indices(inp2targ_idx, n_inp2targ, pinp2targ_idx, n_pinp2targ)

        scores = self.compute_scores(inp, targ, indices=indices) if scores is None else scores[:, indices]

        pos_indices, pos_mask = self.align_indices(inp2targ_idx, n_inp2targ)
        pos_scores = scores.gather(1, pos_indices)

        pos_incidence = torch.zeros_like(scores)
        ppos_indices, _ = self.align_indices(inp2targ_idx, n_inp2targ)
        pos_incidence = pos_incidence.scatter(1, ppos_indices, 1)
        neg_incidence = 1 - pos_incidence

        loss = scores.unsqueeze(1) - pos_scores.unsqueeze(2) + self.margin
        loss = F.relu(loss * neg_incidence.unsqueeze(1))

        scores = scores.unsqueeze(1).expand_as(loss)
        neg_incidence = neg_incidence.unsqueeze(1).expand_as(loss)

        if self.n_negatives is not None:
            loss, idx = torch.topk(loss, min(self.n_negatives, loss.shape[2]), dim=2, largest=True)
            scores, neg_incidence = scores.gather(2, idx), neg_incidence.gather(2, idx)

        if self.apply_softmax:
            mask = loss != 0
            penalty = scores / self.tau * mask
            penalty[neg_incidence == 0] = torch.finfo(penalty.dtype).min
            penalty = torch.softmax(penalty, dim=2)
            loss = loss*penalty
        
        loss /= (neg_incidence.sum(dim=2, keepdim=True) + 1e-9)
        loss = loss[pos_mask.bool()].sum(dim=1)

        if self.reduction == 'mean': return loss.mean()
        elif self.reduction == 'sum': return loss.sum()
        else: raise ValueError(f'`reduction` cannot be `{self.reduction}`')
    

In [49]:
class MultiTriplet(BaseMultiTriplet):

    def forward(
        self, 
        inp:torch.FloatTensor, 
        targ:torch.FloatTensor, 
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, inp=inp, targ=targ, margin=margin, tau=tau, 
                               apply_softmax=apply_softmax, n_negatives=n_negatives, **kwargs)
        

In [50]:
class MultiTripletFromScores(BaseMultiTriplet):

    def forward(
        self, 
        scores:torch.FloatTensor,  
        n_inp2targ:torch.LongTensor,
        inp2targ_idx:torch.LongTensor,
        n_pinp2targ:torch.LongTensor,
        pinp2targ_idx:torch.LongTensor,
        margin:Optional[float]=None,
        tau:Optional[float]=None,
        apply_softmax:Optional[bool]=None,
        n_negatives:Optional[int]=None,
        **kwargs
    ):
        return super().forward(n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx, scores=scores, margin=margin, tau=tau, 
                               apply_softmax=apply_softmax, n_negatives=n_negatives, **kwargs)
        

In [51]:
import torch.autograd.profiler as profiler

In [52]:
inp, targ = output.data_repr, output.lbl2data_repr

inp2targ_idx, n_inp2targ = batch['lbl2data_idx'], batch['lbl2data_data2ptr']
pinp2targ_idx, n_pinp2targ = batch['plbl2data_idx'], batch['plbl2data_data2ptr']

In [53]:
inp.shape, targ.shape

(torch.Size([1000, 768]), torch.Size([1000, 768]))

In [54]:
margin, tau = 0.3, 0.1
apply_softmax = True
n_negatives = 10

In [63]:
new_loss_fn = MultiTriplet(margin, tau, apply_softmax, n_negatives, reduce='mean')

with profiler.profile(with_stack=True, profile_memory=True) as prof:
    new_loss = new_loss_fn(inp, targ, n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx)
    print(new_loss)
    

tensor(0.0304, grad_fn=<MeanBackward0>)


In [64]:
print(prof)

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   aten::isin         0.29%      70.581us         3.05%     755.362us     755.362us       4.19 Kb    -206.17 Kb             1  
                aten::_unique         0.20%      50.246us         1.38%     340.472us     340.472us      65.38 Kb     -67.06 Kb             1  
                  aten::empty         0.02%       4.002us         0.02%       4.002us       4.002us           0 b           0 b             1  
                  aten::empty         0.00%       0.425us         0.00%       0.425us       0.425us           0 b           0 b         

In [57]:
scores = inp@targ.T
new_loss_fn = MultiTripletFromScores(margin, tau, apply_softmax, n_negatives, reduce='mean')

with profiler.profile(with_stack=True, profile_memory=True) as prof:
    new_loss = new_loss_fn(scores, n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx)
    print(new_loss)
    

tensor(0.0304, grad_fn=<MeanBackward0>)


In [58]:
print(prof)

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   aten::isin         0.40%      67.912us         4.37%     746.408us     746.408us       4.19 Kb    -206.17 Kb             1  
                aten::_unique         0.29%      50.171us         1.92%     327.497us     327.497us      65.38 Kb     -67.06 Kb             1  
                  aten::empty         0.03%       4.309us         0.03%       4.309us       4.309us           0 b           0 b             1  
                  aten::empty         0.00%       0.504us         0.00%       0.504us       0.504us           0 b           0 b         

In [61]:
old_loss_fn = xloss.MultiTriplet(bsz=1000, tn_targ=1000, margin=margin, tau=tau, apply_softmax=apply_softmax, 
                                 n_negatives=n_negatives, reduce='mean')

with profiler.profile(with_stack=True, profile_memory=True) as prof:
    old_loss = old_loss_fn(inp, targ, n_inp2targ, inp2targ_idx, n_pinp2targ, pinp2targ_idx)
    print(old_loss)
    

tensor(0.0300, grad_fn=<DivBackward0>)


In [62]:
print(prof)

------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                  aten::to         0.01%       2.103us         0.01%       2.103us       2.103us           0 b           0 b             1  
                                  aten::to         0.00%       0.148us         0.00%       0.148us       0.148us           0 b           0 b             1  
                                 aten::max         0.05%      18.833us         0.08%      31.685us      31.685us           8 b           0 b             1  
                               aten::empty         0.03%  

## `CAW002`

In [None]:
#| export
class CAW002(CAW000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    def __init__(self, config):
        super().__init__(config)
        self.encoder = Encoder(config)
        self.meta_loss_fn = MultiTripletFromScores(margin=config.margin, n_negatives=config.num_negatives, tau=config.tau, 
                                                  apply_softmax=config.apply_softmax, reduce='mean')
        self.post_init()
        self.remap_post_init()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert

    def compute_meta_loss(self, scores, feat, prefix, **kwargs):
        loss = 0.0
        meta_kwargs = Parameters.from_aug_meta_prefix_for_loss(feat, prefix, **kwargs)
        if len(meta_kwargs):
            args, pargs = meta_kwargs[prefix], meta_kwargs[f'p{prefix}']
            loss = self.config.meta_loss_weight * self.meta_loss_fn(scores[:, args['idx']], args['data2ptr'], args['idx'], 
                                                                    pargs['data2ptr'], pargs['idx'])
        return loss

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ): 
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.config.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_aug_meta_prefix_for_feature('data', self.config.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_enrich=self.config.data_enrich, **data_meta_kwargs)
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_aug_meta_prefix_for_feature('lbl', self.config.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_enrich=self.config.lbl2data_enrich, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.enriched_repr, lbl2data_o.repr,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)
            
            if self.config.use_query_loss:
                loss += self.compute_loss(data_o.repr, lbl2data_o.repr,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)
                
            if self.config.use_calib_loss:
                loss += self.calibration_loss(data_o.enriched_repr, data_o.repr, lbl2data_o.repr,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)

            if self.config.use_meta_loss:
                loss += self.compute_meta_loss(data_o.meta_scores, 'data', self.config.data_aug_meta_prefix, **kwargs)
                loss += self.compute_meta_loss(lbl2data_o.meta_scores, 'lbl', self.config.lbl2data_aug_meta_prefix, **kwargs)
            
        if not return_dict:
            o = (data_o.repr,data_o.enriched_repr,lbl2data_o.repr,lbl2data_o.enriched_repr)
            return ((loss,) + o) if loss is not None else o
        
        return CAWModelOutput(
            loss=loss,
            data_repr=data_o.repr,
            data_enriched_repr=data_o.enriched_repr,
            lbl2data_repr=lbl2data_o.repr,
            lbl2data_enriched_repr=lbl2data_o.enriched_repr,
        )
        

### Example

In [138]:
config = CachewConfig(
    top_k_metadata = 5,
    num_metadata=block.train.dset.meta['cat_meta'].n_meta,

    data_aug_meta_prefix='cat2data', 
    lbl2data_aug_meta_prefix=None,

    data_enrich=True,
    lbl2data_enrich=False,

    batch_size=100, 
    num_batch_labels=5000, 
    margin=0.3,
    num_negatives=10,
    tau=0.1,
    apply_softmax=True,

    calib_margin=0.3,
    calib_num_negatives=10,
    calib_tau=0.1,
    calib_apply_softmax=False,
    calib_loss_weight=0.1,
    use_calib_loss=True,

    meta_loss_weight=0.1,
    use_meta_loss=True,
    
    use_query_loss=True, 
    use_encoder_parallel=False
)

In [127]:
model = CAW002.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', config=config)

Some weights of CAW002 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.combiner_head.attention.k_lin.bias', 'encoder.combiner_head.attention.k_lin.weight', 'encoder.combiner_head.attention.out_lin.bias', 'encoder.combiner_head.attention.out_lin.weight', 'encoder.combiner_head.attention.q_lin.bias', 'encoder.combiner_head.attention.q_lin.weight', 'encoder.combiner_head.attention.v_lin.bias', 'encoder.combiner_head.attention.v_lin.weight', 'encoder.combiner_head.ffn.lin1.bias', 'encoder.combiner_head.ffn.lin1.weight', 'encoder.combiner_head.ffn.lin2.bias', 'encoder.combiner_head.ffn.lin2.weight', 'encoder.combiner_head.output_layer_norm.bias', 'encoder.combiner_head.output_layer_norm.weight', 'encoder.combiner_head.sa_layer_norm.bias', 'encoder.combiner_head.sa_layer_norm.weight', 'encoder.enriched_query_head.layer_norm.bias', 'encoder.enriched_query_head.layer_norm.weight', 'encoder.enriched_query_he

In [128]:
model.init_combiner_to_last_layer()

In [129]:
output = model(**batch)

In [130]:
output

CAWModelOutput(loss=tensor(0.1707, grad_fn=<AddBackward0>), data_repr=tensor([[-0.0214, -0.0066,  0.0303,  ...,  0.0182, -0.0135, -0.0793],
        [ 0.0273,  0.0134, -0.0062,  ...,  0.0202,  0.0661,  0.0051],
        [-0.0364,  0.0412,  0.0306,  ...,  0.0285, -0.0069, -0.0259],
        ...,
        [-0.0645,  0.0458, -0.0585,  ...,  0.0032,  0.0341, -0.0968],
        [ 0.0521, -0.0134, -0.0446,  ...,  0.0345, -0.0384, -0.0503],
        [ 0.0612, -0.0032,  0.0068,  ...,  0.0308,  0.0379,  0.0087]],
       grad_fn=<DivBackward0>), data_enriched_repr=tensor([[ 0.0099,  0.0407,  0.0103,  ...,  0.0601, -0.0036, -0.0469],
        [ 0.0121, -0.0055, -0.0352,  ...,  0.0269, -0.0329,  0.0282],
        [ 0.0361,  0.0599, -0.0128,  ..., -0.0444,  0.0144, -0.0105],
        ...,
        [ 0.0085,  0.0783,  0.0237,  ...,  0.0411, -0.0147, -0.0539],
        [ 0.0198,  0.0504, -0.0206,  ..., -0.0257,  0.0670,  0.0072],
        [ 0.1197,  0.0223,  0.0419,  ...,  0.0094, -0.0042,  0.0105]],
       grad

In [140]:
def func():
    import pdb; pdb.set_trace()
    output = model(**batch)
    