In [None]:
#| default_exp models.modeling_utils

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn, math
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import (
    PretrainedConfig,
    DistilBertForMaskedLM,
    DistilBertModel,
    DistilBertPreTrainedModel,
)
from transformers.utils.generic import ModelOutput
from transformers.activations import get_activation

from fastcore.meta import *
from fastcore.utils import *

In [None]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
from transformers import AutoConfig
from xcai.block import *

## Setup

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC/data/'

In [None]:
block = XCBlock.from_cfg(data_dir, 'data_metas', tfm='rm', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data|cat2lbl2data',1,(1,3)), ('cat2data',1,3)])



In [None]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-meta_distilbert-base-uncased_rm_radga-cat.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [None]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 2: break

In [None]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'hlk2lbl2data_idx', 'hlk2lbl2data_identifier', 'hlk2lbl2data_input_text', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask', 'hlk2lbl2data_data2ptr', 'hlk2lbl2data_plbl2data2ptr', 'hlk2data_idx', 'hlk2data_identifier', 'hlk2data_input_text', 'hlk2data_input_ids', 'hlk2data_attention_mask', 'hlk2data_data2ptr'])

## Helper

In [None]:
#| export
@dataclass
class XCModelOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    data_repr: Optional[torch.FloatTensor] = None
    data_fused_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    lbl2data_fused_repr: Optional[torch.FloatTensor] = None
        

In [None]:
#| export
@dataclass
class EncoderOutput(ModelOutput):
    rep: Optional[torch.FloatTensor] = None
    fused_rep: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    fusion_weights: Optional[torch.FloatTensor] = None
    meta_repr: Optional[torch.FloatTensor] = None
        

In [None]:
#| export
class Pooling:

    @staticmethod
    def mean_pooling(data_embeds:torch.FloatTensor, data_attention_mask:torch.LongTensor):
        data_attention_mask = data_attention_mask.unsqueeze(2).expand(data_embeds.size()).float()
        return torch.sum(data_embeds * data_attention_mask, 1) / torch.clamp(data_attention_mask.sum(1), min=1e-9)


## CrossAttention

In [None]:
#| export
class CrossAttention(nn.Module):
    
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config, self.n_h, self.dim = config, config.n_heads, config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        if self.dim % self.n_h != 0:
            raise ValueError(f"self.n_heads: {self.n_h} must divide self.dim: {self.dim} evenly.")
            
        self.q = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.o = nn.Linear(in_features=config.dim, out_features=config.dim)

    def post_init(self):
        self.q.weight.data = torch.eye(self.q.out_features, self.q.in_features, dtype=self.q.weight.dtype)
        self.k.weight.data = torch.eye(self.k.out_features, self.k.in_features, dtype=self.k.weight.dtype)
        self.v.weight.data = torch.eye(self.v.out_features, self.v.in_features, dtype=self.v.weight.dtype)
        self.o.weight.data = torch.eye(self.o.out_features, self.o.in_features, dtype=self.o.weight.dtype)

    def forward(
        self, 
        q: torch.Tensor,
        q_m: torch.Tensor,
        k: torch.Tensor, 
        k_m: torch.Tensor,
        output_attentions:Optional[bool] = False,
    ):
        bs, q_len, dim = q.size()
        v, k_len = k, k.size(1) 

        h_dim = self.dim//self.n_h

        def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)

        def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)

        q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
        k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
        v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)

        q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
        sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
        
        q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
        mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
        
        sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)

        w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
        w = self.dropout(w)  # (bs, n_h, q_len, k_len)

        o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
        
        if output_attentions: return (o, w)
        else: return (o,)
        

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
fuser = CrossAttention(config)

In [None]:
bsz, data_seq_len, n_meta, dim, dtype = 2, 3, 2, config.dim, torch.float32
data, meta = torch.randn(bsz, data_seq_len, dim, dtype=dtype), torch.randn(bsz, n_meta, dim, dtype=dtype)
data_mask = torch.randint(0, 2, size=(bsz,data_seq_len), dtype=dtype)
meta_mask = torch.randint(0, 2, size=(bsz,n_meta), dtype=dtype)

In [None]:
o = fuser(data, data_mask, meta, meta_mask)

In [None]:
o[0].shape

torch.Size([2, 3, 768])

## Blocks

In [None]:
#| export
class RepresentationHead(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.transform = nn.Linear(config.dim, config.dim)
        self.layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.projector = nn.Linear(config.dim, config.dim)
        self.activation = get_activation(config.activation)
        
        self.post_init()
        
    def post_init(self):
        self.transform.weight.data = torch.eye(self.transform.out_features, self.transform.in_features, 
                                               dtype=self.transform.weight.dtype)
        self.projector.weight.data = torch.eye(self.projector.out_features, self.projector.in_features, 
                                               dtype=self.projector.weight.dtype)
        
    def forward(self, x:torch.Tensor):
        x = self.transform(x)
        x = self.activation(x)
        x = self.layer_norm(x)
        x = self.projector(x)
        return x
    

In [None]:
#| export
class GenerationHead(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.transform = nn.Linear(config.dim, config.dim)
        self.layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.projector = nn.Linear(config.dim, config.vocab_size)
        self.activation = get_activation(config.activation)
        
    def forward(self, x:torch.Tensor):
        x = self.transform(x)
        x = self.activation(x)
        x = self.layer_norm(x)
        x = self.projector(x)
        return x
    

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
x = torch.randn(10, 20, config.dim)

In [None]:
m = RepresentationHead(config)

In [None]:
m = GenerationHead(config)

## Parameters

In [None]:
#| export
class Parameters:
    
    @staticmethod
    def from_meta_aug_prefix(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^{prefix}.*_(input_ids|attention_mask|data2ptr|meta_repr|idx)$', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs
    
    @staticmethod
    def from_feat_meta_aug_prefix(feat:str, prefix:str, **kwargs):
        keys = ['attention_mask', 'input_ids', 'meta_repr', 'idx']
        
        inputs = {f'{prefix}_{k}': kwargs[f'{prefix}_{k}'] for k in keys if f'{prefix}_{k}' in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            inputs.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        return inputs
    
    @staticmethod
    def from_meta_pred_prefix(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^[p]?{prefix}.*', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            if arg[0] == 'p': 
                inputs.setdefault(meta[1:], {})[f'p{param}'] = kwargs[arg]
            else: 
                inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs

    @staticmethod
    def get_meta_loss_weights(lw:Union[float,List], n_meta:int):
        if isinstance(lw, float):
            lw = lw/n_meta if n_meta else None
            return [lw] * n_meta
        else:
            if len(lw) != n_meta: raise ValueError(f'length of `lw` should be equal to number of metadata.')
            return lw
        

### Example

In [None]:
b = next(iter(block.train.dl))

In [None]:
p = Parameters.from_meta_aug_prefix('cat', **b); p.keys()

dict_keys(['cat2lbl', 'cat2data'])

In [None]:
p = Parameters.from_feat_meta_aug_prefix('data', 'cat2lbl', **b); p.keys()

dict_keys(['cat2lbl_attention_mask', 'cat2lbl_input_ids', 'cat2lbl_idx', 'cat2lbl_data2ptr'])

In [None]:
p = Parameters.from_meta_pred_prefix('cat', **b); p.keys()

dict_keys(['cat2lbl', 'cat2data'])