In [None]:
#| default_exp models.radga_lora

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn, math
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import (
    PretrainedConfig,
    DistilBertForMaskedLM,
    DistilBertModel,
    DistilBertPreTrainedModel,
)
from transformers.utils.generic import ModelOutput
from transformers.activations import get_activation

from fastcore.meta import *
from fastcore.utils import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel

from peft import (
    LoraConfig, 
    get_peft_model, 
    TaskType,
    PeftModel
)

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
from transformers import AutoConfig
from xcai.block import *

## Setup

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC/data/'

In [None]:
block = XCBlock.from_cfg(data_dir, 'data_metas', tfm='rm', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data|cat2lbl2data',1,(1,3)), ('cat2data',1,3)])



In [None]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-meta_distilbert-base-uncased_rm_radga-cat.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [None]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 2: break

In [None]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_identifier', 'cat2lbl_input_text', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'hlk2lbl2data_idx', 'hlk2lbl2data_identifier', 'hlk2lbl2data_input_text', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask', 'hlk2lbl2data_data2ptr', 'hlk2lbl2data_plbl2data2ptr', 'hlk2data_idx', 'hlk2data_identifier', 'hlk2data_input_text', 'hlk2data_input_ids', 'hlk2data_attention_mask', 'hlk2data_data2ptr'])

## Helper

In [None]:
#| export
@dataclass
class RADOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    data_repr: Optional[torch.FloatTensor] = None
    data_fused_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    lbl2data_fused_repr: Optional[torch.FloatTensor] = None
        

In [None]:
#| export
@dataclass
class EncoderOutput(ModelOutput):
    rep: Optional[torch.FloatTensor] = None
    fused_rep: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    fusion_weights: Optional[torch.FloatTensor] = None
    meta_repr: Optional[torch.FloatTensor] = None
        

In [None]:
#| export
class Pooling:

    @staticmethod
    def mean_pooling(data_embeds:torch.FloatTensor, data_attention_mask:torch.LongTensor):
        data_attention_mask = data_attention_mask.unsqueeze(2).expand(data_embeds.size()).float()
        return torch.sum(data_embeds * data_attention_mask, 1) / torch.clamp(data_attention_mask.sum(1), min=1e-9)


## CrossAttention

In [None]:
#| export
class CrossAttention(nn.Module):
    
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config, self.n_h, self.dim = config, config.n_heads, config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        if self.dim % self.n_h != 0:
            raise ValueError(f"self.n_heads: {self.n_h} must divide self.dim: {self.dim} evenly.")
            
        self.q = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.o = nn.Linear(in_features=config.dim, out_features=config.dim)

    def post_init(self):
        self.q.weight.data = torch.eye(self.q.out_features, self.q.in_features, dtype=self.q.weight.dtype)
        self.k.weight.data = torch.eye(self.k.out_features, self.k.in_features, dtype=self.k.weight.dtype)
        self.v.weight.data = torch.eye(self.v.out_features, self.v.in_features, dtype=self.v.weight.dtype)
        self.o.weight.data = torch.eye(self.o.out_features, self.o.in_features, dtype=self.o.weight.dtype)

    def forward(
        self, 
        q: torch.Tensor,
        q_m: torch.Tensor,
        k: torch.Tensor, 
        k_m: torch.Tensor,
        output_attentions:Optional[bool] = False,
    ):
        bs, q_len, dim = q.size()
        v, k_len = k, k.size(1) 

        h_dim = self.dim//self.n_h

        def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)

        def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)

        q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
        k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
        v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)

        q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
        sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
        
        q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
        mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
        
        sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)

        w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
        w = self.dropout(w)  # (bs, n_h, q_len, k_len)

        o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
        
        if output_attentions: return (o, w)
        else: return (o,)
        

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
fuser = CrossAttention(config)

In [None]:
bsz, data_seq_len, n_meta, dim, dtype = 2, 3, 2, config.dim, torch.float32
data, meta = torch.randn(bsz, data_seq_len, dim, dtype=dtype), torch.randn(bsz, n_meta, dim, dtype=dtype)
data_mask = torch.randint(0, 2, size=(bsz,data_seq_len), dtype=dtype)
meta_mask = torch.randint(0, 2, size=(bsz,n_meta), dtype=dtype)

In [None]:
o = fuser(data, data_mask, meta, meta_mask)

In [None]:
o[0].shape

torch.Size([2, 3, 768])

## Blocks

In [None]:
#| export
class RepresentationHead(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.transform = nn.Linear(config.dim, config.dim)
        self.layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.projector = nn.Linear(config.dim, config.dim)
        self.activation = get_activation(config.activation)
        
        self.post_init()
        
    def post_init(self):
        self.transform.weight.data = torch.eye(self.transform.out_features, self.transform.in_features, 
                                               dtype=self.transform.weight.dtype)
        self.projector.weight.data = torch.eye(self.projector.out_features, self.projector.in_features, 
                                               dtype=self.projector.weight.dtype)
        
    def forward(self, x:torch.Tensor):
        x = self.transform(x)
        x = self.activation(x)
        x = self.layer_norm(x)
        x = self.projector(x)
        return x
    

In [None]:
#| export
class GenerationHead(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.transform = nn.Linear(config.dim, config.dim)
        self.layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.projector = nn.Linear(config.dim, config.vocab_size)
        self.activation = get_activation(config.activation)
        
    def forward(self, x:torch.Tensor):
        x = self.transform(x)
        x = self.activation(x)
        x = self.layer_norm(x)
        x = self.projector(x)
        return x
    

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
x = torch.randn(10, 20, config.dim)

In [None]:
m = RepresentationHead(config)

In [None]:
m = GenerationHead(config)

## Parameters

In [None]:
#| export
class Parameters:
    
    @staticmethod
    def from_meta_aug_prefix(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^{prefix}.*_(input_ids|attention_mask|data2ptr|meta_repr|idx)$', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs
    
    @staticmethod
    def from_feat_meta_aug_prefix(feat:str, prefix:str, **kwargs):
        keys = ['attention_mask', 'input_ids', 'meta_repr', 'idx']
        
        inputs = {f'{prefix}_{k}': kwargs[f'{prefix}_{k}'] for k in keys if f'{prefix}_{k}' in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            inputs.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        return inputs
    
    @staticmethod
    def from_meta_pred_prefix(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^[p]?{prefix}.*', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            if arg[0] == 'p': 
                inputs.setdefault(meta[1:], {})[f'p{param}'] = kwargs[arg]
            else: 
                inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs

    @staticmethod
    def get_meta_loss_weights(lw:Union[float,List], n_meta:int):
        if isinstance(lw, float):
            lw = lw/n_meta if n_meta else None
            return [lw] * n_meta
        else:
            if len(lw) != n_meta: raise ValueError(f'length of `lw` should be equal to number of metadata.')
            return lw
        

### Example

In [None]:
b = next(iter(block.train.dl))

In [None]:
p = Parameters.from_meta_aug_prefix('cat', **b); p.keys()

dict_keys(['cat2lbl', 'cat2data'])

In [None]:
p = Parameters.from_feat_meta_aug_prefix('data', 'cat2lbl', **b); p.keys()

dict_keys(['cat2lbl_attention_mask', 'cat2lbl_input_ids', 'cat2lbl_data2ptr'])

In [None]:
p = Parameters.from_meta_pred_prefix('cat', **b); p.keys()

dict_keys(['cat2lbl', 'cat2data'])

## Encoder

In [None]:
#| export
class Encoder(DistilBertPreTrainedModel):
    
    def __init__(
        self, 
        config:PretrainedConfig,
        base_model:nn.Module, 
        resize_length:Optional[int]=None,

        lora_r:Optional[int]=8,
        lora_alpha:Optional[int]=32,

        data_aug_meta_prefix:Optional[str]=None, 
        lbl2data_aug_meta_prefix:Optional[str]=None, 
    ):
        super().__init__(config)
        store_attr('data_aug_meta_prefix,lbl2data_aug_meta_prefix')
        
        lora_config = LoraConfig(
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=0.05,
            target_modules=["q_lin", "k_lin","v_lin"],
            bias='none',
        )
        self.distilbert = get_peft_model(base_model, lora_config, adapter_name="lbl2data")
        if self.data_aug_meta_prefix is not None: self.distilbert.add_adapter(self.data_aug_meta_prefix, lora_config)
        if self.lbl2data_aug_meta_prefix is not None: self.distilbert.add_adapter(self.lbl2data_aug_meta_prefix, lora_config)
        self._mark_entire_encoder_as_trainable()
        
        self.dr_head = RepresentationHead(config)
        self.dr_fused_head =  RepresentationHead(config)
        self.meta_head = RepresentationHead(config)
        self.cross_head = CrossAttention(config)
         
        self.ones = torch.ones(resize_length, dtype=torch.long, device=self.device) if resize_length is not None else None
        self.post_init()

    def _mark_entire_encoder_as_trainable(self):
        for p in self.distilbert.parameters(): p.requires_grad_(True)

    def _mark_only_adapters_as_trainable(self):
        self.distilbert.base_model._mark_only_adapters_as_trainable(self.distilbert)
        

    
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
        
        
        
    def resize(self, inputs:torch.Tensor, mask:torch.Tensor, num_inputs:torch.Tensor):
        if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
        bsz, dim, total_num_inputs = num_inputs.shape[0], inputs.shape[-1], inputs.shape[0]
        
        self.ones = self.ones.to(inputs.device)
        ones = (
            torch.ones(total_num_inputs, dtype=torch.long, device=inputs.device) 
            if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
        )

        max_num_inputs = num_inputs.max()
        xnum_inputs = max_num_inputs-num_inputs+1

        inputs_ptr = num_inputs.cumsum(dim=0)-1
        repeat_inputs = ones.scatter(0, inputs_ptr, xnum_inputs)
        
        resized_inputs = inputs.repeat_interleave(repeat_inputs, dim=0)
        resized_mask = mask.repeat_interleave(repeat_inputs, dim=0)
        
        ignore_mask_idx = ones.scatter(0, inputs_ptr, 0).repeat_interleave(repeat_inputs, dim=0).view(bsz, -1)
        ignore_mask_idx[:, -1] = 1; ignore_mask_idx = ignore_mask_idx.view(-1, 1)
        
        resized_mask *= ignore_mask_idx
        
        return resized_inputs,resized_mask


    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
    
    def dr(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def dr_fused(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_fused_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def meta(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)
    
    def meta_unnormalized(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return Pooling.mean_pooling(embed, attention_mask)

    
    
    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)

            self.distilbert.set_adapter(m_key)
            
            if len(idx):
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
                                                                m_args['data2ptr'][idx])
                    n_meta = m_args['data2ptr'].max()

                    m_embed = self.encode(m_input_ids, m_attention_mask)[0]

                    m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
                    m_repr_mask = torch.any(m_attention_mask, dim=1)
                    
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
                
                meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
                
                fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
                embed[idx] += fused_embed
                
        return embed, meta_repr

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
            
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                if self.training: self._mark_only_adapters_as_trainable()
                data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
                                                                             data_attention_mask, 
                                                                             meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)

                self.distilbert.set_adapter('lbl2data')
                if self.training: self._mark_entire_encoder_as_trainable()
        
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

## `RAD001`

In [None]:
#| export
class RAD001(DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    
    def __init__(
        self, config,

        base_model:nn.Module, 
        resize_length:Optional[int]=None,
        lora_r:Optional[int]=8,
        lora_alpha:Optional[int]=32,
        
        num_batch_labels:Optional[int]=None, 
        batch_size:Optional[int]=None,
        margin:Optional[float]=0.3,
        num_negatives:Optional[int]=5,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,
        
        data_aug_meta_prefix:Optional[str]=None, 
        lbl2data_aug_meta_prefix:Optional[str]=None, 

        data_pred_meta_prefix:Optional[str]=None,
        lbl2data_pred_meta_prefix:Optional[str]=None,

        use_query_loss:Optional[float]=False,
        
        calib_margin:Optional[float]=0.3,
        calib_num_negatives:Optional[int]=10,
        calib_tau:Optional[float]=0.1,
        calib_apply_softmax:Optional[bool]=True,
        calib_loss_weight:Optional[float]=0.1,
        use_calib_loss:Optional[float]=False,
        
        meta_loss_weight:Optional[Union[List,float]]=0.3,
        use_fusion_loss:Optional[bool]=False,
        fusion_loss_weight:Optional[float]=0.15,
        
        use_encoder_parallel:Optional[bool]=True,
    ):
        super().__init__(config)
        self.m_lw, self.f_lw, self.c_lw = meta_loss_weight, fusion_loss_weight,calib_loss_weight
        store_attr('data_pred_meta_prefix,lbl2data_pred_meta_prefix')
        store_attr('data_aug_meta_prefix,lbl2data_aug_meta_prefix')
        store_attr('use_query_loss,use_calib_loss,use_fusion_loss,use_encoder_parallel')
        
        self.encoder = Encoder(config, base_model=base_model, resize_length=resize_length, lora_r=lora_r, lora_alpha=lora_alpha, 
                               data_aug_meta_prefix=data_aug_meta_prefix, lbl2data_aug_meta_prefix=lbl2data_aug_meta_prefix)

        self.rep_loss_fn = MultiTriplet(bsz=batch_size, tn_targ=num_batch_labels, margin=margin, n_negatives=num_negatives, 
                                        tau=tau, apply_softmax=apply_softmax, reduce='mean')
        self.cab_loss_fn = Calibration(margin=calib_margin, tau=calib_tau, n_negatives=calib_num_negatives, 
                                       apply_softmax=calib_apply_softmax, reduce='mean')
        
    def init_retrieval_head(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        self.encoder.dr_head.post_init()
        self.encoder.dr_fused_head.post_init()
        self.encoder.meta_head.post_init()

    def init_cross_head(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        self.encoder.cross_head.post_init()
        

    def compute_loss(self, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def calibration_loss(self, einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.c_lw * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)
    
    
    def compute_meta_loss(self, data_repr, lbl2data_repr, **kwargs):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_meta_inputs = Parameters.from_meta_pred_prefix(self.data_pred_meta_prefix, **kwargs)
        lbl2data_meta_inputs = Parameters.from_meta_pred_prefix(self.lbl2data_pred_meta_prefix, **kwargs)
        meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}

        m_lw = Parameters.get_meta_loss_weights(self.m_lw, len(meta_inputs)) if len(meta_inputs) else []
        
        loss = 0.0
        for inputs,lw in zip(meta_inputs.values(), m_lw):
            if 'lbl2data2ptr' in inputs:
                idx = torch.where(inputs['lbl2data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(lbl2data_repr[idx], inputs_o.rep, inputs['lbl2data2ptr'][idx],
                                              inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss

            elif 'data2ptr' in inputs:
                idx = torch.where(inputs['data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(data_repr[idx], inputs_o.rep, inputs['data2ptr'][idx], inputs['idx'], 
                                              inputs['pdata2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss       

            else: raise ValueError('Invalid metadata input arguments.')
        return loss

    def compute_fusion_loss(self, data_repr, meta_repr:Dict, prefix:str, **kwargs):
        meta_inputs = Parameters.from_meta_pred_prefix(prefix, **kwargs)
        
        loss = 0.0
        if meta_repr is not None:
            for key,input_repr in meta_repr.items():
                inputs = meta_inputs[key]
                if 'lbl2data2ptr' in inputs:
                    idx = torch.where(inputs['lbl2data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['lbl2data2ptr'][idx],
                                                  inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                        loss += self.f_lw * m_loss
    
                elif 'data2ptr' in inputs:
                    idx = torch.where(inputs['data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
                                                  inputs['pdata2ptr'][idx], inputs['pidx'])
                        loss += self.f_lw * m_loss       
    
                else: raise ValueError('Invalid metadata input arguments.')
        return loss


    
    def get_meta_representation(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_unnormalized=True, data_type="meta")
        return RADOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)
            
            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)
            
            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return RADOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

### Example

In [None]:
from transformers import DistilBertConfig

In [None]:
base_model = DistilBertModel.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4')

In [None]:
model = RAD001(DistilBertConfig(), resize_length=5000, base_model=base_model, lora_r=8, lora_alpha=32,
               
               batch_size=100, num_batch_labels=5000, margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
               
               use_query_loss=True,
               
               calib_margin=0.3, calib_num_negatives=5, calib_tau=0.1, calib_apply_softmax=True, calib_loss_weight=0.1,
               use_calib_loss=False,
               
               meta_loss_weight=0.0, fusion_loss_weight=0.0, use_fusion_loss=False,
               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(0.0638, grad_fn=<AddBackward0>)

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [None]:
func()

> /tmp/ipykernel_22396/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_22396/2456184125.py:151


ipdb>  b model.encoder.forward


Breakpoint 2 at /tmp/ipykernel_22396/3427015654.py:141


ipdb>  r


> /tmp/ipykernel_22396/2456184125.py(166)forward()
    164         **kwargs
    165     ):  
--> 166         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    167 
    168         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_22396/2456184125.py(168)forward()
    166         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    167 
--> 168         if self.use_encoder_parallel:
    169             encoder = XCDataParallel(module=self.encoder)
    170         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_22396/2456184125.py(170)forward()
    168         if self.use_encoder_parallel:
    169             encoder = XCDataParallel(module=self.encoder)
--> 170         else: encoder = self.encoder
    171 
    172         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_22396/2456184125.py(172)forward()
    170         else: encoder = self.encoder
    171 
--> 172         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
    173         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    174                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  n


> /tmp/ipykernel_22396/2456184125.py(173)forward()
    171 
    172         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
--> 173         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    174                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    175 



ipdb>  data_meta_kwargs


{'cat2data_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), 'cat2data_input_ids': tensor([[  101,  3803,  1011,  4092,  3032,  1998,  6500,   102,     0,     0,
             0,     0,     0,     0],
        [  101,  7139,  3032,  1997,  1996,  2983,  1997,  1996,  4549,   102,
             0,     0,     0,     0],
        [  101, 17867,  2575,  4487,  

ipdb>  data_meta_kwargs.keys()


dict_keys(['cat2data_attention_mask', 'cat2data_input_ids', 'cat2data_idx', 'cat2data_data2ptr'])


ipdb>  n


> /tmp/ipykernel_22396/2456184125.py(174)forward()
    172         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
    173         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
--> 174                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    175 
    176 



ipdb>  r


> /tmp/ipykernel_22396/3427015654.py(150)forward()
    148         **kwargs
    149     ):
--> 150         data_o = self.encode(data_input_ids, data_attention_mask)
    151 
    152         if data_type is not None and data_type == "meta":



ipdb>  xx = [n for n,p in self.distilbert.named_parameters()]
ipdb>  xx


['base_model.model.embeddings.word_embeddings.weight', 'base_model.model.embeddings.position_embeddings.weight', 'base_model.model.embeddings.LayerNorm.weight', 'base_model.model.embeddings.LayerNorm.bias', 'base_model.model.transformer.layer.0.attention.q_lin.base_layer.weight', 'base_model.model.transformer.layer.0.attention.q_lin.base_layer.bias', 'base_model.model.transformer.layer.0.attention.q_lin.lora_A.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_B.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.base_layer.weight', 'base_model.model.transformer.layer.0.attention.k_lin.base_layer.bias', 'base_model.model.transformer.layer.0.attention.k_lin.lora_A.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.lora_A.cat2data.weight', 'base_model.model.transform

ipdb>  n


> /tmp/ipykernel_22396/3427015654.py(152)forward()
    150         data_o = self.encode(data_input_ids, data_attention_mask)
    151 
--> 152         if data_type is not None and data_type == "meta":
    153             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    154         else:



ipdb>  n


> /tmp/ipykernel_22396/3427015654.py(155)forward()
    153             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    154         else:
--> 155             data_repr = self.dr(data_o[0], data_attention_mask)
    156 
    157         if self.training: self._mark_only_adapters_as_trainable()



ipdb>  n


> /tmp/ipykernel_22396/3427015654.py(157)forward()
    155             data_repr = self.dr(data_o[0], data_attention_mask)
    156 
--> 157         if self.training: self._mark_only_adapters_as_trainable()
    158 
    159         data_fused_repr = meta_repr = None



ipdb>  self.training


True


ipdb>  n


> /tmp/ipykernel_22396/3427015654.py(159)forward()
    157         if self.training: self._mark_only_adapters_as_trainable()
    158 
--> 159         data_fused_repr = meta_repr = None
    160         if data_aug_meta_prefix is not None:
    161             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  xx = [n for n,p in self.distilbert.named_parameters()]
ipdb>  xx


['base_model.model.embeddings.word_embeddings.weight', 'base_model.model.embeddings.position_embeddings.weight', 'base_model.model.embeddings.LayerNorm.weight', 'base_model.model.embeddings.LayerNorm.bias', 'base_model.model.transformer.layer.0.attention.q_lin.base_layer.weight', 'base_model.model.transformer.layer.0.attention.q_lin.base_layer.bias', 'base_model.model.transformer.layer.0.attention.q_lin.lora_A.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_B.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.base_layer.weight', 'base_model.model.transformer.layer.0.attention.k_lin.base_layer.bias', 'base_model.model.transformer.layer.0.attention.k_lin.lora_A.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.lora_A.cat2data.weight', 'base_model.model.transform

ipdb>  self._mark_only_adapters_as_trainable()
ipdb>  xx = [n for n,p in self.distilbert.named_parameters() if p.requires_grad]
ipdb>  xx


['base_model.model.transformer.layer.0.attention.q_lin.lora_A.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_B.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.lora_A.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.lora_B.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.0.attention.v_lin.lora_A.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.v_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.0.attention.v_lin.lora_B.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.v_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.1.attention.q_l

ipdb>  n


> /tmp/ipykernel_22396/3427015654.py(160)forward()
    158 
    159         data_fused_repr = meta_repr = None
--> 160         if data_aug_meta_prefix is not None:
    161             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    162             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(161)forward()
    159         data_fused_repr = meta_repr = None
    160         if data_aug_meta_prefix is not None:
--> 161             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    162             if len(meta_kwargs):
    163                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(162)forward()
    160         if data_aug_meta_prefix is not None:
    161             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 162             if len(meta_kwargs):
    163                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    164                                                                              data_attention_mask,



ipdb>  meta_kwargs.keys()


dict_keys(['cat2data'])


ipdb>  n


> /tmp/ipykernel_22396/3427015654.py(163)forward()
    161             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    162             if len(meta_kwargs):
--> 163                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    164                                                                              data_attention_mask,
    165                                                                              meta_kwargs)



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(164)forward()
    162             if len(meta_kwargs):
    163                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
--> 164                                                                              data_attention_mask,
    165                                                                              meta_kwargs)
    166                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(165)forward()
    163                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    164                                                                              data_attention_mask,
--> 165                                                                              meta_kwargs)
    166                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    167 



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(163)forward()
    161             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    162             if len(meta_kwargs):
--> 163                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    164                                                                              data_attention_mask,
    165                                                                              meta_kwargs)



ipdb>  s


--Call--
> /tmp/ipykernel_22396/3427015654.py(108)fuse_meta_into_embeddings()
    106 
    107 
--> 108     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
    109         meta_repr = {}
    110 



ipdb>  n


> /tmp/ipykernel_22396/3427015654.py(109)fuse_meta_into_embeddings()
    107 
    108     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
--> 109         meta_repr = {}
    110 
    111         for m_key, m_args in meta_kwargs.items():



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(111)fuse_meta_into_embeddings()
    109         meta_repr = {}
    110 
--> 111         for m_key, m_args in meta_kwargs.items():
    112             idx = torch.where(m_args['data2ptr'] > 0)[0]
    113             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(112)fuse_meta_into_embeddings()
    110 
    111         for m_key, m_args in meta_kwargs.items():
--> 112             idx = torch.where(m_args['data2ptr'] > 0)[0]
    113             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
    114 



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(113)fuse_meta_into_embeddings()
    111         for m_key, m_args in meta_kwargs.items():
    112             idx = torch.where(m_args['data2ptr'] > 0)[0]
--> 113             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
    114 
    115             self.distilbert.set_adapter(m_key)



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(115)fuse_meta_into_embeddings()
    113             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
    114 
--> 115             self.distilbert.set_adapter(m_key)
    116 
    117             if len(idx):



ipdb>  m_key


'cat2data'


ipdb>  n


> /tmp/ipykernel_22396/3427015654.py(117)fuse_meta_into_embeddings()
    115             self.distilbert.set_adapter(m_key)
    116 
--> 117             if len(idx):
    118                 if 'meta_repr' in m_args:
    119                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)



ipdb>  xx = [n for n,p in self.distilbert.named_parameters() if p.requires_grad]
ipdb>  xx


['base_model.model.transformer.layer.0.attention.q_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.0.attention.v_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.0.attention.v_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.1.attention.q_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.1.attention.q_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.1.attention.k_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.1.attention.k_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.1.attention.v_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.1.attention.v_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.2.attention.q_l

ipdb>  self.distilbert.active_adapters


['cat2data']


ipdb>  n


> /tmp/ipykernel_22396/3427015654.py(118)fuse_meta_into_embeddings()
    116 
    117             if len(idx):
--> 118                 if 'meta_repr' in m_args:
    119                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
    120                     m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(123)fuse_meta_into_embeddings()
    121                     m_repr_mask = m_repr_mask.bool()
    122                 else:
--> 123                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
    124                                                                 m_args['data2ptr'][idx])
    125                     n_meta = m_args['data2ptr'].max()



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(124)fuse_meta_into_embeddings()
    122                 else:
    123                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
--> 124                                                                 m_args['data2ptr'][idx])
    125                     n_meta = m_args['data2ptr'].max()
    126 



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(123)fuse_meta_into_embeddings()
    121                     m_repr_mask = m_repr_mask.bool()
    122                 else:
--> 123                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
    124                                                                 m_args['data2ptr'][idx])
    125                     n_meta = m_args['data2ptr'].max()



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(125)fuse_meta_into_embeddings()
    123                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
    124                                                                 m_args['data2ptr'][idx])
--> 125                     n_meta = m_args['data2ptr'].max()
    126 
    127                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(127)fuse_meta_into_embeddings()
    125                     n_meta = m_args['data2ptr'].max()
    126 
--> 127                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
    128 
    129                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(129)fuse_meta_into_embeddings()
    127                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
    128 
--> 129                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
    130                     m_repr_mask = torch.any(m_attention_mask, dim=1)
    131 



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(130)fuse_meta_into_embeddings()
    128 
    129                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
--> 130                     m_repr_mask = torch.any(m_attention_mask, dim=1)
    131 
    132                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(132)fuse_meta_into_embeddings()
    130                     m_repr_mask = torch.any(m_attention_mask, dim=1)
    131 
--> 132                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
    133 
    134                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(134)fuse_meta_into_embeddings()
    132                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
    133 
--> 134                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
    135 
    136                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(136)fuse_meta_into_embeddings()
    134                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
    135 
--> 136                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
    137                 embed[idx] += fused_embed
    138 



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(137)fuse_meta_into_embeddings()
    135 
    136                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
--> 137                 embed[idx] += fused_embed
    138 
    139         return embed, meta_repr



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(111)fuse_meta_into_embeddings()
    109         meta_repr = {}
    110 
--> 111         for m_key, m_args in meta_kwargs.items():
    112             idx = torch.where(m_args['data2ptr'] > 0)[0]
    113             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(139)fuse_meta_into_embeddings()
    137                 embed[idx] += fused_embed
    138 
--> 139         return embed, meta_repr
    140 
2   141     def forward(



ipdb>  


--Return--
(tensor([[[-1....PutBackward0>), {'cat2data': tensor([[-0.0...DivBackward0>)})
> /tmp/ipykernel_22396/3427015654.py(139)fuse_meta_into_embeddings()
    137                 embed[idx] += fused_embed
    138 
--> 139         return embed, meta_repr
    140 
2   141     def forward(



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(166)forward()
    164                                                                              data_attention_mask,
    165                                                                              meta_kwargs)
--> 166                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    167 
    168         if self.training:



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(168)forward()
    166                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    167 
--> 168         if self.training:
    169             self.distilbert.set_adapter('lbl2data')
    170             self._mark_entire_encoder_as_trainable()



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(169)forward()
    167 
    168         if self.training:
--> 169             self.distilbert.set_adapter('lbl2data')
    170             self._mark_entire_encoder_as_trainable()
    171 



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(170)forward()
    168         if self.training:
    169             self.distilbert.set_adapter('lbl2data')
--> 170             self._mark_entire_encoder_as_trainable()
    171 
    172         return EncoderOutput(



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(172)forward()
    170             self._mark_entire_encoder_as_trainable()
    171 
--> 172         return EncoderOutput(
    173             rep=data_repr,
    174             fused_rep=data_fused_repr,



ipdb>  xx = [n for n,p in self.distilbert.named_parameters() if p.requires_grad]
ipdb>  xx


['base_model.model.embeddings.word_embeddings.weight', 'base_model.model.embeddings.position_embeddings.weight', 'base_model.model.embeddings.LayerNorm.weight', 'base_model.model.embeddings.LayerNorm.bias', 'base_model.model.transformer.layer.0.attention.q_lin.base_layer.weight', 'base_model.model.transformer.layer.0.attention.q_lin.base_layer.bias', 'base_model.model.transformer.layer.0.attention.q_lin.lora_A.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_B.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.base_layer.weight', 'base_model.model.transformer.layer.0.attention.k_lin.base_layer.bias', 'base_model.model.transformer.layer.0.attention.k_lin.lora_A.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.lora_A.cat2data.weight', 'base_model.model.transform

ipdb>  self.distilbert.active_adapters


['lbl2data']


ipdb>  c


> /tmp/ipykernel_22396/3427015654.py(150)forward()
    148         **kwargs
    149     ):
--> 150         data_o = self.encode(data_input_ids, data_attention_mask)
    151 
    152         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_22396/3427015654.py(152)forward()
    150         data_o = self.encode(data_input_ids, data_attention_mask)
    151 
--> 152         if data_type is not None and data_type == "meta":
    153             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    154         else:



ipdb>  xx = [n for n,p in self.distilbert.named_parameters() if p.requires_grad]
ipdb>  xx


['base_model.model.embeddings.word_embeddings.weight', 'base_model.model.embeddings.position_embeddings.weight', 'base_model.model.embeddings.LayerNorm.weight', 'base_model.model.embeddings.LayerNorm.bias', 'base_model.model.transformer.layer.0.attention.q_lin.base_layer.weight', 'base_model.model.transformer.layer.0.attention.q_lin.base_layer.bias', 'base_model.model.transformer.layer.0.attention.q_lin.lora_A.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_A.cat2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_B.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.q_lin.lora_B.cat2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.base_layer.weight', 'base_model.model.transformer.layer.0.attention.k_lin.base_layer.bias', 'base_model.model.transformer.layer.0.attention.k_lin.lora_A.lbl2data.weight', 'base_model.model.transformer.layer.0.attention.k_lin.lora_A.cat2data.weight', 'base_model.model.transform

ipdb>  n


> /tmp/ipykernel_22396/3427015654.py(155)forward()
    153             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    154         else:
--> 155             data_repr = self.dr(data_o[0], data_attention_mask)
    156 
    157         if self.training: self._mark_only_adapters_as_trainable()



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(157)forward()
    155             data_repr = self.dr(data_o[0], data_attention_mask)
    156 
--> 157         if self.training: self._mark_only_adapters_as_trainable()
    158 
    159         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(159)forward()
    157         if self.training: self._mark_only_adapters_as_trainable()
    158 
--> 159         data_fused_repr = meta_repr = None
    160         if data_aug_meta_prefix is not None:
    161             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(160)forward()
    158 
    159         data_fused_repr = meta_repr = None
--> 160         if data_aug_meta_prefix is not None:
    161             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    162             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(168)forward()
    166                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    167 
--> 168         if self.training:
    169             self.distilbert.set_adapter('lbl2data')
    170             self._mark_entire_encoder_as_trainable()



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(169)forward()
    167 
    168         if self.training:
--> 169             self.distilbert.set_adapter('lbl2data')
    170             self._mark_entire_encoder_as_trainable()
    171 



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(170)forward()
    168         if self.training:
    169             self.distilbert.set_adapter('lbl2data')
--> 170             self._mark_entire_encoder_as_trainable()
    171 
    172         return EncoderOutput(



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(172)forward()
    170             self._mark_entire_encoder_as_trainable()
    171 
--> 172         return EncoderOutput(
    173             rep=data_repr,
    174             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(173)forward()
    171 
    172         return EncoderOutput(
--> 173             rep=data_repr,
    174             fused_rep=data_fused_repr,
    175             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(174)forward()
    172         return EncoderOutput(
    173             rep=data_repr,
--> 174             fused_rep=data_fused_repr,
    175             meta_repr=meta_repr,
    176         )



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(175)forward()
    173             rep=data_repr,
    174             fused_rep=data_fused_repr,
--> 175             meta_repr=meta_repr,
    176         )
    177 



ipdb>  


> /tmp/ipykernel_22396/3427015654.py(172)forward()
    170             self._mark_entire_encoder_as_trainable()
    171 
--> 172         return EncoderOutput(
    173             rep=data_repr,
    174             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...eta_repr=None)
> /tmp/ipykernel_22396/3427015654.py(172)forward()
    170             self._mark_entire_encoder_as_trainable()
    171 
--> 172         return EncoderOutput(
    173             rep=data_repr,
    174             fused_rep=data_fused_repr,



ipdb>  c


RADOutput(loss=tensor(0.0592, grad_fn=<AddBackward0>), logits=None, data_repr=tensor([[-0.0355, -0.0144, -0.0295,  ...,  0.0476,  0.0050, -0.0291],
        [-0.0242, -0.0198, -0.0323,  ...,  0.0314, -0.0336,  0.0125],
        [ 0.0014,  0.0163, -0.0380,  ...,  0.0723, -0.0179,  0.0352],
        [ 0.0704,  0.0167,  0.0023,  ...,  0.0334, -0.0045,  0.0316],
        [-0.0218, -0.0240, -0.0293,  ...,  0.0250, -0.0241, -0.0328]],
       grad_fn=<DivBackward0>), data_fused_repr=tensor([[-2.3747e-02, -1.5425e-02, -2.3315e-02,  ...,  2.7843e-02,
         -9.0777e-03, -2.3724e-02],
        [-1.9169e-02, -1.8406e-02, -2.0055e-02,  ...,  2.4618e-02,
         -1.9030e-02, -5.9703e-05],
        [-1.3714e-02,  3.5124e-03, -2.3359e-02,  ...,  1.6399e-02,
         -2.4352e-02,  3.2127e-03],
        [ 8.0003e-02,  7.5907e-03, -1.1098e-02,  ...,  4.2235e-02,
         -1.7850e-02,  3.7884e-02],
        [-2.4064e-02, -2.0427e-02, -1.9721e-02,  ...,  1.0924e-02,
         -1.3290e-02, -2.3074e-02]], grad_fn

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



## `RAD002`

In [None]:
#| export
class Encoder002(Encoder):
    
    def __init__(
        self, 
        config:PretrainedConfig,
        **kwargs
    ):
        super().__init__(config, **kwargs)

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
            
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
                                                                             data_attention_mask, 
                                                                             meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)

                self.distilbert.set_adapter('lbl2data')
        
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

In [None]:
#| export
class RAD002(RAD001):
    
    def __init__(
        self, config,

        base_model:nn.Module, 
        resize_length:Optional[int]=None,
        lora_r:Optional[int]=8,
        lora_alpha:Optional[int]=32,

        data_aug_meta_prefix:Optional[str]=None, 
        lbl2data_aug_meta_prefix:Optional[str]=None, 
        
        **kwargs
    ):
        super().__init__(config, base_model=base_model, resize_length=resize_length, lora_r=lora_r, lora_alpha=lora_alpha, 
                         data_aug_meta_prefix=data_aug_meta_prefix, lbl2data_aug_meta_prefix=lbl2data_aug_meta_prefix, **kwargs)
        
        self.encoder = Encoder002(config, base_model=base_model, resize_length=resize_length, lora_r=lora_r, lora_alpha=lora_alpha, 
                                  data_aug_meta_prefix=data_aug_meta_prefix, lbl2data_aug_meta_prefix=lbl2data_aug_meta_prefix)
        

### Example

In [None]:
from transformers import DistilBertConfig

In [None]:
base_model = DistilBertModel.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4')

In [None]:
model = RAD002(DistilBertConfig(), resize_length=5000, base_model=base_model, lora_r=8, lora_alpha=32,
               
               batch_size=100, num_batch_labels=5000, margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
               
               use_query_loss=True,
               
               calib_margin=0.3, calib_num_negatives=5, calib_tau=0.1, calib_apply_softmax=True, calib_loss_weight=0.1,
               use_calib_loss=False,
               
               meta_loss_weight=0.0, fusion_loss_weight=0.0, use_fusion_loss=False,
               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(0.0643, grad_fn=<AddBackward0>)