In [285]:
#| default_exp models.upma

In [286]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [287]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [381]:
#| export
import torch, torch.nn as nn, re
from dataclasses import dataclass
from torch.nn.parallel import DataParallel
from typing import Optional, Union, Tuple, Any, Dict, Sequence

from transformers import DistilBertConfig, DistilBertModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, ModelOutput
from transformers.models.distilbert.modeling_distilbert import Embeddings, TransformerBlock
from transformers.activations import get_activation

from xcai.core import *
from xcai.losses import *
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import Pooling

## Load data

In [8]:
from xcai.main import *

In [9]:
data_dir = '/Users/suchith720/Projects/data'
config_file = 'wikiseealsotitles'
config_key = 'data_meta'

mname = 'sentence-transformers/msmarco-distilbert-base-v4'

In [40]:
pkl_dir = f'{data_dir}/processed/mogicX/'
pkl_file = f'{pkl_dir}/wikiseealsotitles_data-meta_distilbert-base-uncased_sxc.joblib'

In [41]:
block = build_block(pkl_file, config_file, True, config_key, data_dir=data_dir, 
                    n_sdata_meta_samples=3, do_build=False)

In [208]:
batch = block.train.dset.__getitems__([10, 30])

In [209]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr'])

## Configuration

In [297]:
#| export
class UPMAConfig(DistilBertConfig):

    def __init__(
        self,
        num_total_metadata: Optional[int] = None,
        num_input_metadata: Optional[int] = 3,
        pad_metadata_idx: Optional[int] = None,
        metadata_dropout: Optional[float] = 0.1,
        memory_injection_layer: Optional[int] = None,
        memory_module_name: Optional[str] = "embeddings",
        
        data_aug_meta_prefix: Optional[str] = None, 
        lbl2data_aug_meta_prefix: Optional[str] = None,

        data_enrich: Optional[bool] = True,
        lbl2data_enrich: Optional[bool] = True,

        margin: Optional[float] = 0.3,
        num_negatives: Optional[int] = 10,
        tau: Optional[float] = 0.1,
        apply_softmax: Optional[bool] = True,

        calib_margin: Optional[float] = 0.05,
        calib_num_negatives: Optional[int] = 10,
        calib_tau: Optional[float] = 0.1,
        calib_apply_softmax: Optional[bool] = False,
        
        calib_loss_weight: Optional[float] = 0.1,
        use_calib_loss: Optional[bool] = False,

        use_encoder_parallel: Optional[bool] = False,
        **kwargs,
    ):
        store_attr('num_total_metadata,num_input_metadata,pad_metadata_idx,metadata_dropout,memory_module_name')
        store_attr('data_aug_meta_prefix,lbl2data_aug_meta_prefix,data_enrich,lbl2data_enrich')
        store_attr('margin,num_negatives,tau,apply_softmax')
        store_attr('calib_margin,calib_num_negatives,calib_tau,calib_apply_softmax')
        store_attr('calib_loss_weight,use_calib_loss')
        
        super().__init__(**kwargs)
        
        if memory_injection_layer is None: self.memory_injection_layer = self.n_layers
            
        assert self.memory_injection_layer <= self.n_layers, (
            f"Invalid memory injection layer: {self.memory_injection_layer}. "
            f"it must be less than the total number of layers ({self.n_layers})."
        )
        

## Helper functions

In [142]:
#| export
def get_memory_module(name: str):
    if name == "embeddings": return UPMAEmbeddingMemory
    else: raise ValueError(f"Invalid memory module: {name}")
    

## Outputs

In [289]:
#| export
@dataclass
class UPMAEncoderOutput(BaseModelOutput):
    repr: Optional[torch.FloatTensor] = None
    
@dataclass
class UPMAModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    data_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    

## Modeling

### `FFN`

* $\hat{x} = dropout(W_2 * max(0, W_1 * x))$

In [133]:
#| export
class FFN(nn.Module):
    def __init__(
        self, 
        config:PretrainedConfig,
        input_dim:int,
        hidden_dim:int,
        output_dim:int,
    ):
        super().__init__()
        self.dropout = nn.Dropout(p=config.dropout)
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.lin1 = nn.Linear(in_features=input_dim, out_features=hidden_dim)
        self.lin2 = nn.Linear(in_features=hidden_dim, out_features=output_dim)
        self.activation = get_activation(config.activation)
        
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)

    def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x
        

### `Embedding Memory`

* $\hat{\mathcal{A}} = \{a_1, a_2, a_3, \dots a_n\}$ be relevant metadata predicted by the linker.

* $\hat{\mathcal{S}} = \{s_1, s_2, s_3, \dots s_n\}$ be scores of the predicted metadata.

* $\hat{\mathcal{R}} = \{1, 2, 3, \dots n\}$ be rank for the predicted metadata.

* $\mathcal{K} \in R^{M \times D}$ be the $M$ memory items for each metadata.

* $\mathcal{P} \in R^{N \times D}$ be the $N$ positional embeddings.

* Rank and score aware metadata representation: $x_m = \mathcal{K}(\hat{\mathcal{A}}) + MLP(\hat{\mathcal{S}}) + \mathcal{P}(\hat{\mathcal{R}})$

In [260]:
#| export
class UPMAEmbeddingMemory(nn.Module):
    
    def __init__(
        self, 
        config: PretrainedConfig
    ):
        super().__init__()
        self.metadata_embeddings = nn.Embedding(config.num_total_metadata, config.dim, padding_idx=config.pad_metadata_idx)
        self.rank_embeddings = nn.Embedding(config.num_input_metadata, config.dim)
        
        self.score_ffn = FFN(config, input_dim=1, hidden_dim=config.hidden_dim, output_dim=config.dim)
        self.out_ffn = FFN(config, input_dim=config.dim, hidden_dim=config.hidden_dim, output_dim=config.dim)

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.metadata_dropout)
        self.register_buffer(
            "position_ids", torch.arange(config.num_input_metadata).expand((1, -1)), persistent=False
        )
        
    def forward(
        self,
        input_idx: torch.Tensor,
        input_embeds: Optional[torch.Tensor] = None,
        input_scores: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        # `input_idx`: (bs, num_input_metadata)
        # `input_scores`: (bs, num_input_metadata)
        
        if input_ids is not None:
            input_embeds = self.metadata_embeddings(input_idx) # (bs, num_input_metadata, dim)
            
        if input_embeds.size(1) != self.num_input_metadata:
            raise ValueError(
                f"Invalid input: expected {self.num_input_metadata} metadata items, "
                f"but got {input_embeds.size(1)}."
            )
            
        if position_ids is None:
            position_ids = (
                self.position_ids[:, :self.num_input_metadata]
                if input_scores is None else 
                torch.argsort(scores, dim=1, descending=True)
            )
            
        rank_embeddings = self.rank_embeddings(position_ids) # (bs, num_input_metadata, dim) or (1, num_input_metadata, dim)

        embeddings = input_embeds + rank_embeddings

        if input_scores is not None:
            score_embeddings = self.score_ffn(input_scores) # (bs, num_input_metadata, dim)
            embeddings = embeddings + score_embeddings
            
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return self.out_ffn(embeddings)
        

## `UPMA Model`

In [280]:
#| export
class UPMAModel(PreTrainedModel):
    config: UPMAConfig
    load_tf_weights = None
    base_model_prefix = "distilbert"
    supports_gradient_checkpointing = True
    _supports_flash_attn = True
    _supports_sdpa = True
    
    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.embeddings = Embeddings(config)
        self.memory_module = get_memory_module(config.memory_module_name)(config)
        
        self.n_layers = config.n_layers
        self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        self.gradient_checkpointing = False

        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self._use_sdpa = config._attn_implementation == "sdpa"

    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight
            )
        elif isinstance(module, UPMAEmbeddingMemory) and self.config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                self.config.num_input_metadata, self.config.dim, module.rank_embeddings.weight
            )
            
    def get_position_embeddings(self) -> nn.Embedding:
        return self.embeddings.position_embeddings

    def resize_position_embeddings(self, new_num_position_embeddings: int):
        num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings

        # no resizing needs to be done if the length stays the same
        if num_position_embeds_diff == 0:
            return

        logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
        self.config.max_position_embeddings = new_num_position_embeddings

        old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone()

        self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim)

        if self.config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight
            )
        else:
            with torch.no_grad():
                if num_position_embeds_diff > 0:
                    self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter(
                        old_position_embeddings_weight
                    )
                else:
                    self.embeddings.position_embeddings.weight = nn.Parameter(
                        old_position_embeddings_weight[:num_position_embeds_diff]
                    )
        # move position_embeddings to correct device
        self.embeddings.position_embeddings.to(self.device)

    def get_input_embeddings(self) -> nn.Embedding:
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, new_embeddings: nn.Embedding):
        self.embeddings.word_embeddings = new_embeddings

    def _prune_heads(self, heads_to_prune: dict[int, list[list[int]]]):
        for layer, heads in heads_to_prune.items():
            self.transformer.layer[layer].attention.prune_heads(heads)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        
        metadata_idx: Optional[torch.Tensor] = None,
        metadata_ids: Optional[torch.Tensor] = None,
        metadata_attention_mask: Optional[torch.Tensor] = None,
        metadata_scores: Optional[torch.Tensor] = None,
        inject_memory: Optional[bool] = True,
        
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[BaseModelOutput, tuple[torch.Tensor, ...]]:
         
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        head_mask_is_none = head_mask is None
        # Prepare head mask if needed
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embeddings = self.embeddings(input_ids, inputs_embeds)  # (bs, seq_length, dim)
        if inject_memory:
            memory_embeddings = self.memory_module(
                input_idx=metadata_idx,
                input_ids=metadata_ids,
                input_attention_mask=metadata_attention_mask,
                input_scores=metadata_scores
            ) # (bs, num_input_metadata, dim)

        if self._use_flash_attention_2:
            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        else:
            if attention_mask is None:
                attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)

            if self._use_sdpa and head_mask_is_none and not output_attentions:
                attention_mask = _prepare_4d_attention_mask_for_sdpa(
                    attention_mask, embeddings.dtype, tgt_len=input_shape[1]
                )
                
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        hidden_state = embeddings
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state,)

            if inject_memory and i+1 == self.config.memory_injection_layer:
                hidden_state = torch.cat([hidden_state, memory_embeddings], dim=1)
                # modify the attention mask here.
                
            layer_outputs = layer_module(
                hidden_state,
                attention_mask,
                head_mask[i],
                output_attentions,
            )

            hidden_state = layer_outputs[-1]

            if output_attentions:
                if len(layer_outputs) != 2:
                    raise ValueError(f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}")

                attentions = layer_outputs[0]
                all_attentions = all_attentions + (attentions,)
            else:
                if len(layer_outputs) != 1:
                    raise ValueError(f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}")

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_state,)

        if not return_dict:
            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
            
        return BaseModelOutput(
            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
        )
        

## Parameters

In [268]:
#| export
class Parameters:
    
    @staticmethod
    def from_data_aug_meta_prefix_for_encoder(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^{prefix}.*_(input_ids|attention_mask|data2ptr|idx)$', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[f"metadata_{param}"] = kwargs[arg]
        return inputs
    
    @staticmethod
    def from_aug_meta_prefix_for_feature(feat:str, prefix:str, **kwargs):
        keys = ['attention_mask', 'input_ids', 'idx']        
        inputs = {f'{prefix}_{k}': kwargs[f'{prefix}_{k}'] for k in keys if f'{prefix}_{k}' in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            inputs.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        return inputs

    @staticmethod
    def from_aug_meta_prefix_for_loss(feat:str, prefix:str, **kwargs):
        keys = [f'{prefix}_idx', f'p{prefix}_idx']
        args = {k: kwargs[k] for k in keys if k in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            args.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        if prefix is not None and f'p{prefix}_{feat}2ptr' in kwargs:
            args.update({f'p{prefix}_data2ptr': kwargs[f'p{prefix}_{feat}2ptr']})

        inputs = {}
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = args[arg]
        return inputs
        

### Examples

In [264]:
kwargs = Parameters.from_aug_meta_prefix_for_feature('lbl', 'cat2lbl', **batch)
kwargs.keys()

dict_keys(['cat2lbl_attention_mask', 'cat2lbl_input_ids', 'cat2lbl_idx', 'cat2lbl_data2ptr'])

In [265]:
Parameters.from_data_aug_meta_prefix_for_encoder('cat2lbl', **kwargs)

{'cat2lbl': {'metadata_attention_mask': tensor([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0]]),
  'metadata_input_ids': tensor([[  101, 17151, 12412,  4155,   102,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0],
          [  101, 12943, 23296, 21823, 19833,  3512,  4155,   102,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0],
          [  101,  7939,  2050, 29652,  4710, 18595, 20240,  2319,  4155,   102,
               0,     0,     0,     0,     0,     0, 

## `UPMAEncoder`

In [374]:
class UPMAEncoder(UPMAModel):

    @classmethod
    def from_pretrained(cls):
        src_model = DistilBertModel.from_pretrained('distilbert-base-uncased')
        targ_model = cls(config)

        src_sd, targ_sd = src_model.state_dict(), targ_model.state_dict()
        src_keys, targ_keys = set(src_sd.keys()), set(targ_sd.keys())
        
        for k in src_keys.intersection(targ_keys):
            assert targ_sd[k].shape == src_sd[k].shape, (
                f"Shape mismatch at key '{k}'. "
                f"Expected {targ_sd[k].shape}, but got {src_sd[k].shape} in source state_dict."
            )
            targ_sd[k].copy_(src_sd[k])

        diff_keys = targ_keys.difference(src_keys)
        transformer_keys = [k for k in src_keys if k.startswith("transformer")]
        for k in transformer_keys:
            targ_k = k.split('.', maxsplit=1)[1]
            
            assert targ_k in targ_sd, (
                f"Unexpected key '{targ_k}' encountered, not found in target state_dict."
            )
            
            assert targ_sd[targ_k].shape == src_sd[k].shape, (
                f"Shape mismatch at key '{k}'. "
                f"Expected {targ_sd[targ_k].shape}, but got {src_sd[k].shape} in source state_dict."
            )
            
            targ_sd[targ_k].copy_(src_sd[k])
            diff_keys.remove(targ_k)
        return targ_model, diff_keys
    
    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_inject_memory: Optional[bool]=True,
        **kwargs
    ):
        meta_kwargs = Parameters.from_data_aug_meta_prefix_for_encoder(data_aug_meta_prefix, **kwargs)
        meta_kwargs = meta_kwargs.get(data_aug_meta_prefix, dict())
        
        output = super().forward(
            input_ids=data_input_ids, 
            attention_mask=data_attention_mask,
            inject_memory=data_inject_memory,
            **meta_kwargs
        )
        
        # NOTE: Pooling can be done according to modified attention mask.
        data_repr = Pooling.mean_pooling(output[0], data_attention_mask)
        return UPMAEncoderOutput(repr=data_repr, **output)
        

### Example

In [376]:
config = UPMAConfig(
    num_total_metadata=block.train.dset.meta['cat_meta'].n_meta,
    num_input_metadata = 5,
    pad_metadata_idx=None,
    metadata_dropout=0.1,
    memory_injection_layer=None,
    memory_module_name="embeddings",

    data_aug_meta_prefix="cat2data",
    lbl2data_aug_meta_prefix="cat2lbl",

    data_enrich=True,
    lbl2data_enrich=False,

    margin=0.3,
    num_negatives=5,
    tau=0.1,
    apply_softmax=True,

    calib_margin=0.3,
    calib_num_negatives=10,
    calib_tau=0.1,
    calib_apply_softmax=False,
    
    calib_loss_weight=0.1,
    use_calib_loss=True,
    
    use_encoder_parallel=False
)

In [384]:
model, new_keys = UPMAEncoder.from_pretrained()

loading configuration file config.json from cache at /Users/suchith720/.cache/huggingface/hub/models--distilbert-base-uncased/snapshots/12040accade4e8a0f71eabdb258fecc2e7e948be/config.json
Model config DistilBertConfig {
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.49.0",
  "vocab_size": 30522
}

loading weights file model.safetensors from cache at /Users/suchith720/.cache/huggingface/hub/models--distilbert-base-uncased/snapshots/12040accade4e8a0f71eabdb258fecc2e7e948be/model.safetensors
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertM

In [385]:
new_keys

{'memory_module.LayerNorm.bias',
 'memory_module.LayerNorm.weight',
 'memory_module.metadata_embeddings.weight',
 'memory_module.out_ffn.lin1.bias',
 'memory_module.out_ffn.lin1.weight',
 'memory_module.out_ffn.lin2.bias',
 'memory_module.out_ffn.lin2.weight',
 'memory_module.rank_embeddings.weight',
 'memory_module.score_ffn.lin1.bias',
 'memory_module.score_ffn.lin1.weight',
 'memory_module.score_ffn.lin2.bias',
 'memory_module.score_ffn.lin2.weight'}

In [387]:
model.memory_module.metadata_embeddings

Embedding(656086, 768)

In [391]:
linker_dset = block.linker_dset('cat_meta')

In [433]:
meta_dset

<xcai.sdata.SMainXCDataset at 0x16872ae40>

In [434]:
dset = block.train.dset

In [435]:
_keys_to_ignore = ["data_info", "data_lbl", "lbl_info", "data_lbl_filterer", "curr_data_lbl", "data_lbl_scores"]
args = [o for o in vars(dset.data).keys() if not o.startswith('__') and o not in _keys_to_ignore]

In [436]:
args = {k: kwargs.get(k, getattr(dset.data, k)) for k in args}

In [437]:
meta_dset = type(dset.data)(data_info=dset.meta['cat_meta'].meta_info, **args)

## `UPA000`

In [300]:
#| export
class UPA000(PreTrainedModel):
    
    def __init__(
        self, 
        config: UPMAConfig,
    ):
        super().__init__(config)
        self.config, self.encoder = config, UPMAEncoder(config)
        self.rep_loss_fn = MultiTriplet(margin=config.margin, n_negatives=config.num_negatives, tau=config.tau, 
                                        apply_softmax=config.apply_softmax, reduce='mean')
        self.cab_loss_fn = Calibration(margin=config.calib_margin, tau=config.calib_tau, n_negatives=config.calib_num_negatives, 
                                       apply_softmax=config.calib_apply_softmax, reduce='mean')
        
    def init_heads_to_identity(self):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.init_heads_to_identity()

    def init_combiner_to_last_layer(self):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.init_combiner_to_last_layer()

    def set_memory_embeddings(self, embed:torch.Tensor):
        if self.encoder is None: raise ValueError('Encoder not initialized.')
        self.encoder.set_memory_embeddings(embed)
        
    def compute_loss(
        self, 
        inp_repr: Optional[torch.Tensor] = None, 
        targ_repr: Optional[torch.Tensor] = None,
        targ_ptr: Optional[torch.Tensor] = None, 
        targ_idx: Optional[torch.Tensor] = None,
        ptarg_ptr: Optional[torch.Tensor] = None,
        ptarg_idx: Optional[torch.Tensor] = None
    ):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def calibration_loss(
        self, 
        einp_repr: Optional[torch.Tensor] = None, 
        inp_repr: Optional[torch.Tensor] = None,
        targ_repr: Optional[torch.Tensor] = None,
        targ_ptr: Optional[torch.Tensor] = None,
        targ_idx: Optional[torch.Tensor] = None,
        ptarg_ptr: Optional[torch.Tensor] = None,
        ptarg_idx: Optional[torch.Tensor] = None
    ):
        return self.config.calib_loss_weight * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def forward(
        self,
        data_input_ids: Optional[torch.Tensor] = None,
        data_attention_mask: Optional[torch.Tensor] = None,
        lbl2data_data2ptr: Optional[torch.Tensor] = None,
        lbl2data_idx: Optional[torch.Tensor] = None,
        lbl2data_input_ids: Optional[torch.Tensor] = None,
        lbl2data_attention_mask: Optional[torch.Tensor] = None,
        plbl2data_data2ptr: Optional[torch.Tensor] = None,
        plbl2data_idx: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        encoder = XCDataParallel(module=self.encoder) if self.config.use_encoder_parallel else self.encoder
        
        data_meta_kwargs = Parameters.from_aug_meta_prefix_for_feature('data', self.config.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_enrich=self.config.data_enrich, **data_meta_kwargs)
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_aug_meta_prefix_for_feature('lbl', self.config.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_enrich=self.config.lbl2data_enrich, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.enriched_repr, lbl2data_o.repr,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)
            
        if not return_dict:
            o = (data_o.repr,lbl2data_o.repr)
            return ((loss,) + o) if loss is not None else o

        return UPAModelOutput(
            loss=loss,
            data_repr=data_o.repr,
            lbl2data_repr=lbl2data_o.repr,
        )
        

### Example

In [301]:
config = UPMAConfig(
    num_total_metadata=block.train.dset.meta['cat_meta'].n_meta,
    num_input_metadata = 5,
    pad_metadata_idx=None,
    metadata_dropout=0.1,
    memory_injection_layer=None,
    memory_module_name="embeddings",

    data_aug_meta_prefix="cat2data",
    lbl2data_aug_meta_prefix="cat2lbl",

    data_enrich=True,
    lbl2data_enrich=False,

    margin=0.3,
    num_negatives=5,
    tau=0.1,
    apply_softmax=True,

    calib_margin=0.3,
    calib_num_negatives=10,
    calib_tau=0.1,
    calib_apply_softmax=False,
    
    calib_loss_weight=0.1,
    use_calib_loss=True,
    
    use_encoder_parallel=False
)

In [302]:
model = UPA000(config)