In [2]:
#| default_exp models.upma

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [5]:
#| export
import torch, torch.nn as nn, re, os
from tqdm.auto import tqdm
from dataclasses import dataclass
from torch.nn.parallel import DataParallel
from torch.utils.data import DataLoader
from typing import Optional, Union, Tuple, Any, Dict, Sequence

from transformers import DistilBertConfig, DistilBertModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput, ModelOutput
from transformers.models.distilbert.modeling_distilbert import Embeddings, TransformerBlock, create_sinusoidal_embeddings
from transformers.activations import get_activation
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from transformers.pytorch_utils import apply_chunking_to_forward

from xcai.core import *
from xcai.losses import *
from xcai.data import MainXCDataset
from xcai.sdata import SMainXCDataset, identity_collate_fn
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import Pooling

## Load data

In [6]:
from xcai.main import *

In [7]:
data_dir = '/Users/suchith720/Projects/data'
config_file = 'wikiseealsotitles'
config_key = 'data_meta'

mname = 'sentence-transformers/msmarco-distilbert-base-v4'

In [8]:
pkl_dir = f'{data_dir}/processed/mogicX/'
pkl_file = f'{pkl_dir}/wikiseealsotitles_data-meta_distilbert-base-uncased_sxc.joblib'

In [9]:
block = build_block(pkl_file, config_file, True, config_key, data_dir=data_dir, 
                    n_sdata_meta_samples=3, do_build=False)

In [10]:
batch = block.train.dset.__getitems__([10, 30])
batch = block.train.dset.__getitems__([10, 200, 500])

In [11]:
batch.keys()

dict_keys(['data_idx', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_data2ptr', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'pcat2lbl_idx', 'pcat2lbl_lbl2ptr', 'cat2lbl_idx', 'cat2lbl_lbl2ptr', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_data2ptr', 'pcat2lbl_data2ptr'])

## Configuration

In [105]:
#| export
class UPMAConfig(DistilBertConfig):

    def __init__(
        self,
        num_total_metadata: Optional[int] = None,
        num_input_metadata: Optional[int] = 3,
        metadata_dropout: Optional[float] = 0.1,
        memory_injection_layer: Optional[int] = None,
        memory_module_name: Optional[str] = "embeddings",
        
        data_aug_meta_prefix: Optional[str] = None, 
        lbl2data_aug_meta_prefix: Optional[str] = None,

        data_inject_memory: Optional[bool] = True,
        lbl2data_inject_memory: Optional[bool] = True,
        data_repr_pooling: Optional[bool] = True,

        margin: Optional[float] = 0.3,
        num_negatives: Optional[int] = 10,
        tau: Optional[float] = 0.1,
        apply_softmax: Optional[bool] = True,

        calib_margin: Optional[float] = 0.05,
        calib_num_negatives: Optional[int] = 10,
        calib_tau: Optional[float] = 0.1,
        calib_apply_softmax: Optional[bool] = False,
        
        calib_loss_weight: Optional[float] = 0.1,
        use_calib_loss: Optional[bool] = False,

        use_encoder_parallel: Optional[bool] = False,
        
        initialize_memory_embeddings_from_injection_layer_mean: Optional[bool] = True,
        metadata_embedding_file: Optional[str] = None,
        
        **kwargs,
    ):
        store_attr('num_total_metadata,num_input_metadata,metadata_dropout,memory_module_name')
        store_attr('data_aug_meta_prefix,lbl2data_aug_meta_prefix')
        store_attr('data_inject_memory,lbl2data_inject_memory,data_repr_pooling')
        store_attr('margin,num_negatives,tau,apply_softmax')
        store_attr('calib_margin,calib_num_negatives,calib_tau,calib_apply_softmax')
        store_attr('calib_loss_weight,use_calib_loss,use_encoder_parallel')
        store_attr('initialize_memory_embeddings_from_injection_layer_mean,metadata_embedding_file')
        
        super().__init__(**kwargs)
        
        if memory_injection_layer is None: self.memory_injection_layer = self.n_layers
            
        assert self.memory_injection_layer <= self.n_layers, (
            f"Invalid memory injection layer: {self.memory_injection_layer}. "
            f"it must be less than the total number of layers ({self.n_layers})."
        )
        

### Example

In [106]:
config = UPMAConfig(
    num_total_metadata = block.train.dset.meta['cat_meta'].n_meta,
    num_input_metadata = 3,
    metadata_dropout = 0.1,
    memory_injection_layer = None,
    memory_module_name = "embeddings",

    data_aug_meta_prefix="cat2data",
    lbl2data_aug_meta_prefix="cat2lbl",

    data_inject_memory=True,
    lbl2data_inject_memory=False,
    data_repr_pooling=True,

    margin=0.3,
    num_negatives=5,
    tau=0.1,
    apply_softmax=True,

    calib_margin=0.3,
    calib_num_negatives=10,
    calib_tau=0.1,
    calib_apply_softmax=False,
    
    calib_loss_weight=0.1,
    use_calib_loss=True,
    
    use_encoder_parallel=False
)

In [108]:
config.data_inject_memory, config.lbl2data_inject_memory

(True, False)

## Helper functions

In [15]:
#| export
def get_memory_module(name: str):
    if name == "embeddings": return UPMAEmbeddingMemory
    else: raise ValueError(f"Invalid memory module: {name}")

def align_tensor(tensor:torch.Tensor, indptr:torch.Tensor, pad_tok:Optional[Union[int,float]]=0):
    tensor_shape = tensor.shape
    r, c = len(indptr), indptr.max()

    row_idx = torch.repeat_interleave(torch.arange(r, device=tensor.device), indptr)
    indptr = torch.cat([indptr.new_tensor([0]), indptr.cumsum(dim=0)[:-1]], dim=0)
    within_idx = torch.arange(tensor_shape[0], device=indptr.device) - indptr[row_idx]

    output = torch.full((r, c, *tensor_shape[1:]), pad_tok, device=tensor.device, dtype=tensor.dtype)
    mask = torch.zeros((r, c), device=tensor.device)

    output[row_idx, within_idx] = tensor
    mask[row_idx, within_idx] = 1.0

    return output, mask

def alignment_mask(indptr:torch.Tensor):
    n, r, c = indptr.sum(), len(indptr), indptr.max()

    row_idx = torch.repeat_interleave(torch.arange(r, device=indptr.device), indptr)
    indptr = torch.cat([indptr.new_tensor([0]), indptr.cumsum(dim=0)[:-1]], dim=0)
    within_idx = torch.arange(n, device=indptr.device) - indptr[row_idx]

    mask = torch.zeros((r, c), device=indptr.device, dtype=torch.int64)
    mask[row_idx, within_idx] = 1
    return mask


### Examples

In [60]:
t = torch.randint(100, size=(5, 3))
indptr = torch.tensor([1, 1, 0, 3])

In [61]:
o = align_tensor(t, indptr)

In [62]:
o

(tensor([[[ 6, 84, 31],
          [ 0,  0,  0],
          [ 0,  0,  0]],
 
         [[16, 12, 85],
          [ 0,  0,  0],
          [ 0,  0,  0]],
 
         [[ 0,  0,  0],
          [ 0,  0,  0],
          [ 0,  0,  0]],
 
         [[95, 29, 80],
          [82, 45, 10],
          [81, 43, 63]]]),
 tensor([[1., 0., 0.],
         [1., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.]]))

## Outputs

In [124]:
#| export
@dataclass
class UPMAEncoderOutput(BaseModelOutput):
    repr: Optional[torch.FloatTensor] = None
    
@dataclass
class UPMAModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    data_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    

## Modeling

### `FFN`

* $\hat{x} = dropout(W_2 * max(0, W_1 * x))$

In [17]:
#| export
class FFN(nn.Module):
    def __init__(
        self, 
        config:PretrainedConfig,
        input_dim:int,
        hidden_dim:int,
        output_dim:int,
    ):
        super().__init__()
        self.config = config
        self.dropout = nn.Dropout(p=config.dropout)
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.lin1 = nn.Linear(in_features=input_dim, out_features=hidden_dim)
        self.lin2 = nn.Linear(in_features=hidden_dim, out_features=output_dim)
        self.activation = get_activation(config.activation)
        
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)

    def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
        x = self.lin1(input)
        x = self.activation(x)
        x = self.lin2(x)
        x = self.dropout(x)
        return x
        

### `Embedding Memory`

* $\hat{\mathcal{A}} = \{a_1, a_2, a_3, \dots a_n\}$ be relevant metadata predicted by the linker.

* $\hat{\mathcal{S}} = \{s_1, s_2, s_3, \dots s_n\}$ be scores of the predicted metadata.

* $\hat{\mathcal{R}} = \{1, 2, 3, \dots n\}$ be rank for the predicted metadata.

* $\mathcal{K} \in R^{M \times D}$ be the $M$ memory items for each metadata.

* $\mathcal{P} \in R^{N \times D}$ be the $N$ positional embeddings.

* Rank and score aware metadata representation: $x_m = \mathcal{K}(\hat{\mathcal{A}}) + MLP(\hat{\mathcal{S}}) + \mathcal{P}(\hat{\mathcal{R}})$

In [18]:
#| export
class UPMAEmbeddingMemory(nn.Module):
    
    def __init__(
        self, 
        config: PretrainedConfig
    ):
        super().__init__()
        self.config = config
        self.metadata_embeddings = nn.Embedding(config.num_total_metadata+1, config.dim, padding_idx=config.num_total_metadata)
        self.rank_embeddings = nn.Embedding(config.num_input_metadata, config.dim)
        
        self.score_ffn = FFN(config, input_dim=1, hidden_dim=config.hidden_dim, output_dim=config.dim)
        self.out_ffn = FFN(config, input_dim=config.dim, hidden_dim=config.hidden_dim, output_dim=config.dim)

        self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
        self.dropout = nn.Dropout(config.metadata_dropout)
        self.register_buffer(
            "position_ids", torch.arange(config.num_input_metadata).expand((1, -1)), persistent=False
        )
        
    def get_metadata_embeddings(self) -> torch.Tensor:
        return self.metadata_embeddings.weight

    def set_metadata_embeddings(self, new_embeddings: torch.Tensor):
        self.metadata_embeddings.weight.copy_(new_embeddings)

    def get_rank_embeddings(self) -> torch.Tensor:
        return self.rank_embeddings.weight

    def set_rank_embeddings(self, new_embeddings: torch.Tensor):
        self.rank_embeddings.weight.copy_(new_embeddings)
        
    def forward(
        self,
        input_idx: torch.Tensor,
        embeds: Optional[torch.Tensor] = None,
        scores: Optional[torch.Tensor] = None,
        indptr: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        # `input_idx`: (total_input_metadata)
        # `scores`: (total_input_metadata)
        
        if input_idx is not None:
            input_idx, mask = align_tensor(input_idx, indptr, pad_tok=self.config.num_total_metadata) # (bs, num_input_metadata)
            embeds = self.metadata_embeddings(input_idx) # (bs, num_input_metadata, dim)
        else:
            assert embeds is not None, "Invalid input: both `input_idx` and `embeds` cannot be None." 
            embeds, mask = align_tensor(embeds, indptr)
            
        if embeds.size(1) != self.config.num_input_metadata:
            raise ValueError(
                f"Invalid input: expected {self.config.num_input_metadata} metadata items, "
                f"but got {embeds.size(1)}."
            )

        if scores is not None:
            scores, mask = align_tensor(scores, indptr) # (bs, num_input_metadata)
            
        if position_ids is None:
            position_ids = (
                self.position_ids[:, :self.config.num_input_metadata]
                if scores is None else 
                torch.argsort(scores, dim=1, descending=True)
            )
            
        rank_embeddings = self.rank_embeddings(position_ids) # (bs, num_input_metadata, dim) or (1, num_input_metadata, dim)

        embeddings = embeds + rank_embeddings

        if scores is not None:
            score_embeddings = self.score_ffn(scores.unsqueeze(2)) # (bs, num_input_metadata, dim)
            embeddings = embeddings + score_embeddings
            
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return self.out_ffn(embeddings), mask # (bs, num_input_metadata, dim), (bs, num_input_metadata)
        

#### Example

In [139]:
model = UPMAEmbeddingMemory(config)

In [141]:
inputs = {
    'input_idx': torch.randint(config.num_total_metadata, size=(5,)),
    'scores': torch.rand((5,)),
    'indptr': torch.tensor([1, 1, 0, 3], dtype=torch.int64),
}

In [143]:
embeds, mask = model(**inputs)

## Parameters

In [19]:
#| export
class Parameters:
    
    @staticmethod
    def from_data_aug_meta_prefix_for_encoder(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^{prefix}.*_(input_ids|attention_mask|data2ptr|idx)$', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[f"metadata_{param}"] = kwargs[arg]
        return inputs
    
    @staticmethod
    def from_aug_meta_prefix_for_feature(feat:str, prefix:str, **kwargs):
        keys = ['attention_mask', 'input_ids', 'idx']        
        inputs = {f'{prefix}_{k}': kwargs[f'{prefix}_{k}'] for k in keys if f'{prefix}_{k}' in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            inputs.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        return inputs

    @staticmethod
    def from_aug_meta_prefix_for_loss(feat:str, prefix:str, **kwargs):
        keys = [f'{prefix}_idx', f'p{prefix}_idx']
        args = {k: kwargs[k] for k in keys if k in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            args.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        if prefix is not None and f'p{prefix}_{feat}2ptr' in kwargs:
            args.update({f'p{prefix}_data2ptr': kwargs[f'p{prefix}_{feat}2ptr']})

        inputs = {}
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = args[arg]
        return inputs
        

### Examples

In [39]:
kwargs = Parameters.from_aug_meta_prefix_for_feature('data', 'cat2data', **batch)
kwargs.keys()

dict_keys(['cat2data_attention_mask', 'cat2data_input_ids', 'cat2data_idx', 'cat2data_data2ptr'])

In [40]:
upma_model_inputs = Parameters.from_data_aug_meta_prefix_for_encoder('cat2data', **kwargs)['cat2data']

In [41]:
upma_model_inputs.keys()

dict_keys(['metadata_attention_mask', 'metadata_input_ids', 'metadata_idx', 'metadata_data2ptr'])

## `UPMA Model`

In [55]:
#| export
class UPMAModel(PreTrainedModel):
    config: UPMAConfig
    load_tf_weights = None
    base_model_prefix = "distilbert"
    supports_gradient_checkpointing = True
    _supports_flash_attn = True
    _supports_sdpa = True
    
    def __init__(self, config: PretrainedConfig):
        super().__init__(config)
        self.embeddings = Embeddings(config)
        self.memory_module = get_memory_module(config.memory_module_name)(config)
        
        self.n_layers = config.n_layers
        self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        self.gradient_checkpointing = False

        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
        self._use_sdpa = config._attn_implementation == "sdpa"

    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight
            )
        elif isinstance(module, UPMAEmbeddingMemory) and self.config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                self.config.num_input_metadata, self.config.dim, module.rank_embeddings.weight
            )
            
    def get_position_embeddings(self) -> nn.Embedding:
        return self.embeddings.position_embeddings

    def resize_position_embeddings(self, new_num_position_embeddings: int):
        num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings

        # no resizing needs to be done if the length stays the same
        if num_position_embeds_diff == 0:
            return

        logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
        self.config.max_position_embeddings = new_num_position_embeddings

        old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone()

        self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim)

        if self.config.sinusoidal_pos_embds:
            create_sinusoidal_embeddings(
                n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight
            )
        else:
            with torch.no_grad():
                if num_position_embeds_diff > 0:
                    self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter(
                        old_position_embeddings_weight
                    )
                else:
                    self.embeddings.position_embeddings.weight = nn.Parameter(
                        old_position_embeddings_weight[:num_position_embeds_diff]
                    )
        # move position_embeddings to correct device
        self.embeddings.position_embeddings.to(self.device)

    def get_input_embeddings(self) -> nn.Embedding:
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, new_embeddings: nn.Embedding):
        self.embeddings.word_embeddings = new_embeddings

    def _prune_heads(self, heads_to_prune: dict[int, list[list[int]]]):
        for layer, heads in heads_to_prune.items():
            self.transformer.layer[layer].attention.prune_heads(heads)

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        
        metadata_idx: Optional[torch.Tensor] = None,
        metadata_input_ids: Optional[torch.Tensor] = None,
        metadata_attention_mask: Optional[torch.Tensor] = None,
        metadata_scores: Optional[torch.Tensor] = None,
        metadata_data2ptr: Optional[torch.Tensor] = None,
        metadata_embeds: Optional[torch.Tensor] = None,
        inject_memory: Optional[bool] = True,
        
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[BaseModelOutput, tuple[torch.Tensor, ...]]:
         
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        head_mask_is_none = head_mask is None
        # Prepare head mask if needed
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embeddings = self.embeddings(input_ids, inputs_embeds)  # (bs, seq_length, dim)
        
        if inject_memory:
            memory_embeddings, memory_mask = self.memory_module(
                input_idx = metadata_idx,
                embeds = metadata_embeds,
                scores = metadata_scores,
                indptr = metadata_data2ptr,
                input_ids = metadata_input_ids,
                attention_mask = metadata_attention_mask,
            ) # (bs, num_input_metadata, dim), (bs, num_input_metadata)
            memory_mask = torch.cat([attention_mask, memory_mask], dim=1)
            
        def _prepare_attention_mask(attention_mask, input_shape):
            if self._use_flash_attention_2:
                attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
            else:
                if attention_mask is None:
                    attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)
    
                if self._use_sdpa and head_mask_is_none and not output_attentions:
                    attention_mask = _prepare_4d_attention_mask_for_sdpa(
                        attention_mask, embeddings.dtype, tgt_len=input_shape[1]
                    )
            return attention_mask

        attention_mask = _prepare_attention_mask(attention_mask, input_shape)
                
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        hidden_state = embeddings
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_state,)

            if inject_memory and i+1 == self.config.memory_injection_layer:
                hidden_state = torch.cat([hidden_state, memory_embeddings], dim=1)
                attention_mask = _prepare_attention_mask(memory_mask, memory_mask.size())
                
            layer_outputs = layer_module(
                hidden_state,
                attention_mask,
                head_mask[i],
                output_attentions,
            )

            hidden_state = layer_outputs[-1]

            if output_attentions:
                if len(layer_outputs) != 2:
                    raise ValueError(f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}")

                attentions = layer_outputs[0]
                all_attentions = all_attentions + (attentions,)
            else:
                if len(layer_outputs) != 1:
                    raise ValueError(f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}")

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_state,)

        if not return_dict:
            return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
            
        return BaseModelOutput(
            last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
        )
        

### Example

In [48]:
model = UPMAModel(config)

In [49]:
upma_model_inputs.keys()

dict_keys(['metadata_attention_mask', 'metadata_input_ids', 'metadata_idx', 'metadata_data2ptr'])

In [51]:
o = model(input_ids=batch['data_input_ids'], attention_mask=batch['data_attention_mask'], **upma_model_inputs)

In [53]:
o[0].shape

torch.Size([3, 35, 768])

## `UPMAEncoder`

In [103]:
#| export
class UPMAEncoder(UPMAModel):

    @classmethod
    def from_pretrained(
        cls,
        config:PretrainedConfig,
        meta_dset:Optional[Union[MainXCDataset, SMainXCDataset]] = None,
        batch_size:Optional[int] = 100,
    ):
        src_model = DistilBertModel.from_pretrained('distilbert-base-uncased')
        targ_model = cls(config)

        targ_model.init_weights()
        targ_model.eval()
        
        src_sd, targ_sd = src_model.state_dict(), targ_model.state_dict()
        src_keys, targ_keys = set(src_sd.keys()), set(targ_sd.keys())
        
        for k in src_keys.intersection(targ_keys):
            assert targ_sd[k].shape == src_sd[k].shape, (
                f"Shape mismatch at key '{k}'. "
                f"Expected {targ_sd[k].shape}, but got {src_sd[k].shape} in source state_dict."
            )
            targ_sd[k].copy_(src_sd[k])

        diff_keys = targ_keys.difference(src_keys)
        transformer_keys = [k for k in src_keys if k.startswith("transformer")]
        for k in transformer_keys:
            targ_k = k.split('.', maxsplit=1)[1]
            
            assert targ_k in targ_sd, (
                f"Unexpected key '{targ_k}' encountered, not found in target state_dict."
            )
            
            assert targ_sd[targ_k].shape == src_sd[k].shape, (
                f"Shape mismatch at key '{k}'. "
                f"Expected {targ_sd[targ_k].shape}, but got {src_sd[k].shape} in source state_dict."
            )
            
            targ_sd[targ_k].copy_(src_sd[k])
            diff_keys.remove(targ_k)

        if config.initialize_memory_embeddings_from_injection_layer_mean:
            targ_model.initialize_memory_embeddings_from_injection_layer_mean(
                meta_dset,
                save_file=config.metadata_embedding_file,
                batch_size=batch_size,
                use_encoder_parallel=config.use_encoder_parallel,
            )
            
        return targ_model

    def initialize_memory_embeddings_from_injection_layer_mean(
        self,
        meta_dset:Optional[Union[MainXCDataset, SMainXCDataset]] = None,
        save_file:Optional[str] = None,
        batch_size:Optional[int] = 100,
        use_encoder_parallel:Optional[bool] = True
    ):
        if save_file is not None and os.path.exists(save_file):
            meta_embeds = torch.load(save_file)
        else:
            if meta_dset is None: 
                raise ValueError(
                    f"Invalid argument: 'meta_dset' cannot be None. "
                    f"Please pass a valid dataset."
                )
                
            meta_embeds, device = [], next(self.parameters()).device
            meta_dl = DataLoader(meta_dset, batch_size=batch_size, collate_fn=identity_collate_fn)
    
            model = XCDataParallel(module=self) if use_encoder_parallel else self
            for batch in tqdm(meta_dl):
                for k, v in batch.items(): 
                    if isinstance(v, torch.Tensor): batch[k] = v.to(device)
                output = model(**batch, data_inject_memory=False, data_output_hidden_states=True)
                embeds = output.hidden_states[self.config.memory_injection_layer - 1]
                meta_embeds.append(Pooling.mean_pooling(embeds, batch['data_attention_mask']))
            meta_embeds = torch.cat(meta_embeds, dim=0)
            if save_file is not None: torch.save(meta_embeds, save_file)
            
        self.memory_module.set_metadata_embeddings(meta_embeds)
        return meta_embeds
    
    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_inject_memory: Optional[bool]=True,
        data_output_attentions: Optional[bool] = None,
        data_output_hidden_states: Optional[bool] = None,
        data_return_dict: Optional[bool] = None,
        **kwargs
    ):
        meta_kwargs = Parameters.from_data_aug_meta_prefix_for_encoder(data_aug_meta_prefix, **kwargs)
        meta_kwargs = meta_kwargs.get(data_aug_meta_prefix, dict())
        
        output = super().forward(
            input_ids=data_input_ids, 
            attention_mask=data_attention_mask,
            inject_memory=data_inject_memory,
            output_attentions=data_output_attentions,
            output_hidden_states=data_output_hidden_states,
            return_dict=data_return_dict,
            **meta_kwargs
        )
        
        if self.config.data_repr_pooling:
            embeds = output[0][:, :data_attention_mask.shape[1], :]
            attention_mask = data_attention_mask
        else:
            embeds = output[0]
            if 'metadata_data2ptr' in meta_kwargs:
                memory_mask = alignment_mask(meta_kwargs['metadata_data2ptr'])
                attention_mask = torch.cat([data_attention_mask, memory_mask], dim=1)
            else:
                attention_mask = data_attention_mask
                
        assert embeds.shape[:2] == attention_mask.shape, (
            f"Shape mismatch: embeds.shape[:2] = {embeds.shape[:2]} "
            f"but attention_mask.shape = {attention_mask.shape}."
        )
        
        data_repr = Pooling.mean_pooling(embeds, attention_mask)
        return UPMAEncoderOutput(repr=data_repr, **output)
        

### Example

In [58]:
config = UPMAConfig(
    num_total_metadata=block.train.dset.meta['cat_meta'].n_meta,
    num_input_metadata = 3,
    pad_metadata_idx=None,
    metadata_dropout=0.1,
    memory_injection_layer=None,
    memory_module_name="embeddings",

    data_aug_meta_prefix="cat2data",
    lbl2data_aug_meta_prefix="cat2lbl",

    data_enrich=True,
    lbl2data_enrich=False,
    data_repr_pooling=False,

    margin=0.3,
    num_negatives=5,
    tau=0.1,
    apply_softmax=True,

    calib_margin=0.3,
    calib_num_negatives=10,
    calib_tau=0.1,
    calib_apply_softmax=False,
    
    calib_loss_weight=0.1,
    use_calib_loss=True,
    
    use_encoder_parallel=False,

    initialize_memory_embeddings_from_injection_layer_mean=False,
    metadata_embedding_file=None,
)

In [59]:
meta_dset = block.train.dset.meta_dset('cat_meta')
model = UPMAEncoder.from_pretrained(config, meta_dset=meta_dset)

In [64]:
data_meta_kwargs = Parameters.from_aug_meta_prefix_for_feature('data', config.data_aug_meta_prefix, **batch)

o = model(data_input_ids=batch["data_input_ids"], data_attention_mask=batch["data_attention_mask"], 
          data_aug_meta_prefix=config.data_aug_meta_prefix, data_enrich=config.data_enrich, **data_meta_kwargs)

In [67]:
o[0].shape

torch.Size([3, 35, 768])

## `UPA000`

In [125]:
#| export
class UPA000(PreTrainedModel):
    
    def __init__(
        self, 
        config: UPMAConfig,
        meta_dset:Optional[Union[MainXCDataset, SMainXCDataset]] = None,
        batch_size:Optional[int] = 100,
    ):
        super().__init__(config)
        self.config = config
        self.encoder = UPMAEncoder.from_pretrained(config, meta_dset=meta_dset, batch_size=batch_size)
        
        self.rep_loss_fn = MultiTriplet(margin=config.margin, n_negatives=config.num_negatives, tau=config.tau, 
                                        apply_softmax=config.apply_softmax, reduce='mean')
        self.cab_loss_fn = Calibration(margin=config.calib_margin, tau=config.calib_tau, n_negatives=config.calib_num_negatives, 
                                       apply_softmax=config.calib_apply_softmax, reduce='mean')
        
    @classmethod
    def from_pretrained(
        cls,
        config: PretrainedConfig,
        meta_dset: Optional[Union[MainXCDataset, SMainXCDataset]] = None,
        batch_size: Optional[int] = 100,
    ):
        return cls(config, meta_dset=meta_dset, batch_size=batch_size)
        
    def compute_loss(
        self,
        inp_repr: Optional[torch.Tensor] = None,
        targ_repr: Optional[torch.Tensor] = None,
        targ_ptr: Optional[torch.Tensor] = None,
        targ_idx: Optional[torch.Tensor] = None,
        ptarg_ptr: Optional[torch.Tensor] = None,
        ptarg_idx: Optional[torch.Tensor] = None
    ):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def calibration_loss(
        self,
        einp_repr: Optional[torch.Tensor] = None,
        inp_repr: Optional[torch.Tensor] = None,
        targ_repr: Optional[torch.Tensor] = None,
        targ_ptr: Optional[torch.Tensor] = None,
        targ_idx: Optional[torch.Tensor] = None,
        ptarg_ptr: Optional[torch.Tensor] = None,
        ptarg_idx: Optional[torch.Tensor] = None
    ):
        return self.config.calib_loss_weight * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def forward(
        self,
        data_input_ids: Optional[torch.Tensor] = None,
        data_attention_mask: Optional[torch.Tensor] = None,
        lbl2data_data2ptr: Optional[torch.Tensor] = None,
        lbl2data_idx: Optional[torch.Tensor] = None,
        lbl2data_input_ids: Optional[torch.Tensor] = None,
        lbl2data_attention_mask: Optional[torch.Tensor] = None,
        plbl2data_data2ptr: Optional[torch.Tensor] = None,
        plbl2data_idx: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        encoder = XCDataParallel(module=self.encoder) if self.config.use_encoder_parallel else self.encoder
        
        data_meta_kwargs = Parameters.from_aug_meta_prefix_for_feature('data', self.config.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.config.data_aug_meta_prefix, data_inject_memory=self.config.data_inject_memory, **data_meta_kwargs)
        
        loss = None; lbl2data_o = UPMAEncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_aug_meta_prefix_for_feature('lbl', self.config.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.config.lbl2data_aug_meta_prefix, data_inject_memory=self.config.lbl2data_inject_memory, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.repr, lbl2data_o.repr,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)
            
        if not return_dict:
            o = (data_o.repr,lbl2data_o.repr)
            return ((loss,) + o) if loss is not None else o

        return UPMAModelOutput(
            loss=loss,
            data_repr=data_o.repr,
            lbl2data_repr=lbl2data_o.repr,
        )
        

### Example

In [133]:
config = UPMAConfig(
    num_total_metadata=block.train.dset.meta['cat_meta'].n_meta,
    num_input_metadata = 3,
    metadata_dropout = 0.1,
    memory_injection_layer = None,
    memory_module_name = "embeddings",

    data_aug_meta_prefix = "cat2data",
    lbl2data_aug_meta_prefix = "cat2lbl",

    data_inject_memory = True,
    lbl2data_inject_memory = False,

    margin = 0.3,
    num_negatives = 5,
    tau = 0.1,
    apply_softmax = True,

    calib_margin = 0.3,
    calib_num_negatives = 10,
    calib_tau = 0.1,
    calib_apply_softmax = False,
    
    calib_loss_weight = 0.1,
    use_calib_loss = True,
    
    use_encoder_parallel = False,
    initialize_memory_embeddings_from_injection_layer_mean = False,
)

In [134]:
model = UPA000.from_pretrained(config)

In [150]:
o = model(**batch)

In [None]:
block.train

In [151]:
o.loss

tensor(0., grad_fn=<DivBackward0>)

In [144]:
o.data_repr

tensor([[ 0.0329,  0.4165, -0.1727,  ..., -0.3043,  0.0117,  0.1269],
        [-0.1260,  0.1324, -0.6800,  ..., -0.5539,  0.0975,  0.0054],
        [-0.1265, -0.6460, -0.4209,  ...,  0.0807, -0.0174, -0.0341]],
       grad_fn=<DivBackward0>)

In [145]:
o.lbl2data_repr

tensor([[ 0.1009,  0.3114, -0.2503,  ..., -0.2568, -0.0038,  0.0075],
        [ 0.0146,  0.5108, -0.4399,  ..., -0.3527, -0.1128, -0.0712],
        [-0.3037, -0.0469, -0.4431,  ..., -0.5149, -0.0814,  0.1870],
        [-0.1228,  0.1947, -0.6093,  ..., -0.4707,  0.0389, -0.0159],
        [-0.3955, -0.2311, -0.4383,  ..., -0.0789, -0.0377,  0.2048],
        [-0.3194, -0.6375, -0.2985,  ...,  0.1485, -0.0762,  0.2454]],
       grad_fn=<DivBackward0>)