In [1]:
#| default_exp models.oak

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn, math
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import (
    PretrainedConfig,
    DistilBertForMaskedLM,
    DistilBertModel,
    DistilBertPreTrainedModel,
)
from transformers.utils.generic import ModelOutput
from transformers.activations import get_activation

from fastcore.meta import *
from fastcore.utils import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel
from xcai.models.modeling_utils import *

comet_ml is installed but `COMET_API_KEY` is not set.


In [4]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [5]:
from transformers import AutoConfig
from xcai.block import *

## Setup

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [7]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-meta_distilbert-base-uncased_xcs.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [8]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [9]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 2: break

In [15]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'cat2data_idx', 'cat2data_identifier', 'cat2data_input_text', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx', 'cat2lbl2data_idx', 'cat2lbl2data_identifier', 'cat2lbl2data_input_text', 'cat2lbl2data_input_ids', 'cat2lbl2data_attention_mask', 'cat2lbl2data_data2ptr', 'cat2lbl2data_lbl2data2ptr'])

## CrossAttention

In [11]:
#| export
class CrossAttention(nn.Module):
    
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config, self.n_h, self.dim = config, config.n_heads, config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        if self.dim % self.n_h != 0:
            raise ValueError(f"self.n_heads: {self.n_h} must divide self.dim: {self.dim} evenly.")
            
        self.q = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.o = nn.Linear(in_features=config.dim, out_features=config.dim)

    def post_init(self):
        self.q.weight.data = torch.eye(self.q.out_features, self.q.in_features, dtype=self.q.weight.dtype)
        self.k.weight.data = torch.eye(self.k.out_features, self.k.in_features, dtype=self.k.weight.dtype)
        self.v.weight.data = torch.eye(self.v.out_features, self.v.in_features, dtype=self.v.weight.dtype)
        self.o.weight.data = torch.eye(self.o.out_features, self.o.in_features, dtype=self.o.weight.dtype)

    def forward(
        self, 
        q: torch.Tensor,
        q_m: torch.Tensor,
        k: torch.Tensor, 
        k_m: torch.Tensor,
        output_attentions:Optional[bool] = False,
    ):
        bs, q_len, dim = q.size()
        v, k_len = k, k.size(1) 

        h_dim = self.dim//self.n_h

        def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)

        def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)

        q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
        k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
        v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)

        q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
        sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
        
        q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
        mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
        
        sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)

        w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
        w = self.dropout(w)  # (bs, n_h, q_len, k_len)

        o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
        
        if output_attentions: return (o, w)
        else: return (o,)
        

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
fuser = CrossAttention(config)

In [None]:
bsz, data_seq_len, n_meta, dim, dtype = 2, 3, 2, config.dim, torch.float32
data, meta = torch.randn(bsz, data_seq_len, dim, dtype=dtype), torch.randn(bsz, n_meta, dim, dtype=dtype)
data_mask = torch.randint(0, 2, size=(bsz,data_seq_len), dtype=dtype)
meta_mask = torch.randint(0, 2, size=(bsz,n_meta), dtype=dtype)

In [None]:
o = fuser(data, data_mask, meta, meta_mask)

In [None]:
o[0].shape

torch.Size([2, 3, 768])

## NormCrossAttention

In [None]:
#| export
class NormCrossAttention(nn.Module):
    
    def __init__(self, config: PretrainedConfig, tau:Optional[float]=0.1, dropout:Optional[float]=0.1):
        super().__init__()
        self.tau = nn.Parameter(torch.tensor(tau, dtype=torch.float32))
        
        self.config, self.n_h, self.dim = config, config.n_heads, config.dim
        self.dropout = nn.Dropout(p=dropout)

        if self.dim % self.n_h != 0:
            raise ValueError(f"self.n_heads: {self.n_h} must divide self.dim: {self.dim} evenly.")
            
        self.q = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.o = nn.Linear(in_features=config.dim, out_features=config.dim)

    def post_init(self):
        self.q.weight.data = torch.eye(self.q.out_features, self.q.in_features, dtype=self.q.weight.dtype)
        self.k.weight.data = torch.eye(self.k.out_features, self.k.in_features, dtype=self.k.weight.dtype)
        self.v.weight.data = torch.eye(self.v.out_features, self.v.in_features, dtype=self.v.weight.dtype)
        self.o.weight.data = torch.eye(self.o.out_features, self.o.in_features, dtype=self.o.weight.dtype)

    def forward(
        self, 
        q: torch.Tensor,
        q_m: torch.Tensor,
        k: torch.Tensor, 
        k_m: torch.Tensor,
        output_attentions:Optional[bool] = False,
    ):
        bs, q_len, dim = q.size()
        v, k_len = k, k.size(1) 

        h_dim = self.dim//self.n_h

        def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)

        def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)

        q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
        k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
        v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)

        q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
        sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
        sc = sc * self.tau
        
        q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
        mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
        
        sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)
        
        w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
        w = self.dropout(w)  # (bs, n_h, q_len, k_len)

        o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
        
        if output_attentions: return (o, w)
        else: return (o,)
        

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
fuser = NormCrossAttention(config)

In [None]:
bsz, data_seq_len, n_meta, dim, dtype = 2, 3, 2, config.dim, torch.float32

data, meta = torch.randn(bsz, data_seq_len, dim, dtype=dtype), torch.randn(bsz, n_meta, dim, dtype=dtype)
data_mask = torch.randint(0, 2, size=(bsz,data_seq_len), dtype=dtype)
meta_mask = torch.randint(0, 2, size=(bsz,n_meta), dtype=dtype)

In [None]:
o = fuser(data, data_mask, meta, meta_mask)

In [None]:
o[0].shape

torch.Size([2, 3, 768])

## Encoder

In [12]:
#| export
class Encoder(DistilBertPreTrainedModel):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
        num_metadata:int,
        resize_length:Optional[int]=None,
    ):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        
        self.dr_head = RepresentationHead(config)
        self.dr_fused_head = RepresentationHead(config)
        self.meta_head = RepresentationHead(config)
        self.cross_head = CrossAttention(config)
        self.meta_embeddings = nn.Embedding(num_metadata, config.dim)

        self.ones = torch.ones(resize_length, dtype=torch.long, device=self.device) if resize_length is not None else None
        self.post_init()

    def freeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(False)

    def unfreeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(True)

    def set_meta_embeddings(self, embed:torch.Tensor):
        self.meta_embeddings.weight.data = embed
        
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
    
    def dr(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def dr_fused(self, embed:torch.Tensor):
        embed = self.dr_fused_head(embed)
        return F.normalize(embed, dim=1)

    def meta(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)
    
    def meta_unnormalized(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return Pooling.mean_pooling(embed, attention_mask)

    def resize(self, idx:torch.Tensor, num_inputs:torch.Tensor):
        if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
        bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
        
        self.ones = self.ones.to(idx.device)
        ones = (
            torch.ones(total_num_inputs, dtype=torch.long, device=idx.device) 
            if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
        )

        max_num_inputs = num_inputs.max()
        if (num_inputs == max_num_inputs).all():
            return idx,ones
        
        xnum_inputs = max_num_inputs-num_inputs+1

        inputs_ptr = num_inputs.cumsum(dim=0)-1
        repeat_inputs = ones.scatter(0, inputs_ptr, xnum_inputs)
        
        resized_idx = idx.repeat_interleave(repeat_inputs, dim=0)
        ignore_mask = ones.scatter(0, inputs_ptr, 0).repeat_interleave(repeat_inputs, dim=0).view(bsz, -1)
        ignore_mask[:, -1] = 1; ignore_mask = ignore_mask.flatten()
        
        return resized_idx,ignore_mask

    
    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
            
            if len(idx):
                m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
                m_repr = F.normalize(self.meta_embeddings(m_idx), dim=1)
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
                meta_repr[m_key] = m_repr[m_repr_mask]
                
                fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
                data_fused_repr[idx] += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
                                                                            torch.any(data_attention_mask, dim=1), 
                                                                            meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_repr)
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

## `OAK000`

In [13]:
#| export
class OAK000(nn.Module):
    
    def __init__(
        self, config,

        data_aug_meta_prefix:Optional[str]=None, 
        lbl2data_aug_meta_prefix:Optional[str]=None, 

        data_pred_meta_prefix:Optional[str]=None,
        lbl2data_pred_meta_prefix:Optional[str]=None,
        
        num_batch_labels:Optional[int]=None, 
        batch_size:Optional[int]=None,
        margin:Optional[float]=0.3,
        num_negatives:Optional[int]=5,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,

        calib_margin:Optional[float]=0.3,
        calib_num_negatives:Optional[int]=10,
        calib_tau:Optional[float]=0.1,
        calib_apply_softmax:Optional[bool]=False,
        calib_loss_weight:Optional[float]=0.1,
        use_calib_loss:Optional[float]=False,
        
        meta_loss_weight:Optional[Union[List,float]]=0.3,
        
        use_fusion_loss:Optional[bool]=False,
        fusion_loss_weight:Optional[float]=0.15,

        use_query_loss:Optional[float]=True,
        
        use_encoder_parallel:Optional[bool]=True,
    ):
        super().__init__(config)
        store_attr('meta_loss_weight,fusion_loss_weight,calib_loss_weight')
        store_attr('data_pred_meta_prefix,lbl2data_pred_meta_prefix')
        store_attr('data_aug_meta_prefix,lbl2data_aug_meta_prefix')
        store_attr('use_fusion_loss,use_query_loss,use_calib_loss,use_encoder_parallel')
        
        self.encoder = None
        self.rep_loss_fn = MultiTriplet(bsz=batch_size, tn_targ=num_batch_labels, margin=margin, n_negatives=num_negatives, 
                                        tau=tau, apply_softmax=apply_softmax, reduce='mean')
        self.cab_loss_fn = Calibration(margin=calib_margin, tau=calib_tau, n_negatives=calib_num_negatives, 
                                       apply_softmax=calib_apply_softmax, reduce='mean')
        
    def init_retrieval_head(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        self.encoder.dr_head.post_init()
        self.encoder.meta_head.post_init()
        self.encoder.dr_fused_head.post_init()

    def init_cross_head(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        self.encoder.cross_head.post_init()
        

    def compute_loss(self, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)

    def calibration_loss(self, einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.calib_loss_weight * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)
    
    def compute_meta_loss(self, data_repr, lbl2data_repr, **kwargs):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_meta_inputs = Parameters.from_meta_pred_prefix(self.data_pred_meta_prefix, **kwargs)
        lbl2data_meta_inputs = Parameters.from_meta_pred_prefix(self.lbl2data_pred_meta_prefix, **kwargs)
        meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}

        m_lw = Parameters.get_meta_loss_weights(self.meta_loss_weight, len(meta_inputs)) if len(meta_inputs) else []
        
        loss = 0.0
        for inputs,lw in zip(meta_inputs.values(), m_lw):
            if 'lbl2data2ptr' in inputs:
                idx = torch.where(inputs['lbl2data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(lbl2data_repr[idx], inputs_o.rep, inputs['lbl2data2ptr'][idx],
                                              inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss

            elif 'data2ptr' in inputs:
                idx = torch.where(inputs['data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(data_repr[idx], inputs_o.rep, inputs['data2ptr'][idx], inputs['idx'], 
                                              inputs['pdata2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss       

            else: raise ValueError('Invalid metadata input arguments.')
        return loss

    def compute_fusion_loss(self, data_repr, meta_repr:Dict, prefix:str, **kwargs):
        meta_inputs = Parameters.from_meta_pred_prefix(prefix, **kwargs)
        
        loss = 0.0
        if meta_repr is not None:
            for key,input_repr in meta_repr.items():
                inputs = meta_inputs[key]
                if 'lbl2data2ptr' in inputs:
                    idx = torch.where(inputs['lbl2data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['lbl2data2ptr'][idx],
                                                  inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                        loss += self.fusion_loss_weight * m_loss
    
                elif 'data2ptr' in inputs:
                    idx = torch.where(inputs['data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
                                                  inputs['pdata2ptr'][idx], inputs['pidx'])
                        loss += self.fusion_loss_weight * m_loss       
    
                else: raise ValueError('Invalid metadata input arguments.')
        return loss


    def get_meta_representation(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_unnormalized=True, data_type="meta")
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

## `OAK001`

In [14]:
#| export
class OAK001(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder(config, num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()

In [None]:
model.encoder.set_meta_embeddings(torch.zeros(656086, 768))

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [None]:
o = func()

In [None]:
o.loss

tensor(0.0670, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK002`

In [None]:
#| export
class Encoder002(Encoder):

    def __init__(
        self, 
        config,
        cross_tau:Optional[float]=0.1, 
        cross_dropout:Optional[float]=0.1, 
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.cross_head = NormCrossAttention(config, tau=cross_tau, dropout=cross_dropout)
        self.post_init()
    

In [None]:
#| export
class OAK002(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        
        cross_tau:Optional[float]=0.1,
        cross_dropout:Optional[float]=0.1,
        
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder002(config, cross_tau=cross_tau, cross_dropout=cross_dropout,
                                  num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = OAK002.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               cross_tau=1.0, cross_dropout=0.1,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()

Some weights of OAK002 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.tau', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_

In [None]:
model.encoder.set_meta_embeddings(torch.zeros(656086, 768))

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

> /tmp/ipykernel_36176/2795290057.py(33)forward()
     31         output_attentions:Optional[bool] = False,
     32     ):
---> 33         bs, q_len, dim = q.size()
     34         v, k_len = k, k.size(1)
     35 



ipdb>  c


> /tmp/ipykernel_36176/2795290057.py(34)forward()
     32     ):
     33         bs, q_len, dim = q.size()
---> 34         v, k_len = k, k.size(1)
     35 
     36         h_dim = self.dim//self.n_h



ipdb>  c


In [None]:
o.loss

tensor(0.0638, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK003`

In [17]:
#| export
class Encoder003(Encoder):

    def __init__(
        self, 
        config,
        num_metadata:int,
        **kwargs
    ):
        super().__init__(config, num_metadata=num_metadata, **kwargs)
        self.pretrained_meta_embeddings = nn.Embedding(num_metadata, config.dim)
        self.post_init()

    def freeze_pretrained_meta_embeddings(self):
        self.pretrained_meta_embeddings.requires_grad_(False)

    def unfreeze_pretrained_meta_embeddings(self):
        self.pretrained_meta_embeddings.requires_grad_(True)

    def set_pretrained_meta_embeddings(self, embed:torch.Tensor):
        self.pretrained_meta_embeddings.weight.data = embed

    def init_meta_embeddings(self):
        self.meta_embeddings.weight.data = torch.zeros_like(self.meta_embeddings.weight.data)

    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
            
            if len(idx):
                m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
                m_repr = F.normalize(self.meta_embeddings(m_idx) + self.pretrained_meta_embeddings(m_idx), dim=1)
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
                meta_repr[m_key] = m_repr[m_repr_mask]
                
                fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
                data_fused_repr[idx] += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr
    

In [18]:
#| export
class OAK003(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder003(config, num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_meta_embeddings(self):
        self.encoder.init_meta_embeddings()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = OAK003.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.3, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=True,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

Some weights of OAK003 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [None]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

> /tmp/ipykernel_26893/4002658968.py(65)resize()
     63         #debug
     64 
---> 65         if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
     66         bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
     67 



ipdb>  num_inputs


tensor([3, 3, 3, 3, 3])


ipdb>  n


> /tmp/ipykernel_26893/4002658968.py(66)resize()
     64 
     65         if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
---> 66         bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
     67 
     68         self.ones = self.ones.to(idx.device)



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(68)resize()
     66         bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
     67 
---> 68         self.ones = self.ones.to(idx.device)
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(71)resize()
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
---> 71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )
     73 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(70)resize()
     68         self.ones = self.ones.to(idx.device)
     69         ones = (
---> 70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
     71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(71)resize()
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
---> 71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )
     73 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(70)resize()
     68         self.ones = self.ones.to(idx.device)
     69         ones = (
---> 70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
     71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(71)resize()
     69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
---> 71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
     72         )
     73 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(69)resize()
     67 
     68         self.ones = self.ones.to(idx.device)
---> 69         ones = (
     70             torch.ones(total_num_inputs, dtype=torch.long, device=idx.device)
     71             if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(74)resize()
     72         )
     73 
---> 74         max_num_inputs = num_inputs.max()
     75         if (num_inputs == max_num_inputs).all():
     76             return idx,ones



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(75)resize()
     73 
     74         max_num_inputs = num_inputs.max()
---> 75         if (num_inputs == max_num_inputs).all():
     76             return idx,ones
     77 



ipdb>  


> /tmp/ipykernel_26893/4002658968.py(76)resize()
     74         max_num_inputs = num_inputs.max()
     75         if (num_inputs == max_num_inputs).all():
---> 76             return idx,ones
     77 
     78         xnum_inputs = max_num_inputs-num_inputs+1



ipdb>  idx.shape


torch.Size([15])


ipdb>  ones.shape


torch.Size([15])


ipdb>  n


--Return--
(tensor([16184...4875,  84879]), tensor([1, 1,..., 1, 1, 1, 1]))
> /tmp/ipykernel_26893/4002658968.py(76)resize()
     74         max_num_inputs = num_inputs.max()
     75         if (num_inputs == max_num_inputs).all():
---> 76             return idx,ones
     77 
     78         xnum_inputs = max_num_inputs-num_inputs+1



ipdb>  n


> /tmp/ipykernel_26893/3123686291.py(36)fuse_meta_into_embeddings()
     34             if len(idx):
     35                 m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
---> 36                 m_repr = F.normalize(self.meta_embeddings(m_idx) + self.pretrained_meta_embeddings(m_idx), dim=1)
     37 
     38                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)



ipdb>  m_idx.shape


torch.Size([15])


ipdb>  m_repr_mask.shape


torch.Size([15])


ipdb>  n


> /tmp/ipykernel_26893/3123686291.py(38)fuse_meta_into_embeddings()
     36                 m_repr = F.normalize(self.meta_embeddings(m_idx) + self.pretrained_meta_embeddings(m_idx), dim=1)
     37 
---> 38                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
     39                 meta_repr[m_key] = m_repr[m_repr_mask]
     40 



ipdb>  c


  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(0.0675, grad_fn=<AddBackward0>)

## `OAK004`

In [None]:
#| export
class Encoder004(Encoder003):

    def __init__(
        self, 
        config,
        cross_tau:Optional[float]=0.1, 
        cross_dropout:Optional[float]=0.1, 
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.cross_head = NormCrossAttention(config, tau=cross_tau, dropout=cross_dropout)
        self.post_init()
    

In [None]:
#| export
class OAK004(OAK003, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK003.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        
        cross_tau:Optional[float]=1.0,
        cross_dropout:Optional[float]=0.1,
        
        **kwargs
    ):
        super().__init__(config, **kwargs, num_metadata=num_metadata, resize_length=resize_length)
        self.encoder = Encoder004(config, cross_tau=cross_tau, cross_dropout=cross_dropout,
                                  num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert

### Example

In [None]:
model = OAK004.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               cross_tau=1.0, cross_dropout=0.1,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()

Some weights of OAK004 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.tau', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_

In [None]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

model.init_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [None]:
o.loss

tensor(0.0618, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK005`

In [None]:
#| export
class Encoder005(Encoder003):

    def __init__(
        self, 
        config,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.gen_head = GenerationHead(config)
        
    def get_output_embeddings(self) -> nn.Module:
        return self.gen_head.projector

    def set_output_embeddings(self, new_embeddings: nn.Module):
        self.gen_head.projector = new_embeddings
    
    def gen(self, x:torch.Tensor):
        return self.gen_head(x)

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,

        data_gen_idx:Optional[torch.Tensor]=None,
        
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
                                                                            torch.any(data_attention_mask, dim=1), 
                                                                            meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_repr)

        data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx]) 
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
            
            logits=data_logits,
        )
    

In [None]:
#| export
class OAK005(OAK003, DistilBertPreTrainedModel):
    use_generation,use_representation = True,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK003.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,

        num_batch_labels:Optional[int]=None, 
        ignore_token:Optional[int]=0,
        gen_loss_weight:Optional[float]=1.0,
        use_gen_loss:Optional[bool]=True,
        
        **kwargs
    ):
        super().__init__(config, **kwargs, num_metadata=num_metadata, resize_length=resize_length, num_batch_labels=num_batch_labels)
        store_attr('use_gen_loss')
        self.g_lw = gen_loss_weight
        self.encoder = Encoder005(config, num_metadata=num_metadata, resize_length=resize_length)

        self.gen_loss_fn = MultiCrossEntropy(tn_targ=num_batch_labels, ig_tok=ignore_token, reduce='mean')
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert

    def compute_gen_loss(self, inp_logits, targ_logits, inp_input_ids, targ_input_ids, targ_ptr):
        gen_loss = self.gen_loss_fn(inp_logits, targ_input_ids, targ_ptr) + self.gen_loss_fn(targ_logits, inp_input_ids)
        return self.g_lw * gen_loss

    def init_generation_head(self):
        self.encoder.gen_head.projector.weight.data = self.get_input_embeddings().weight.data.clone()

    def get_last_item_mask(self, num_input:torch.Tensor, input_sz:int):
        idx = torch.where(num_input > 0)[0]
        input_ptr = num_input[idx].cumsum(dim=0)-1
        return torch.zeros(input_sz, dtype=torch.bool, device=num_input.device).scatter(0, input_ptr, 1)

    def freeze(self):
        for n,p in self.named_parameters():
            p.requires_grad_(False)

    def unfreeze(self):
        for n,p in self.named_parameters():
            p.requires_grad_(True)

    def unfreeze_head(self):
        for n,p in self.encoder.gen_head.named_parameters():
            p.requires_grad_(True)

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
                                 **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_gen_loss:
                loss += self.compute_gen_loss(data_o.logits, lbl2data_o.logits, data_input_ids,lbl2data_input_ids,
                                              lbl2data_data2ptr)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,

            logits=data_o.logits,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
    

### Example

In [None]:
model = OAK005.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,
                               
                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,

                               ignore_token=0, gen_loss_weight=1.0, use_gen_loss=True,
                               
                               use_encoder_parallel=False)

model.init_generation_head()

Some weights of OAK005 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.gen_head.layer_norm.bias', 'encoder.gen_head.layer_norm.weight', 'encoder.gen_head.projector.bias', 'encoder.gen_head.projector.weight', 'enc

In [None]:
model.freeze()
model.unfreeze_head()

In [None]:
model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 
    'lnk2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

> /tmp/ipykernel_3071/2330200620.py(58)forward()
     56         **kwargs
     57     ):  
---> 58         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     59 
     60         if self.use_encoder_parallel:



ipdb>  c


> /tmp/ipykernel_3071/1501785349.py(33)forward()
     31         **kwargs
     32     ):  
---> 33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
     35         if data_type is not None and data_type == "meta":



ipdb>  c


> /tmp/ipykernel_3071/1501785349.py(33)forward()
     31         **kwargs
     32     ):  
---> 33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
     35         if data_type is not None and data_type == "meta":



ipdb>  c


In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [None]:
o = func()

> /tmp/ipykernel_3071/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_3071/2330200620.py:43


ipdb>  b model.encoder.forward


Breakpoint 2 at /tmp/ipykernel_3071/1501785349.py:21


ipdb>  c


> /tmp/ipykernel_3071/2330200620.py(58)forward()
     56         **kwargs
     57     ):  
---> 58         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     59 
     60         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(60)forward()
     58         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     59 
---> 60         if self.use_encoder_parallel:
     61             encoder = XCDataParallel(module=self.encoder)
     62         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(62)forward()
     60         if self.use_encoder_parallel:
     61             encoder = XCDataParallel(module=self.encoder)
---> 62         else: encoder = self.encoder
     63 
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(64)forward()
     62         else: encoder = self.encoder
     63 
---> 64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(65)forward()
     63 
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     67 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(66)forward()
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     67 
     68 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(65)forward()
     63 
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     67 



ipdb>  s


> /tmp/ipykernel_3071/2330200620.py(66)forward()
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     67 
     68 



ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(65)forward()
     63 
     64         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 65         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     66                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     67 



ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1507)_wrapped_call_impl()
   1505         return result
   1506 
-> 1507     def _wrapped_call_impl(self, *args, **kwargs):
   1508         if self._compiled_call_impl is not None:
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]



ipdb>  c


> /tmp/ipykernel_3071/1501785349.py(33)forward()
     31         **kwargs
     32     ):  
---> 33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
     35         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(35)forward()
     33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
---> 35         if data_type is not None and data_type == "meta":
     36             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
     37         else:



ipdb>  data_o.shape


*** AttributeError: 'BaseModelOutput' object has no attribute 'shape'


ipdb>  data_o[0].shape


torch.Size([5, 5, 768])


ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(38)forward()
     36             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
     37         else:
---> 38             data_repr = self.dr(data_o[0], data_attention_mask)
     39 
     40         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(40)forward()
     38             data_repr = self.dr(data_o[0], data_attention_mask)
     39 
---> 40         data_fused_repr = meta_repr = None
     41         if data_aug_meta_prefix is not None:
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(41)forward()
     39 
     40         data_fused_repr = meta_repr = None
---> 41         if data_aug_meta_prefix is not None:
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     43             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(42)forward()
     40         data_fused_repr = meta_repr = None
     41         if data_aug_meta_prefix is not None:
---> 42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     43             if len(meta_kwargs):
     44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(43)forward()
     41         if data_aug_meta_prefix is not None:
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
---> 43             if len(meta_kwargs):
     44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
     45                                                                             torch.any(data_attention_mask, dim=1),



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(44)forward()
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     43             if len(meta_kwargs):
---> 44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
     45                                                                             torch.any(data_attention_mask, dim=1),
     46                                                                             meta_kwargs)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(45)forward()
     43             if len(meta_kwargs):
     44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
---> 45                                                                             torch.any(data_attention_mask, dim=1),
     46                                                                             meta_kwargs)
     47                 data_fused_repr = self.dr_fused(data_fused_repr)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(46)forward()
     44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
     45                                                                             torch.any(data_attention_mask, dim=1),
---> 46                                                                             meta_kwargs)
     47                 data_fused_repr = self.dr_fused(data_fused_repr)
     48 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(44)forward()
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     43             if len(meta_kwargs):
---> 44                 data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
     45                                                                             torch.any(data_attention_mask, dim=1),
     46                                                                             meta_kwargs)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(47)forward()
     45                                                                             torch.any(data_attention_mask, dim=1),
     46                                                                             meta_kwargs)
---> 47                 data_fused_repr = self.dr_fused(data_fused_repr)
     48 
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(49)forward()
     47                 data_fused_repr = self.dr_fused(data_fused_repr)
     48 
---> 49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
     51         return EncoderOutput(



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  data_logits.shape


torch.Size([5, 5, 30522])


ipdb>  data_gen_idx
ipdb>  data_gen_idx is None


True


ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(52)forward()
     50 
     51         return EncoderOutput(
---> 52             rep=data_repr,
     53             fused_rep=data_fused_repr,
     54             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(53)forward()
     51         return EncoderOutput(
     52             rep=data_repr,
---> 53             fused_rep=data_fused_repr,
     54             meta_repr=meta_repr,
     55 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(54)forward()
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,
---> 54             meta_repr=meta_repr,
     55 
     56             logits=data_logits,



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(56)forward()
     54             meta_repr=meta_repr,
     55 
---> 56             logits=data_logits,
     57         )
     58 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...xBackward0>)})
> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  n


--Return--
EncoderOutput...xBackward0>)})
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1511)_wrapped_call_impl()
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(69)forward()
     67 
     68 
---> 69         loss = None; lbl2data_o = EncoderOutput()
     70         if lbl2data_input_ids is not None:
     71             lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(70)forward()
     68 
     69         loss = None; lbl2data_o = EncoderOutput()
---> 70         if lbl2data_input_ids is not None:
     71             lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))
     72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(71)forward()
     69         loss = None; lbl2data_o = EncoderOutput()
     70         if lbl2data_input_ids is not None:
---> 71             lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))
     72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     73 



ipdb>  lbl2data_data2ptr


tensor([1, 1, 1, 1, 1])


ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(72)forward()
     70         if lbl2data_input_ids is not None:
     71             lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))
---> 72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     73 
     74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 



ipdb>  lbl2data_gen_idx


tensor([True, True, True, True, True])


ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(74)forward()
     72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     73 
---> 74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
     76                                  **lbl2data_meta_kwargs)



ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(75)forward()
     73 
     74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
---> 75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
     76                                  **lbl2data_meta_kwargs)
     77 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(74)forward()
     72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     73 
---> 74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
     76                                  **lbl2data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(76)forward()
     74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
---> 76                                  **lbl2data_meta_kwargs)
     77 
     78             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(74)forward()
     72             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
     73 
---> 74             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
     75                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
     76                                  **lbl2data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(33)forward()
     31         **kwargs
     32     ):  
---> 33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
     35         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(35)forward()
     33         data_o = self.encode(data_input_ids, data_attention_mask)
     34 
---> 35         if data_type is not None and data_type == "meta":
     36             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
     37         else:



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(38)forward()
     36             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
     37         else:
---> 38             data_repr = self.dr(data_o[0], data_attention_mask)
     39 
     40         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(40)forward()
     38             data_repr = self.dr(data_o[0], data_attention_mask)
     39 
---> 40         data_fused_repr = meta_repr = None
     41         if data_aug_meta_prefix is not None:
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(41)forward()
     39 
     40         data_fused_repr = meta_repr = None
---> 41         if data_aug_meta_prefix is not None:
     42             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
     43             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(49)forward()
     47                 data_fused_repr = self.dr_fused(data_fused_repr)
     48 
---> 49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
     51         return EncoderOutput(



ipdb>  data_gen_idx


tensor([True, True, True, True, True])


ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  data_logits.shape


torch.Size([5, 11, 30522])


ipdb>  n


> /tmp/ipykernel_3071/1501785349.py(52)forward()
     50 
     51         return EncoderOutput(
---> 52             rep=data_repr,
     53             fused_rep=data_fused_repr,
     54             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(53)forward()
     51         return EncoderOutput(
     52             rep=data_repr,
---> 53             fused_rep=data_fused_repr,
     54             meta_repr=meta_repr,
     55 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(54)forward()
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,
---> 54             meta_repr=meta_repr,
     55 
     56             logits=data_logits,



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(56)forward()
     54             meta_repr=meta_repr,
     55 
---> 56             logits=data_logits,
     57         )
     58 



ipdb>  


> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...eta_repr=None)
> /tmp/ipykernel_3071/1501785349.py(51)forward()
     49         data_logits = self.gen(data_o[0] if data_gen_idx is None else data_o[0][data_gen_idx])
     50 
---> 51         return EncoderOutput(
     52             rep=data_repr,
     53             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(78)forward()
     76                                  **lbl2data_meta_kwargs)
     77 
---> 78             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     79                                      plbl2data_data2ptr,plbl2data_idx)
     80 



ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(79)forward()
     77 
     78             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
---> 79                                      plbl2data_data2ptr,plbl2data_idx)
     80 
     81             if self.use_gen_loss:



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(78)forward()
     76                                  **lbl2data_meta_kwargs)
     77 
---> 78             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     79                                      plbl2data_data2ptr,plbl2data_idx)
     80 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(81)forward()
     79                                      plbl2data_data2ptr,plbl2data_idx)
     80 
---> 81             if self.use_gen_loss:
     82                 loss += self.compute_gen_loss(data_o.logits, lbl2data_o.logits, data_input_ids,lbl2data_input_ids,
     83                                               lbl2data_data2ptr)



ipdb>  self.use_gen_loss


True


ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(82)forward()
     80 
     81             if self.use_gen_loss:
---> 82                 loss += self.compute_gen_loss(data_o.logits, lbl2data_o.logits, data_input_ids,lbl2data_input_ids,
     83                                               lbl2data_data2ptr)
     84 



ipdb>  data_input_ids.shape


torch.Size([5, 5])


ipdb>  lbl2data_input_ids.shape


torch.Size([5, 11])


ipdb>  lbl2data_data2ptr.shape


torch.Size([5])


ipdb>  n


> /tmp/ipykernel_3071/2330200620.py(83)forward()
     81             if self.use_gen_loss:
     82                 loss += self.compute_gen_loss(data_o.logits, lbl2data_o.logits, data_input_ids,lbl2data_input_ids,
---> 83                                               lbl2data_data2ptr)
     84 
     85             if self.use_query_loss:



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(82)forward()
     80 
     81             if self.use_gen_loss:
---> 82                 loss += self.compute_gen_loss(data_o.logits, lbl2data_o.logits, data_input_ids,lbl2data_input_ids,
     83                                               lbl2data_data2ptr)
     84 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(85)forward()
     83                                               lbl2data_data2ptr)
     84 
---> 85             if self.use_query_loss:
     86                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     87                                           plbl2data_data2ptr,plbl2data_idx)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(86)forward()
     84 
     85             if self.use_query_loss:
---> 86                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     87                                           plbl2data_data2ptr,plbl2data_idx)
     88 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(87)forward()
     85             if self.use_query_loss:
     86                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
---> 87                                           plbl2data_data2ptr,plbl2data_idx)
     88 
     89             if self.use_calib_loss:



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(86)forward()
     84 
     85             if self.use_query_loss:
---> 86                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     87                                           plbl2data_data2ptr,plbl2data_idx)
     88 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(89)forward()
     87                                           plbl2data_data2ptr,plbl2data_idx)
     88 
---> 89             if self.use_calib_loss:
     90                 loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     91                                               plbl2data_data2ptr,plbl2data_idx)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(93)forward()
     91                                               plbl2data_data2ptr,plbl2data_idx)
     92 
---> 93             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
     94 
     95             if self.use_fusion_loss:



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(95)forward()
     93             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
     94 
---> 95             if self.use_fusion_loss:
     96                 loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
     97                 loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(100)forward()
     98 
     99 
--> 100         if not return_dict:
    101             o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
    102             return ((loss,) + o) if loss is not None else o



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(105)forward()
    103 
    104 
--> 105         return XCModelOutput(
    106             loss=loss,
    107 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(106)forward()
    104 
    105         return XCModelOutput(
--> 106             loss=loss,
    107 
    108             data_repr=data_o.rep,



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(108)forward()
    106             loss=loss,
    107 
--> 108             data_repr=data_o.rep,
    109             data_fused_repr=data_o.fused_rep,
    110 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(109)forward()
    107 
    108             data_repr=data_o.rep,
--> 109             data_fused_repr=data_o.fused_rep,
    110 
    111             logits=data_o.logits,



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(111)forward()
    109             data_fused_repr=data_o.fused_rep,
    110 
--> 111             logits=data_o.logits,
    112 
    113             lbl2data_repr=lbl2data_o.rep,



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(113)forward()
    111             logits=data_o.logits,
    112 
--> 113             lbl2data_repr=lbl2data_o.rep,
    114             lbl2data_fused_repr=lbl2data_o.fused_rep,
    115         )



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(114)forward()
    112 
    113             lbl2data_repr=lbl2data_o.rep,
--> 114             lbl2data_fused_repr=lbl2data_o.fused_rep,
    115         )
    116 



ipdb>  


> /tmp/ipykernel_3071/2330200620.py(105)forward()
    103 
    104 
--> 105         return XCModelOutput(
    106             loss=loss,
    107 



ipdb>  


--Return--
XCModelOutput...sed_repr=None)
> /tmp/ipykernel_3071/2330200620.py(105)forward()
    103 
    104 
--> 105         return XCModelOutput(
    106             loss=loss,
    107 



ipdb>  


--Return--
XCModelOutput...sed_repr=None)
> /tmp/ipykernel_3071/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  


--Return--
None
> /tmp/ipykernel_3071/492731717.py(1)<module>()
----> 1 o = func()



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [None]:
o.loss

tensor(21.9783, grad_fn=<AddBackward0>)

## `OAK006`

In [None]:
#| export
class Encoder006(DistilBertPreTrainedModel):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
        num_metadata:int,
        resize_length:Optional[int]=None,
    ):
        super().__init__(config)
        self.distilbert = DistilBertModel(config)
        
        self.dr_head = RepresentationHead(config)
        self.dr_fused_head = RepresentationHead(config)
        self.meta_head = RepresentationHead(config)
        self.cross_head = CrossAttention(config)
        self.meta_embeddings = nn.Embedding(num_metadata, config.dim, sparse=True)
        self.pretrained_meta_embeddings = nn.Embedding(num_metadata, config.dim)

        self.ones = torch.ones(resize_length, dtype=torch.long, device=self.device) if resize_length is not None else None
        self.post_init()

    def freeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(False)

    def unfreeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(True)

    def set_meta_embeddings(self, embed:torch.Tensor):
        self.meta_embeddings.weight.data = embed

    def init_meta_embeddings(self):
        self.meta_embeddings.weight.data = torch.zeros_like(self.meta_embeddings.weight.data)

    def freeze_pretrained_meta_embeddings(self):
        self.pretrained_meta_embeddings.requires_grad_(False)

    def unfreeze_pretrained_meta_embeddings(self):
        self.pretrained_meta_embeddings.requires_grad_(True)

    def set_pretrained_meta_embeddings(self, embed:torch.Tensor):
        self.pretrained_meta_embeddings.weight.data = embed
        
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
    
    def dr(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def dr_fused(self, embed:torch.Tensor):
        embed = self.dr_fused_head(embed)
        return F.normalize(embed, dim=1)

    def meta(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)
    
    def meta_unnormalized(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return Pooling.mean_pooling(embed, attention_mask)

    def resize(self, idx:torch.Tensor, num_inputs:torch.Tensor):
        if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
        bsz, total_num_inputs = num_inputs.shape[0], idx.shape[0]
        
        self.ones = self.ones.to(idx.device)
        ones = (
            torch.ones(total_num_inputs, dtype=torch.long, device=idx.device) 
            if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
        )

        max_num_inputs = num_inputs.max()
        if (num_inputs == max_num_inputs).all():
            return idx,ones
        
        xnum_inputs = max_num_inputs-num_inputs+1

        inputs_ptr = num_inputs.cumsum(dim=0)-1
        repeat_inputs = ones.scatter(0, inputs_ptr, xnum_inputs)
        
        resized_idx = idx.repeat_interleave(repeat_inputs, dim=0)
        ignore_mask = ones.scatter(0, inputs_ptr, 0).repeat_interleave(repeat_inputs, dim=0).view(bsz, -1)
        ignore_mask[:, -1] = 1; ignore_mask = ignore_mask.flatten()
        
        return resized_idx,ignore_mask

    def fuse_meta_into_embeddings(self, data_repr:torch.Tensor, data_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        data_fused_repr, data_mask = data_repr.clone().view(-1, 1, self.config.dim), data_mask.view(-1, 1)
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(data_repr)
            
            if len(idx):
                m_idx,m_repr_mask = self.resize(m_args['idx'], m_args['data2ptr'][idx])
                m_repr = F.normalize(self.meta_embeddings(m_idx) + self.pretrained_meta_embeddings(m_idx), dim=1)
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.bool().view(len(idx), -1)
                meta_repr[m_key] = m_repr[m_repr_mask]
                
                fused_repr = self.cross_head(data_fused_repr[idx], data_mask[idx], m_repr, m_repr_mask)[0]
                data_fused_repr[idx] += fused_repr
                
        return data_fused_repr.squeeze(), meta_repr

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):  
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_repr, meta_repr = self.fuse_meta_into_embeddings(data_repr, 
                                                                            torch.any(data_attention_mask, dim=1), 
                                                                            meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_repr)
                
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

In [None]:
#| export
class OAK006(OAK000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(OAK000.__init__)
    def __init__(
        self, 
        config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder006(config, num_metadata=num_metadata, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_meta_embeddings(self):
        self.encoder.init_meta_embeddings()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = OAK006.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

Some weights of OAK006 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [None]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.0160, device='cuda:0', grad_fn=<AddBackward0>)

## `OAK007`

In [20]:
#| export
class OAK007(OAK003, DistilBertPreTrainedModel):
    
    @delegates(OAK003.__init__)
    def __init__(
        self, 
        config,
        n_labels:int,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.label_embeddings = nn.Embedding(n_labels, config.dim)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def init_label_embeddings(self):
        self.label_embeddings.weight.data = torch.zeros_like(self.label_embeddings.weight.data)

    def get_label_representation(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask)
        data_o.rep = F.normalize(data_o.rep + self.label_embeddings(data_idx), dim=1)
        return XCModelOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    def forward(
        self,
        data_idx:Optional[torch.Tensor]=None,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)

        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            lbl2data_o.rep = F.normalize(lbl2data_o.rep + self.label_embeddings(lbl2data_idx), dim=1)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)

            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

### Example

In [22]:
model = OAK007.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, resize_length=5000, n_labels=block.n_lbl,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=False,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()
model.init_label_embeddings()

Some weights of OAK007 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [None]:
model.encoder.set_pretrained_meta_embeddings(torch.zeros(656086, 768))
model.encoder.freeze_pretrained_meta_embeddings()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
])

In [30]:
o = model(**b.to(model.device))

In [31]:
o.loss

tensor(0.0609, device='cuda:0', grad_fn=<AddBackward0>)