In [1]:
#| default_exp models.cachew

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [234]:
#| export
import torch, torch.nn as nn, torch.nn.functional as F, re
from typing import Optional, Dict, Tuple
from dataclasses import dataclass

from transformers.utils.generic import ModelOutput
from transformers import PretrainedConfig, DistilBertConfig, DistilBertPreTrainedModel, DistilBertModel
from transformers.models.distilbert.modeling_distilbert import create_sinusoidal_embeddings, TransformerBlock

from xcai.losses import *
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

## Load data

In [5]:
from xcai.main import *

In [6]:
output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/03_ngame-for-wikiseealsotitles'

data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/'
config_file = 'wikiseealsotitles'
config_key = 'data_meta'

mname = 'sentence-transformers/msmarco-distilbert-base-v4'

In [7]:
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'

In [8]:
pkl_file = f'{pkl_dir}/mogicX/wikiseealsotitles_data-meta_distilbert-base-uncased_sxc.joblib'

In [9]:
block = build_block(pkl_file, config_file, True, config_key, data_dir=data_dir, n_sdata_meta_samples=3)

In [10]:
batch = block.train.one_batch(100)

In [11]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr'])

In [23]:
from transformers import DistilBertModel
m = DistilBertModel.from_pretrained('distilbert-base-uncased')

In [53]:
o = m(input_ids=batch['data_input_ids'], attention_mask=batch['data_attention_mask'])
o = Pooling.mean_pooling(o.last_hidden_state, batch['data_attention_mask'])

## Parameters

In [394]:
#| export
class Parameters:
    
    @staticmethod
    def from_data_aug_meta_prefix_for_encoder(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^{prefix}.*_(input_ids|attention_mask|data2ptr|idx)$', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs
    
    @staticmethod
    def from_data_aug_meta_prefix_for_feature(feat:str, prefix:str, **kwargs):
        keys = ['attention_mask', 'input_ids', 'idx']        
        inputs = {f'{prefix}_{k}': kwargs[f'{prefix}_{k}'] for k in keys if f'{prefix}_{k}' in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            inputs.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        return inputs
        

## Configuration

In [395]:
#| export
class MemoryConfig(DistilBertConfig):

    def __init__(
        self,
        top_k_metadata:Optional[int] = 5,
        num_metadata:Optional[int] = 100_000,
        **kwargs,
    ):
        self.top_k_metadata = top_k_metadata
        self.num_metadata = num_metadata
        super().__init__(**kwargs)
    

In [396]:
#| export
class CachewConfig(MemoryConfig):

    def __init__(
        self,
        data_aug_meta_prefix:Optional[str] = None, 
        lbl2data_aug_meta_prefix:Optional[str] = None, 
        
        num_batch_labels:Optional[int] = None,
        batch_size:Optional[int] = None,
        margin:Optional[float] = 0.3,
        num_negatives:Optional[int] = 10,
        tau:Optional[float] = 0.1,
        apply_softmax:Optional[bool] = True,

        calib_margin:Optional[float] = 0.05,
        calib_num_negatives:Optional[int] = 10,
        calib_tau:Optional[float] = 0.1,
        calib_apply_softmax:Optional[bool] = False,
        calib_loss_weight:Optional[float] = 0.1,
        use_calib_loss:Optional[float] = False,
        
        use_query_loss:Optional[float] = True,
        
        use_encoder_parallel:Optional[bool] = True,
        
        **kwargs,
    ):
        self.data_aug_meta_prefix = data_aug_meta_prefix
        self.lbl2data_aug_meta_prefix = lbl2data_aug_meta_prefix

        self.num_batch_labels = num_batch_labels
        self.batch_size = batch_size
        self.margin = margin
        self.num_negatives = num_negatives
        self.tau = tau
        self.apply_softmax = apply_softmax

        self.calib_margin = calib_margin
        self.calib_num_negatives = calib_num_negatives
        self.calib_tau = calib_tau
        self.calib_apply_softmax = calib_apply_softmax
        self.calib_loss_weight = calib_loss_weight
        self.use_calib_loss = use_calib_loss

        self.use_query_loss = use_query_loss

        self.use_encoder_parallel = use_encoder_parallel

        super().__init__(**kwargs)
        

## Memory

In [397]:
#| export
class Memory(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.top_k_metadata = config.top_k_metadata
        
        self.memory_embeddings = nn.Embedding(config.num_metadata, config.dim)
        
        self.position_embeddings = nn.Embedding(config.num_metadata, config.dim)
        if config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                n_pos=config.num_metadata, dim=config.dim, out=self.position_embeddings.weight
            )
        
        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.dropout)

    def set_memory_embeddings(self, embed:torch.Tensor):
        with torch.no_grad():
            self.memory_embeddings.weight.copy_(embed)

    def align_embeddings(self, embeddings:torch.Tensor, group_lengths:torch.Tensor):
        n, dim = embeddings.shape
        num_groups, max_len = len(group_lengths), group_lengths.max()
        group_ids = torch.repeat_interleave(torch.arange(num_groups, device=embeddings.device), group_lengths)

        row_indices = torch.arange(n, device=embeddings.device)

        group_start = torch.cat([torch.zeros(1, dtype=group_lengths.dtype, device=group_lengths.device), group_lengths.cumsum(0)[:-1]], dim=0)

        within_idx = row_indices - group_start[group_ids]

        output, mask = torch.zeros((num_groups, max_len, dim), device=embeddings.device), torch.zeros((num_groups, max_len), device=embeddings.device)
        output[group_ids, within_idx] = embeddings
        mask[group_ids, within_idx] = 1.0

        return output, mask
        
    def forward(self, input_embeds:torch.Tensor, input_indices:Optional[torch.Tensor]=None, input_data2ptr:Optional[torch.Tensor]=None):
        assert input_embeds.dim() == 2, f'Input embeddings should be 2-dimensional, but got dim:{input_embeds.dim()}'
        
        meta_norm = F.normalize(self.memory_embeddings.weight, dim=-1)
        input_norm = F.normalize(input_embeds, dim=-1)
        
        scores = input_norm@meta_norm.T
        values, indices = torch.topk(scores, self.top_k_metadata, dim=-1)
        
        pred_embeddings = self.memory_embeddings(indices) + self.position_embeddings(indices)
        pred_embeddings = self.LayerNorm(pred_embeddings)
        pred_embeddings = self.dropout(pred_embeddings)
        pred_mask = torch.ones(pred_embeddings.shape[0], pred_embeddings.shape[1], device=pred_embeddings.device)

        input_embeddings = input_mask = None
        if input_indices is not None:
            input_embeddings = self.memory_embeddings(input_indices) + self.position_embeddings(input_indices)
            input_embeddings = self.LayerNorm(input_embeddings)
            input_embeddings = self.dropout(input_embeddings)
            input_embeddings, input_mask = self.align_embeddings(input_embeddings, input_data2ptr)

        embeddings = pred_embeddings if input_embeddings is None else torch.cat([pred_embeddings, input_embeddings], dim=1)
        mask = pred_mask if pred_mask is None else torch.cat([pred_mask, input_mask], dim=1)
        
        return embeddings, mask, scores
        

In [53]:
config = MemoryConfig(num_metadata=block.train.dset.meta['cat_meta'].n_meta)

In [54]:
m = Memory(config)

In [55]:
data_aug_meta_prefix = 'cat2data'
meta_kwargs = Parameters.from_data_aug_meta_prefix_for_encoder(data_aug_meta_prefix, **batch)

In [56]:
data_repr = torch.randn(batch['data_input_ids'].shape[0], config.dim)

In [57]:
embeddings, mask, scores = m(data_repr, meta_kwargs[data_aug_meta_prefix]['idx'], meta_kwargs[data_aug_meta_prefix]['data2ptr'])

In [58]:
embeddings.shape, mask.shape, scores.shape

(torch.Size([100, 8, 768]), torch.Size([100, 8]), torch.Size([100, 656086]))

## Combiner

In [398]:
#| export
class CrossCombinerBlock(TransformerBlock):

    def __init__(self, config: PretrainedConfig):
        super().__init__(config)

    def post_init(self):
        for module in self.modules(): self._init_weights(module)

    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            torch.nn.init.eye_(module.weight)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(
        self,
        x: torch.Tensor,
        m: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
    ) -> Tuple[torch.Tensor, ...]:
        
        # Cross-Attention
        ca_output = self.attention(
            query=x,
            key=m,
            value=m,
            mask=attn_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
        )
        if output_attentions:
            ca_output, ca_weights = ca_output  # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
        else:  # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
            if type(ca_output) is not tuple:
                raise TypeError(f"ca_output must be a tuple but it is {type(ca_output)} type")

            ca_output = ca_output[0]
        ca_output = self.sa_layer_norm(ca_output + x)  # (bs, seq_length, dim)

        # Feed Forward Network
        ffn_output = self.ffn(ca_output)  # (bs, seq_length, dim)
        ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + ca_output)  # (bs, seq_length, dim)

        output = (ffn_output,)
        if output_attentions:
            output = (ca_weights,) + output
        return output
        

## Encoder

In [399]:
#| export
@dataclass
class EncoderOutput(ModelOutput):
    repr: Optional[torch.FloatTensor] = None
    enriched_repr: Optional[torch.FloatTensor] = None
    meta_scores: Optional[torch.FloatTensor] = None
    

In [400]:
#| export
class Encoder(DistilBertPreTrainedModel):
    
    config_class = MemoryConfig
    
    def __init__(
        self, 
        config:PretrainedConfig, 
    ):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        self.query_head = RepresentationHead(config)
        self.combiner_head = CrossCombinerBlock(config)
        self.enriched_query_head = RepresentationHead(config)

        self.memory = Memory(config)
        
        self.post_init()

    @torch.no_grad()
    def init_heads_to_identity(self):
        self.query_head.post_init()
        self.combiner_head.post_init()
        self.enriched_query_head.post_init()

    @torch.no_grad()
    def init_combiner_to_last_layer(self):
        lsd = self.distilbert.transformer.layer[-1].state_dict()
        lsd_keys = lsd.keys()        
        csd = self.combiner_head.state_dict()
        csd_keys = csd.keys()
        
        assert len(lsd_keys) == len(csd_keys), f'mismatched keys: {len(lsd_keys)} != {len(csd_keys)}'
        
        for k in csd_keys:
            assert csd[k].shape == lsd[k].shape
            csd[k].copy_(lsd[k])
            
    @torch.no_grad()
    def set_memory_embeddings(self, embed:torch.Tensor):
        self.memory.set_memory_embeddings(embed)
        
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
        
    def encode_query(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.query_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def encode_enriched_query(self, embed:torch.Tensor):
        return F.normalize(self.enriched_query_head(embed), dim=1)

    def enrich_query_representation(self, data_repr:torch.Tensor, meta_kwargs:Dict):
        meta_repr, meta_mask, meta_scores = self.memory(data_repr, meta_kwargs['idx'], meta_kwargs['data2ptr'])
        
        meta_mask = meta_mask.view(len(meta_mask), 1, 1, -1).bool()
        fusion_repr = self.combiner_head(x=data_repr.view(len(data_repr), 1, -1), m=meta_repr, attn_mask=meta_mask)
        fusion_repr = fusion_repr[0].squeeze(dim=1)
        
        enriched_data_repr = self.encode_enriched_query(data_repr + fusion_repr)
        return enriched_data_repr, meta_scores

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        data_repr = self.encode_query(data_o[0], data_attention_mask)
        
        enriched_data_repr = meta_scores = None
        meta_kwargs = Parameters.from_data_aug_meta_prefix_for_encoder(data_aug_meta_prefix, **kwargs)
        if len(meta_kwargs): 
            enriched_data_repr, meta_scores = self.enrich_query_representation(data_repr, meta_kwargs[data_aug_meta_prefix])
            
        return EncoderOutput(
            repr=data_repr,
            enriched_repr=enriched_data_repr,
            meta_scores=meta_scores
        )
        

In [180]:
config = MemoryConfig(num_metadata=block.train.dset.meta['cat_meta'].n_meta)
m = Encoder(config)

In [181]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr'])

In [182]:
data_aug_meta_prefix='cat2data'
output = m(**batch, data_aug_meta_prefix=data_aug_meta_prefix)

In [183]:
output

EncoderOutput(repr=tensor([[-0.0538, -0.0170,  0.0591,  ..., -0.0791, -0.0830,  0.0450],
        [-0.0295,  0.0319, -0.0031,  ..., -0.0622, -0.0621,  0.0216],
        [-0.0148,  0.0062, -0.0024,  ..., -0.0583, -0.0546,  0.0029],
        ...,
        [ 0.0536, -0.0087, -0.0208,  ..., -0.0220, -0.0539,  0.0342],
        [-0.0158,  0.0252,  0.0377,  ..., -0.0533, -0.0447,  0.0333],
        [ 0.0298, -0.0146, -0.0188,  ..., -0.0568, -0.0479,  0.0317]],
       grad_fn=<DivBackward0>), enriched_repr=tensor([[ 0.0427, -0.0454, -0.0224,  ..., -0.0174,  0.0331,  0.0002],
        [ 0.0302, -0.0223,  0.0149,  ...,  0.0240, -0.0156,  0.0016],
        [ 0.0192,  0.0120,  0.0273,  ..., -0.0221, -0.0642, -0.0499],
        ...,
        [ 0.0385,  0.0089,  0.0263,  ..., -0.0362,  0.0579, -0.0503],
        [ 0.0260,  0.0138,  0.0007,  ..., -0.0468,  0.0162,  0.0187],
        [-0.0065, -0.0112,  0.0086,  ..., -0.0254,  0.0095,  0.0348]],
       grad_fn=<DivBackward0>), meta_scores=tensor([[ 0.0004, -0.07

## `CAW000`

In [401]:
#| export
@dataclass
class CAWModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    data_repr: Optional[torch.FloatTensor] = None
    data_enriched_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    lbl2data_enriched_repr: Optional[torch.FloatTensor] = None
    

In [402]:
#| export
class CAW000(nn.Module):

    config_class = CachewConfig
    
    def __init__(
        self, 
        config: CachewConfig,
    ):
        super().__init__(config)
        self.config, self.encoder = config, None
        self.rep_loss_fn = MultiTriplet(bsz=config.batch_size, tn_targ=config.num_batch_labels, margin=config.margin, 
                                        n_negatives=config.num_negatives, tau=config.tau, apply_softmax=config.apply_softmax, 
                                        reduce='mean')
        self.cab_loss_fn = Calibration(margin=config.calib_margin, tau=config.calib_tau, n_negatives=config.calib_num_negatives, 
                                       apply_softmax=config.calib_apply_softmax, reduce='mean')
        
    def init_heads_to_identity(self):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.init_heads_to_identity()

    def init_combiner_to_last_layer(self):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.init_combiner_to_last_layer()

    def set_memory_embeddings(self, embed:torch.Tensor):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.set_memory_embeddings(embed)
        
    def compute_loss(self, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def calibration_loss(self, einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.config.calib_loss_weight * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ): 
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.config.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_data_aug_meta_prefix_for_feature('data', self.config.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, **data_meta_kwargs)
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_data_aug_meta_prefix_for_feature('lbl2data', self.config.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.enriched_repr, lbl2data_o.repr,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.config.use_query_loss:
                loss += self.compute_loss(data_o.repr, lbl2data_o.repr,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.config.use_calib_loss:
                loss += self.calibration_loss(data_o.enriched_repr, data_o.repr, lbl2data_o.repr,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
        if not return_dict:
            o = (data_o.repr,data_o.enriched_repr,lbl2data_o.repr,lbl2data_o.enriched_repr)
            return ((loss,) + o) if loss is not None else o
        
        return CAWModelOutput(
            loss=loss,
            data_repr=data_o.repr,
            data_enriched_repr=data_o.enriched_repr,
            lbl2data_repr=lbl2data_o.repr,
            lbl2data_enriched_repr=lbl2data_o.enriched_repr,
        )
        

## `CAW001`

In [403]:
#| export
class CAW001(CAW000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    def __init__(self, config):
        super().__init__(config)
        self.encoder = Encoder(config)
        
        self.post_init()
        self.remap_post_init()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [404]:
config = CachewConfig(
    top_k_metadata = 5,
    num_metadata=block.train.dset.meta['cat_meta'].n_meta,

    data_aug_meta_prefix='cat2data', 
    lbl2data_aug_meta_prefix=None,

    batch_size=100, 
    num_batch_labels=5000, 
    margin=0.3,
    num_negatives=5,
    tau=0.1,
    apply_softmax=True,

    calib_margin=0.3,
    calib_num_negatives=10,
    calib_tau=0.1,
    calib_apply_softmax=False,
    calib_loss_weight=0.1,
    use_calib_loss=True,

    use_query_loss=True, 
    use_encoder_parallel=False
)

In [405]:
model = CAW001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', config=config)

Some weights of CAW001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.combiner_head.attention.k_lin.bias', 'encoder.combiner_head.attention.k_lin.weight', 'encoder.combiner_head.attention.out_lin.bias', 'encoder.combiner_head.attention.out_lin.weight', 'encoder.combiner_head.attention.q_lin.bias', 'encoder.combiner_head.attention.q_lin.weight', 'encoder.combiner_head.attention.v_lin.bias', 'encoder.combiner_head.attention.v_lin.weight', 'encoder.combiner_head.ffn.lin1.bias', 'encoder.combiner_head.ffn.lin1.weight', 'encoder.combiner_head.ffn.lin2.bias', 'encoder.combiner_head.ffn.lin2.weight', 'encoder.combiner_head.output_layer_norm.bias', 'encoder.combiner_head.output_layer_norm.weight', 'encoder.combiner_head.sa_layer_norm.bias', 'encoder.combiner_head.sa_layer_norm.weight', 'encoder.enriched_query_head.layer_norm.bias', 'encoder.enriched_query_head.layer_norm.weight', 'encoder.enriched_query_he

In [333]:
model.init_heads_to_identity()

In [334]:
model.init_combiner_to_last_layer()

In [406]:
o = model(**batch)

In [408]:
o

CAWModelOutput(loss=tensor(0.1652, grad_fn=<AddBackward0>), data_repr=tensor([[-0.0091, -0.0404,  0.0069,  ..., -0.0133,  0.0293, -0.0296],
        [ 0.0700, -0.0143, -0.0293,  ...,  0.0230, -0.0322,  0.0585],
        [ 0.0543,  0.0195, -0.0226,  ...,  0.0164,  0.0351, -0.0225],
        ...,
        [ 0.0748, -0.0329, -0.0415,  ...,  0.0313,  0.0020,  0.0383],
        [ 0.0754, -0.0068, -0.0007,  ..., -0.0117, -0.0502,  0.0220],
        [-0.0040,  0.0592, -0.0239,  ..., -0.0101,  0.0522,  0.0024]],
       grad_fn=<DivBackward0>), data_enriched_repr=tensor([[-0.0785, -0.0036,  0.1046,  ...,  0.0229,  0.0140, -0.0448],
        [-0.0200,  0.0137, -0.0136,  ...,  0.0444,  0.0249, -0.0375],
        [-0.0187, -0.0391,  0.0246,  ...,  0.0032,  0.0197,  0.0781],
        ...,
        [-0.0304,  0.0034,  0.0236,  ...,  0.0043,  0.0382, -0.0257],
        [-0.0267, -0.0444, -0.0131,  ...,  0.0010, -0.0422,  0.0187],
        [-0.0179, -0.0321, -0.0272,  ...,  0.0418,  0.0102, -0.0213]],
       grad

In [409]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr'])

## `CAW002`

In [376]:
def align_indices(indices:torch.Tensor, group_lengths:torch.Tensor):
    n, num_groups, max_len = len(indices), len(group_lengths), group_lengths.max()
    group_ids = torch.repeat_interleave(torch.arange(num_groups, device=embeddings.device), group_lengths)

    row_indices = torch.arange(n, device=embeddings.device)

    group_start = torch.cat([torch.zeros(1, dtype=group_lengths.dtype, device=group_lengths.device), group_lengths.cumsum(0)[:-1]], dim=0)

    within_idx = row_indices - group_start[group_ids]

    output = torch.zeros((num_groups, max_len), dtype=indices.dtype, device=embeddings.device)
    mask = torch.zeros((num_groups, max_len), device=embeddings.device)
    output[group_ids, within_idx] = indices
    mask[group_ids, within_idx] = 1.0

    return output, mask
    

In [377]:
pos_indices, pos_mask = align_indices(batch['pcat2data_idx'], batch['pcat2data_data2ptr'])

In [392]:
o.gather(1, pos_indices)

tensor([[-0.0248, -0.0620,  0.0081,  ..., -0.0392, -0.0392, -0.0392],
        [-0.0060,  0.0322,  0.0235,  ..., -0.0391, -0.0391, -0.0391],
        [ 0.0424,  0.0510, -0.0347,  ..., -0.0060, -0.0060, -0.0060],
        ...,
        [-0.0385,  0.0220,  0.0220,  ...,  0.0220,  0.0220,  0.0220],
        [ 0.0434, -0.0090, -0.0090,  ..., -0.0090, -0.0090, -0.0090],
        [ 0.0067,  0.0091, -0.0174,  ..., -0.0174, -0.0174, -0.0174]],
       grad_fn=<GatherBackward0>)

In [384]:
class _Loss(torch.nn.Module):
    def __init__(self, reduction='mean', pad_ind=None):
        super(_Loss, self).__init__()
        self.reduction = reduction
        self.pad_ind = pad_ind

    def _reduce(self, loss):
        if self.reduction == 'none':
            return loss
        elif self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'custom':
            return loss.sum(dim=1).mean()
        else:
            return loss.sum()

    def _mask_at_pad(self, loss):
        """
        Mask the loss at padding index, i.e., make it zero
        """
        if self.pad_ind is not None:
            loss[:, self.pad_ind] = 0.0
        return loss

    def _mask(self, loss, mask=None):
        """
        Mask the loss at padding index, i.e., make it zero
        * Mask should be a boolean array with 1 where loss needs
        to be considered.
        * it'll make it zero where value is 0
        """
        if mask is not None:
            loss = loss.masked_fill(~mask, 0.0)
        return loss
        

In [385]:
class TripletMarginLoss(_Loss):

    def __init__(self, margin=1.0, eps=1.0e-6, reduction='mean',
                 num_positives=3, num_negatives=10,
                 num_violators=False, alpha=0.9):
        super(TripletMarginLossOHNMMulti, self).__init__(reduction)
        self.mx_lim = 100
        self.mn_lim = -100
        self.alpha = alpha
        self._eps = eps
        self._margin = margin
        self._reduction = reduction
        self.num_positives = num_positives
        self.num_negatives = num_negatives
        self.num_violators = num_violators

    def forward(self, output, target, *args):
        B = target.size(0)
        if target.size(0) != target.size(1):
            MX_LIM = torch.full_like(output, self.mx_lim)
            sim_p = output.where(target == 1, MX_LIM)
            indices = sim_p.topk(largest=False, dim=1, k=self.num_positives)[1]
            sim_p = sim_p.gather(1, indices)
        else:
            sim_p = output.diagonal().view(B, 1)
        
        MN_LIM = torch.full_like(output, self.mn_lim)
        target = target.to(output.device)

        _, num_p = sim_p.size()
        sim_p = sim_p.view(B, num_p, 1)
        sim_m = MN_LIM.where(target == 1, output)
        indices = sim_m.topk(largest=True, dim=1, k=self.num_negatives)[1]
        sim_n = output.gather(1, indices)
        sim_n = sim_n.unsqueeze(1).repeat_interleave(num_p, dim=1)
        loss = F.relu(sim_n - sim_p + self._margin)
        prob = loss.clone()
        prob.masked_fill_(prob == 0, self.mn_lim)
        loss = F.softmax(prob, dim=-1)*loss
        if (self._reduction == "mean"):
            reduced_loss = loss.mean()
        else:
            reduced_loss = loss.sum()
        if self.num_violators:
            nnz = torch.sum((loss > 0), axis=1).float().mean()
            return reduced_loss, nnz
        else:
            return reduced_loss
            