In [None]:
#| default_exp models.radga

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import torch, re, inspect, pickle, os, torch.nn as nn, math
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, List, Tuple, Mapping, Any, Union
from transformers import (
    PretrainedConfig,
    DistilBertForMaskedLM,
    DistilBertModel,
    DistilBertPreTrainedModel,
)
from transformers.utils.generic import ModelOutput
from transformers.activations import get_activation

from fastcore.meta import *
from fastcore.utils import *

from xcai.losses import *
from xcai.core import store_attr
from xcai.learner import XCDataParallel

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
from transformers import AutoConfig
from xcai.block import *

## Setup

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC/data/'

In [None]:
block = XCBlock.from_cfg(data_dir, 'data_metas', tfm='rm', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data|cat2lbl2data',1,(1,3)), ('cat2data',1,3)])



In [None]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-meta_distilbert-base-uncased_rm_radga-cat.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [None]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
batch = block.train.one_batch(5)
for i,batch in enumerate(block.train.dl):
    if i > 2: break

## Helper

In [None]:
#| export
@dataclass
class RADOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    data_repr: Optional[torch.FloatTensor] = None
    data_fused_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    lbl2data_fused_repr: Optional[torch.FloatTensor] = None
        

In [None]:
#| export
@dataclass
class EncoderOutput(ModelOutput):
    rep: Optional[torch.FloatTensor] = None
    fused_rep: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    fusion_weights: Optional[torch.FloatTensor] = None
    meta_repr: Optional[torch.FloatTensor] = None
        

In [None]:
#| export
class Pooling:

    @staticmethod
    def mean_pooling(data_embeds:torch.FloatTensor, data_attention_mask:torch.LongTensor):
        data_attention_mask = data_attention_mask.unsqueeze(2).expand(data_embeds.size()).float()
        return torch.sum(data_embeds * data_attention_mask, 1) / torch.clamp(data_attention_mask.sum(1), min=1e-9)


## CrossAttention

In [None]:
#| export
class CrossAttention(nn.Module):
    
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config, self.n_h, self.dim = config, config.n_heads, config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        if self.dim % self.n_h != 0:
            raise ValueError(f"self.n_heads: {self.n_h} must divide self.dim: {self.dim} evenly.")
            
        self.q = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.o = nn.Linear(in_features=config.dim, out_features=config.dim)

    def post_init(self):
        self.q.weight.data = torch.eye(self.q.out_features, self.q.in_features, dtype=self.q.weight.dtype)
        self.k.weight.data = torch.eye(self.k.out_features, self.k.in_features, dtype=self.k.weight.dtype)
        self.v.weight.data = torch.eye(self.v.out_features, self.v.in_features, dtype=self.v.weight.dtype)
        self.o.weight.data = torch.eye(self.o.out_features, self.o.in_features, dtype=self.o.weight.dtype)

    def forward(
        self, 
        q: torch.Tensor,
        q_m: torch.Tensor,
        k: torch.Tensor, 
        k_m: torch.Tensor,
        output_attentions:Optional[bool] = False,
    ):
        bs, q_len, dim = q.size()
        v, k_len = k, k.size(1) 

        h_dim = self.dim//self.n_h

        def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)

        def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)

        q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
        k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
        v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)

        q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
        sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
        
        q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
        mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
        
        sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)

        w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
        w = self.dropout(w)  # (bs, n_h, q_len, k_len)

        o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
        
        if output_attentions: return (o, w)
        else: return (o,)
        

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
fuser = CrossAttention(config)

In [None]:
bsz, data_seq_len, n_meta, dim, dtype = 2, 3, 2, config.dim, torch.float32
data, meta = torch.randn(bsz, data_seq_len, dim, dtype=dtype), torch.randn(bsz, n_meta, dim, dtype=dtype)
data_mask = torch.randint(0, 2, size=(bsz,data_seq_len), dtype=dtype)
meta_mask = torch.randint(0, 2, size=(bsz,n_meta), dtype=dtype)

In [None]:
o = fuser(data, data_mask, meta, meta_mask)

In [None]:
o[0].shape

torch.Size([2, 3, 768])

## GatedCrossAttention

In [None]:
#| export
class GatedCrossAttention(nn.Module):
    
    def __init__(self, config: PretrainedConfig, margin:Optional[float]=0.3, tau:Optional[float]=0.1):
        super().__init__()
        store_attr('margin,tau')
        self.config, self.n_h, self.dim = config, config.n_heads, config.dim
        self.dropout = nn.Dropout(p=config.attention_dropout)

        if self.dim % self.n_h != 0:
            raise ValueError(f"self.n_heads: {self.n_h} must divide self.dim: {self.dim} evenly.")
            
        self.q = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.o = nn.Linear(in_features=config.dim, out_features=config.dim)

    def post_init(self):
        self.q.weight.data = torch.eye(self.q.out_features, self.q.in_features, dtype=self.q.weight.dtype)
        self.k.weight.data = torch.eye(self.k.out_features, self.k.in_features, dtype=self.k.weight.dtype)
        self.v.weight.data = torch.eye(self.v.out_features, self.v.in_features, dtype=self.v.weight.dtype)
        self.o.weight.data = torch.eye(self.o.out_features, self.o.in_features, dtype=self.o.weight.dtype)

    def forward(
        self, 
        q: torch.Tensor,
        q_m: torch.Tensor,
        k: torch.Tensor, 
        k_m: torch.Tensor,
        output_attentions:Optional[bool] = False,
    ):
        bs, q_len, dim = q.size()
        v, k_len = k, k.size(1) 

        h_dim = self.dim//self.n_h

        def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)

        def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)

        q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
        k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
        v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)

        q = q / math.sqrt(h_dim)  # (bs, n_h, q_len, h_dim)
        sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
        sc = sc / self.tau

        q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
        mask = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
        sc = sc.masked_fill(mask == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)

        q_norm, k_norm = F.normalize(q, dim=-1), F.normalize(k, dim=-1)
        gated_sc = torch.matmul(q_norm, k_norm.transpose(2, 3))
        gated_sc = F.relu(gated_sc)
        gated_mask = gated_sc != 0 
        sc = sc.masked_fill(gated_mask == 0, torch.tensor(torch.finfo(sc.dtype).min))

        w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
        w = w * mask * gated_mask
        w = self.dropout(w)  # (bs, n_h, q_len, k_len)

        o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
        
        if output_attentions: return (o, w)
        else: return (o,)
        

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
fuser = GatedCrossAttention(config)

In [None]:
bsz, data_seq_len, n_meta, dim, dtype = 2, 3, 2, config.dim, torch.float32

data, meta = torch.randn(bsz, data_seq_len, dim, dtype=dtype), torch.randn(bsz, n_meta, dim, dtype=dtype)
data_mask = torch.randint(0, 2, size=(bsz,data_seq_len), dtype=dtype)
meta_mask = torch.randint(0, 2, size=(bsz,n_meta), dtype=dtype)

In [None]:
o = fuser(data, data_mask, meta, meta_mask)

In [None]:
o[0].shape

torch.Size([2, 3, 768])

## GatedCrossAttention2

In [None]:
#| export
class GatedCrossAttention2(nn.Module):
    
    def __init__(self, config: PretrainedConfig, margin:Optional[float]=0.3, tau:Optional[float]=0.1, dropout:Optional[float]=0.1):
        super().__init__()
        self.margin = nn.Parameter(torch.tensor(margin, dtype=torch.float32))
        self.tau = nn.Parameter(torch.tensor(tau, dtype=torch.float32))
        
        self.config, self.n_h, self.dim = config, config.n_heads, config.dim
        self.dropout = nn.Dropout(p=dropout)

        if self.dim % self.n_h != 0:
            raise ValueError(f"self.n_heads: {self.n_h} must divide self.dim: {self.dim} evenly.")
            
        self.q = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.k = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.v = nn.Linear(in_features=config.dim, out_features=config.dim)
        self.o = nn.Linear(in_features=config.dim, out_features=config.dim)

    def post_init(self):
        self.q.weight.data = torch.eye(self.q.out_features, self.q.in_features, dtype=self.q.weight.dtype)
        self.k.weight.data = torch.eye(self.k.out_features, self.k.in_features, dtype=self.k.weight.dtype)
        self.v.weight.data = torch.eye(self.v.out_features, self.v.in_features, dtype=self.v.weight.dtype)
        self.o.weight.data = torch.eye(self.o.out_features, self.o.in_features, dtype=self.o.weight.dtype)

    def forward(
        self, 
        q: torch.Tensor,
        q_m: torch.Tensor,
        k: torch.Tensor, 
        k_m: torch.Tensor,
        output_attentions:Optional[bool] = False,
    ):
        bs, q_len, dim = q.size()
        v, k_len = k, k.size(1) 

        h_dim = self.dim//self.n_h

        def shape(x: torch.Tensor): return x.view(bs, -1, self.n_h, h_dim).transpose(1, 2)

        def unshape(x: torch.Tensor): return x.transpose(1, 2).contiguous().view(bs, -1, self.n_h * h_dim)

        q = shape(self.q(q))  # (bs, n_h, q_len, h_dim)
        k = shape(self.k(k))  # (bs, n_h, k_len, h_dim)
        v = shape(self.v(v))  # (bs, n_h, k_len, h_dim)

        q,k = F.normalize(q, dim=-1),F.normalize(k, dim=-1)
        sc = torch.matmul(q, k.transpose(2, 3))  # (bs, n_h, q_len, k_len)
        sc = F.relu(sc - self.margin)

        sc = sc / self.tau
        
        q_m, k_m = q_m.view(bs, 1, -1, 1).to(q.dtype), k_m.view(bs, 1, 1, -1).to(q.dtype)
        mask1 = torch.matmul(q_m, k_m).expand_as(sc)  # (bs, n_h, q_len, k_len)
        sc = sc.masked_fill(mask1 == 0, torch.tensor(torch.finfo(sc.dtype).min))  # (bs, n_h, q_len, k_len)
        mask2 = sc != 0
        sc = sc.masked_fill(mask2 == 0, torch.tensor(torch.finfo(sc.dtype).min))
        
        w = nn.functional.softmax(sc, dim=-1)  # (bs, n_h, q_len, k_len)
        w = self.dropout(w * mask1 * mask2)  # (bs, n_h, q_len, k_len)

        o = self.o(unshape(torch.matmul(w, v))) # (bs, q_len, dim)
        
        if output_attentions: return (o, w)
        else: return (o,)
        

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
fuser = GatedCrossAttention2(config)

In [None]:
bsz, data_seq_len, n_meta, dim, dtype = 2, 3, 2, config.dim, torch.float32

data, meta = torch.randn(bsz, data_seq_len, dim, dtype=dtype), torch.randn(bsz, n_meta, dim, dtype=dtype)
data_mask = torch.randint(0, 2, size=(bsz,data_seq_len), dtype=dtype)
meta_mask = torch.randint(0, 2, size=(bsz,n_meta), dtype=dtype)

In [None]:
o = fuser(data, data_mask, meta, meta_mask)

In [None]:
o[0].shape

torch.Size([2, 3, 768])

## Blocks

In [None]:
#| export
class RepresentationHead(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.transform = nn.Linear(config.dim, config.dim)
        self.layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.projector = nn.Linear(config.dim, config.dim)
        self.activation = get_activation(config.activation)
        
        self.post_init()
        
    def post_init(self):
        self.transform.weight.data = torch.eye(self.transform.out_features, self.transform.in_features, 
                                               dtype=self.transform.weight.dtype)
        self.projector.weight.data = torch.eye(self.projector.out_features, self.projector.in_features, 
                                               dtype=self.projector.weight.dtype)
        
    def forward(self, x:torch.Tensor):
        x = self.transform(x)
        x = self.activation(x)
        x = self.layer_norm(x)
        x = self.projector(x)
        return x
    

In [None]:
#| export
class GenerationHead(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.transform = nn.Linear(config.dim, config.dim)
        self.layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
        self.projector = nn.Linear(config.dim, config.vocab_size)
        self.activation = get_activation(config.activation)
        
    def forward(self, x:torch.Tensor):
        x = self.transform(x)
        x = self.activation(x)
        x = self.layer_norm(x)
        x = self.projector(x)
        return x
    

### Example

In [None]:
config = AutoConfig.from_pretrained('distilbert-base-uncased')
x = torch.randn(10, 20, config.dim)

In [None]:
m = RepresentationHead(config)

In [None]:
m = GenerationHead(config)

## Parameters

In [None]:
#| export
class Parameters:
    
    @staticmethod
    def from_meta_aug_prefix(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^{prefix}.*_(input_ids|attention_mask|data2ptr|meta_repr|idx)$', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs
    
    @staticmethod
    def from_feat_meta_aug_prefix(feat:str, prefix:str, **kwargs):
        keys = ['attention_mask', 'input_ids', 'meta_repr', 'idx']
        
        inputs = {f'{prefix}_{k}': kwargs[f'{prefix}_{k}'] for k in keys if f'{prefix}_{k}' in kwargs}
        if prefix is not None and f'{prefix}_{feat}2ptr' in kwargs:
            inputs.update({f'{prefix}_data2ptr': kwargs[f'{prefix}_{feat}2ptr']})
        return inputs
    
    @staticmethod
    def from_meta_pred_prefix(prefix:str, **kwargs):
        inputs = {}
        args = [arg for arg in kwargs if prefix is not None and re.match(f'^[p]?{prefix}.*', arg)]
        for arg in args:
            meta,param = arg.split('_', maxsplit=1)
            if arg[0] == 'p': 
                inputs.setdefault(meta[1:], {})[f'p{param}'] = kwargs[arg]
            else: 
                inputs.setdefault(meta, {})[param] = kwargs[arg]
        return inputs

    @staticmethod
    def get_meta_loss_weights(lw:Union[float,List], n_meta:int):
        if isinstance(lw, float):
            lw = lw/n_meta if n_meta else None
            return [lw] * n_meta
        else:
            if len(lw) != n_meta: raise ValueError(f'length of `lw` should be equal to number of metadata.')
            return lw
        

### Example

In [None]:
b = next(iter(block.train.dl))

In [None]:
p = Parameters.from_meta_aug_prefix('cat', **b); p.keys()

dict_keys(['cat2lbl', 'cat2data'])

In [None]:
p = Parameters.from_feat_meta_aug_prefix('data', 'cat2lbl', **b); p.keys()

dict_keys(['cat2lbl_attention_mask', 'cat2lbl_input_ids', 'cat2lbl_data2ptr'])

In [None]:
p = Parameters.from_meta_pred_prefix('cat', **b); p.keys()

dict_keys(['cat2lbl', 'cat2data'])

## Encoder

In [None]:
#| export
class Encoder(DistilBertPreTrainedModel):
    
    def __init__(self, config:PretrainedConfig, use_noise:Optional[bool]=True, noise_pct:Optional[float]=0.5, resize_length:Optional[int]=None):
        super().__init__(config)
        store_attr('use_noise,noise_pct')
        self.distilbert = DistilBertModel(config)
        
        self.dr_head = RepresentationHead(config)
        self.dr_fused_head =  RepresentationHead(config)
        self.meta_head = RepresentationHead(config)
        self.cross_head = CrossAttention(config)
         
        self.ones = torch.ones(resize_length, dtype=torch.long, device=self.device) if resize_length is not None else None
        self.post_init()
        
    def get_position_embeddings(self) -> nn.Embedding:
        return self.distilbert.get_position_embeddings()
    
    def resize_position_embeddings(self, new_num_position_embeddings: int):
        self.distilbert.resize_position_embeddings(new_num_position_embeddings)
        
        
        
    def resize(self, inputs:torch.Tensor, mask:torch.Tensor, num_inputs:torch.Tensor):
        if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
        bsz, dim, total_num_inputs = num_inputs.shape[0], inputs.shape[-1], inputs.shape[0]
        
        self.ones = self.ones.to(inputs.device)
        ones = (
            torch.ones(total_num_inputs, dtype=torch.long, device=inputs.device) 
            if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
        )

        max_num_inputs = num_inputs.max()
        xnum_inputs = max_num_inputs-num_inputs+1

        inputs_ptr = num_inputs.cumsum(dim=0)-1
        repeat_inputs = ones.scatter(0, inputs_ptr, xnum_inputs)
        
        resized_inputs = inputs.repeat_interleave(repeat_inputs, dim=0)
        resized_mask = mask.repeat_interleave(repeat_inputs, dim=0)
        
        ignore_mask_idx = ones.scatter(0, inputs_ptr, 0).repeat_interleave(repeat_inputs, dim=0).view(bsz, -1)
        ignore_mask_idx[:, -1] = 1; ignore_mask_idx = ignore_mask_idx.view(-1, 1)
        
        resized_mask *= ignore_mask_idx
        
        return resized_inputs,resized_mask
        
    def add_noise(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, n_meta:int):
        n_data, dim = input_ids.shape[0]//n_meta, input_ids.shape[1]
        noise_mask = torch.rand(n_meta, n_data, device=input_ids.device) < self.noise_pct
        
        input_ids, attention_mask = input_ids.view(n_data, n_meta, -1), attention_mask.view(n_data, n_meta, -1)
        for i,mask in enumerate(noise_mask):
            rnd_idx = torch.randperm(mask.sum())
            input_ids[:,i][mask] = input_ids[:,i][mask][rnd_idx]
            attention_mask[:,i][mask] = attention_mask[:,i][mask][rnd_idx]
        return input_ids.view(-1, dim), attention_mask.view(-1, dim)
        

    
    def encode(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, **kwargs):
        return self.distilbert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )
    
    def dr(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def dr_fused(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.dr_fused_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)

    def meta(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return F.normalize(Pooling.mean_pooling(embed, attention_mask), dim=1)
    
    def meta_unnormalized(self, embed:torch.Tensor, attention_mask:torch.Tensor):
        embed = self.meta_head(embed)
        return Pooling.mean_pooling(embed, attention_mask)


    
    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            
            if len(idx):
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
                                                                m_args['data2ptr'][idx])
                    n_meta = m_args['data2ptr'].max()
                    
                    if self.use_noise:
                        m_input_ids, m_attention_mask = self.add_noise(m_input_ids, m_attention_mask, n_meta)

                    m_embed = self.encode(m_input_ids, m_attention_mask)[0]

                    m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
                    m_repr_mask = torch.any(m_attention_mask, dim=1)
                    
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
                
                meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
                
                fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
                embed[idx] += fused_embed
                
        return embed, meta_repr

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
                                                                             data_attention_mask, 
                                                                             meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
        
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

## `RAD000`

In [None]:
#| export
class RAD000(nn.Module):
    
    def __init__(
        self, config,
        
        num_batch_labels:Optional[int]=None, 
        batch_size:Optional[int]=None,
        margin:Optional[float]=0.3,
        num_negatives:Optional[int]=5,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,
        
        data_aug_meta_prefix:Optional[str]=None, 
        lbl2data_aug_meta_prefix:Optional[str]=None, 

        data_pred_meta_prefix:Optional[str]=None,
        lbl2data_pred_meta_prefix:Optional[str]=None,
        
        meta_loss_weight:Optional[Union[List,float]]=0.3,
        use_fusion_loss:Optional[bool]=False,
        fusion_loss_weight:Optional[float]=0.15,
        
        use_encoder_parallel:Optional[bool]=True,
    ):
        super().__init__(config)
        self.m_lw, self.f_lw = meta_loss_weight, fusion_loss_weight
        store_attr('data_pred_meta_prefix,lbl2data_pred_meta_prefix')
        store_attr('data_aug_meta_prefix,lbl2data_aug_meta_prefix')
        store_attr('use_fusion_loss,use_encoder_parallel')
        
        self.encoder = None
        self.rep_loss_fn = MultiTriplet(bsz=batch_size, tn_targ=num_batch_labels, margin=margin, n_negatives=num_negatives, 
                                        tau=tau, apply_softmax=apply_softmax, reduce='mean')
        
    def init_retrieval_head(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        self.encoder.dr_head.post_init()
        self.encoder.dr_fused_head.post_init()
        self.encoder.meta_head.post_init()

    def init_cross_head(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        self.encoder.cross_head.post_init()
        
    
    def disable_noise(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        use_noise = self.encoder.module.use_noise if isinstance(self.encoder, XCDataParallel) else self.encoder.use_noise
        if isinstance(self.encoder, XCDataParallel): self.encoder.module.use_noise = False
        else: self.encoder.use_noise = False
        return use_noise
    
    def set_noise(self, use_noise):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        if isinstance(self.encoder, XCDataParallel): self.encoder.module.use_noise = use_noise
        else: self.encoder.use_noise = use_noise
            
    def get_noise(self):
        if self.encoder is None: raise ValueError('`self.encoder` is not initialized.')
        return self.encoder.module.use_noise if isinstance(self.encoder, XCDataParallel) else self.encoder.use_noise

    

    def compute_loss(self, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)
    
    def compute_meta_loss(self, data_repr, lbl2data_repr, **kwargs):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_meta_inputs = Parameters.from_meta_pred_prefix(self.data_pred_meta_prefix, **kwargs)
        lbl2data_meta_inputs = Parameters.from_meta_pred_prefix(self.lbl2data_pred_meta_prefix, **kwargs)
        meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}

        m_lw = Parameters.get_meta_loss_weights(self.m_lw, len(meta_inputs)) if len(meta_inputs) else []
        
        loss = 0.0
        for inputs,lw in zip(meta_inputs.values(), m_lw):
            if 'lbl2data2ptr' in inputs:
                idx = torch.where(inputs['lbl2data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(lbl2data_repr[idx], inputs_o.rep, inputs['lbl2data2ptr'][idx],
                                              inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss

            elif 'data2ptr' in inputs:
                idx = torch.where(inputs['data2ptr'])[0]
                if len(idx) > 0:
                    inputs_o = encoder(data_input_ids=inputs['input_ids'], data_attention_mask=inputs['attention_mask'], 
                                       data_type="meta")
                    m_loss = self.rep_loss_fn(data_repr[idx], inputs_o.rep, inputs['data2ptr'][idx], inputs['idx'], 
                                              inputs['pdata2ptr'][idx], inputs['pidx'])
                    loss += lw * m_loss       

            else: raise ValueError('Invalid metadata input arguments.')
        return loss

    def compute_fusion_loss(self, data_repr, meta_repr:Dict, prefix:str, **kwargs):
        meta_inputs = Parameters.from_meta_pred_prefix(prefix, **kwargs)
        
        loss = 0.0
        if meta_repr is not None:
            for key,input_repr in meta_repr.items():
                inputs = meta_inputs[key]
                if 'lbl2data2ptr' in inputs:
                    idx = torch.where(inputs['lbl2data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['lbl2data2ptr'][idx],
                                                  inputs['idx'], inputs['plbl2data2ptr'][idx], inputs['pidx'])
                        loss += self.f_lw * m_loss
    
                elif 'data2ptr' in inputs:
                    idx = torch.where(inputs['data2ptr'])[0]
                    if len(idx) > 0:
                        m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
                                                  inputs['pdata2ptr'][idx], inputs['pidx'])
                        loss += self.f_lw * m_loss       
    
                else: raise ValueError('Invalid metadata input arguments.')
        return loss


    
    def get_meta_representation(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
            
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_unnormalized=True, data_type="meta")
        return RADOutput(
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
        )

    
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return RADOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

## `RAD001`

In [None]:
#| export
class RAD001Encoder(Encoder):

    def __init__(self, config:PretrainedConfig, use_noise:Optional[bool]=True, noise_pct:Optional[float]=0.5, resize_length:Optional[int]=None):
        super().__init__(config, use_noise, noise_pct, resize_length)        
        self.gen_head = GenerationHead(config)
        
        
    def get_output_embeddings(self) -> nn.Module:
        return self.gen_head.projector

    def set_output_embeddings(self, new_embeddings: nn.Module):
        self.gen_head.projector = new_embeddings

    
    
    def gen(self, x:torch.Tensor):
        return self.gen_head(x)

    
    
    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_gen_idx:Optional[torch.Tensor]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
            
        data_embed = data_o[0]
        data_fused_repr = data_logits = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
                                                                             data_attention_mask, 
                                                                             meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
                data_logits = self.gen(data_fused_embed if data_gen_idx is None else data_fused_embed[data_gen_idx])
                
        if data_logits is None:
            data_fused_repr = data_repr
            data_logits = self.gen(data_embed if data_gen_idx is None else data_embed[data_gen_idx])    
        
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            
            logits=data_logits,
            
            meta_repr=meta_repr,
        )
        

In [None]:
#| export
class RAD001(RAD000, DistilBertForMaskedLM):
    use_generation,use_representation = True,True
    _tied_weights_keys = ["encoder.distilbert", "encoder.gen_head.transform", "encoder.gen_head.layer_norm", 
                          "encoder.gen_head.projector"]
    
    @delegates(RAD000.__init__)
    def __init__(
        self, 
        config,
        
        num_batch_labels:Optional[int]=None, 
        ignore_token:Optional[int]=0,
        gen_loss_weight:Optional[float]=0.01,

        resize_length:Optional[int]=None,
        use_noise:Optional[bool]=False,
        noise_percent:Optional[float]=0.7,
        
        **kwargs
    ):
        super().__init__(config, num_batch_labels=num_batch_labels, **kwargs)
        self.lw = gen_loss_weight
        
        self.encoder = RAD001Encoder(config, use_noise=use_noise, noise_pct=noise_percent, resize_length=resize_length)
        self.gen_loss_fn = MultiCrossEntropy(tn_targ=num_batch_labels, ig_tok=ignore_token, reduce='mean')
        
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_generation_head()


    
    def init_generation_head(self):
        self.encoder.gen_head.projector.weight.data = self.get_input_embeddings().weight.data.clone()
        
    def remap_post_init(self):
        self.encoder.distilbert = self.distilbert 
        self.encoder.gen_head.transform = self.vocab_transform
        self.encoder.gen_head.layer_norm = self.vocab_layer_norm
        self.encoder.gen_head.projector = self.vocab_projector
        

    
    def compute_loss(self, inp_logits, inp_repr, targ_logits, targ_repr, 
                     inp_input_ids, targ_input_ids, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        
        gen_loss = self.gen_loss_fn(inp_logits, targ_input_ids, targ_ptr) + self.gen_loss_fn(targ_logits, inp_input_ids)
        rep_loss = self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)
        return rep_loss + self.lw * gen_loss
        
        
    def get_last_item_mask(self, num_input:torch.Tensor, input_sz:int):
        idx = torch.where(num_input > 0)[0]
        input_ptr = num_input[idx].cumsum(dim=0)-1
        return torch.zeros(input_sz, dtype=torch.bool, device=num_input.device).scatter(0, input_ptr, 1)

        

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_gen_idx = self.get_last_item_mask(lbl2data_data2ptr, len(lbl2data_idx))
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, data_gen_idx=lbl2data_gen_idx,
                                 **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.logits, data_o.fused_rep, lbl2data_o.logits, lbl2data_o.fused_rep, 
                                     data_input_ids,lbl2data_input_ids,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.fused_rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.fused_rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return RADOutput(
            loss=loss,
            
            logits=data_o.logits,
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

### Example

In [None]:
model = RAD001.from_pretrained('distilbert-base-uncased', num_batch_labels=5000, ignore_token=0, batch_size=100,
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='hlk2data', lbl2data_aug_meta_prefix='hlk2lbl', 
                               data_pred_meta_prefix='cat2data', lbl2data_pred_meta_prefix='cat2lbl',
                               
                               resize_length=5000, use_noise=True, noise_percent=0.3,
                               
                               gen_loss_weight=0.001, meta_loss_weight=0.3, fusion_loss_weight=0.1,
                               
                               tie_word_embeddings=False, use_fusion_loss=False,  use_encoder_parallel=False)

Some weights of RAD001 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.gen_head.projector.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight', 'vocab_projector.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model.init_retrieval_head()
model.init_generation_head()

In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
    
    'phlk2data_idx', 'phlk2data_data2ptr', 'hlk2data_idx', 'hlk2data_input_ids', 'hlk2data_attention_mask', 
    'hlk2data_data2ptr',
    'phlk2lbl_idx', 'phlk2lbl_lbl2data2ptr', 'phlk2lbl_data2ptr', 'hlk2lbl_idx', 'hlk2lbl_input_ids', 
    'hlk2lbl_attention_mask', 'hlk2lbl_lbl2data2ptr', 'hlk2lbl_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.1000, device='cuda:0', grad_fn=<AddBackward0>)

## `RAD002`

In [None]:
#| export
class RAD002(RAD000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(RAD000.__init__)
    def __init__(
        self, 
        config,  
        resize_length:Optional[int]=None,
        use_noise:Optional[bool]=False,
        noise_percent:Optional[float]=0.5,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder(config, use_noise=use_noise, noise_pct=noise_percent, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()

    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = RAD002.from_pretrained('distilbert-base-uncased', num_batch_labels=5000, batch_size=100,
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix='cat2lbl', 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               resize_length=5000, use_noise=True, noise_percent=0.5,
                               
                               meta_loss_weight=0.3, fusion_loss_weight=0.1, 
                               use_fusion_loss=True,  use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()

Some weights of RAD002 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = model.to('cuda')

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
    
    'phlk2data_idx', 'phlk2data_data2ptr', 'hlk2data_idx', 'hlk2data_input_ids', 'hlk2data_attention_mask', 
    'hlk2data_data2ptr',
    'phlk2lbl_idx', 'phlk2lbl_lbl2data2ptr', 'phlk2lbl_data2ptr', 'hlk2lbl_idx', 'hlk2lbl_input_ids', 
    'hlk2lbl_attention_mask', 'hlk2lbl_lbl2data2ptr', 'hlk2lbl_data2ptr',
])

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

> /tmp/ipykernel_7476/2498770119.py(164)forward()
    162         #debug
    163 
--> 164         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    165 
    166         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_7476/2498770119.py(166)forward()
    164         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    165 
--> 166         if self.use_encoder_parallel:
    167             encoder = XCDataParallel(module=self.encoder)
    168         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(168)forward()
    166         if self.use_encoder_parallel:
    167             encoder = XCDataParallel(module=self.encoder)
--> 168         else: encoder = self.encoder
    169 
    170         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(170)forward()
    168         else: encoder = self.encoder
    169 
--> 170         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
    171         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    172                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(171)forward()
    169 
    170         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
--> 171         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    172                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    173 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(172)forward()
    170         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
    171         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
--> 172                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    173 
    174 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(171)forward()
    169 
    170         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
--> 171         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    172                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    173 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(172)forward()
    170         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
    171         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
--> 172                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    173 
    174 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(171)forward()
    169 
    170         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
--> 171         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
    172                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
    173 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(130)forward()
    128         import pdb; pdb.set_trace()
    129         #debug
--> 130         data_o = self.encode(data_input_ids, data_attention_mask)
    131 
    132         if data_type is not None and data_type == "meta":



ipdb>  n


> /tmp/ipykernel_7476/658384959.py(132)forward()
    130         data_o = self.encode(data_input_ids, data_attention_mask)
    131 
--> 132         if data_type is not None and data_type == "meta":
    133             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    134         else:



ipdb>  


> /tmp/ipykernel_7476/658384959.py(135)forward()
    133             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    134         else:
--> 135             data_repr = self.dr(data_o[0], data_attention_mask)
    136 
    137         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_7476/658384959.py(137)forward()
    135             data_repr = self.dr(data_o[0], data_attention_mask)
    136 
--> 137         data_fused_repr = meta_repr = None
    138         if data_aug_meta_prefix is not None:
    139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(138)forward()
    136 
    137         data_fused_repr = meta_repr = None
--> 138         if data_aug_meta_prefix is not None:
    139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    140             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_7476/658384959.py(139)forward()
    137         data_fused_repr = meta_repr = None
    138         if data_aug_meta_prefix is not None:
--> 139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    140             if len(meta_kwargs):
    141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(140)forward()
    138         if data_aug_meta_prefix is not None:
    139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 140             if len(meta_kwargs):
    141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    142                                                                              data_attention_mask,



ipdb>  meta_kwargs.keys()


dict_keys(['cat2data'])


ipdb>  n


> /tmp/ipykernel_7476/658384959.py(141)forward()
    139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    140             if len(meta_kwargs):
--> 141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    142                                                                              data_attention_mask,
    143                                                                              meta_kwargs)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(142)forward()
    140             if len(meta_kwargs):
    141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
--> 142                                                                              data_attention_mask,
    143                                                                              meta_kwargs)
    144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(143)forward()
    141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    142                                                                              data_attention_mask,
--> 143                                                                              meta_kwargs)
    144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    145 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(141)forward()
    139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    140             if len(meta_kwargs):
--> 141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    142                                                                              data_attention_mask,
    143                                                                              meta_kwargs)



ipdb>  s


--Call--
> /tmp/ipykernel_7476/658384959.py(84)fuse_meta_into_embeddings()
     82 
     83 
---> 84     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
     85         meta_repr = {}
     86 



ipdb>  n


> /tmp/ipykernel_7476/658384959.py(85)fuse_meta_into_embeddings()
     83 
     84     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
---> 85         meta_repr = {}
     86 
     87         for m_key, m_args in meta_kwargs.items():



ipdb>  


> /tmp/ipykernel_7476/658384959.py(87)fuse_meta_into_embeddings()
     85         meta_repr = {}
     86 
---> 87         for m_key, m_args in meta_kwargs.items():
     88             idx = torch.where(m_args['data2ptr'] > 0)[0]
     89             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(88)fuse_meta_into_embeddings()
     86 
     87         for m_key, m_args in meta_kwargs.items():
---> 88             idx = torch.where(m_args['data2ptr'] > 0)[0]
     89             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     90 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(89)fuse_meta_into_embeddings()
     87         for m_key, m_args in meta_kwargs.items():
     88             idx = torch.where(m_args['data2ptr'] > 0)[0]
---> 89             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     90 
     91             if len(idx):



ipdb>  


> /tmp/ipykernel_7476/658384959.py(91)fuse_meta_into_embeddings()
     89             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     90 
---> 91             if len(idx):
     92                 if 'meta_repr' in m_args:
     93                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(92)fuse_meta_into_embeddings()
     90 
     91             if len(idx):
---> 92                 if 'meta_repr' in m_args:
     93                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
     94                     m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])



ipdb>  n


> /tmp/ipykernel_7476/658384959.py(97)fuse_meta_into_embeddings()
     95                     m_repr_mask = m_repr_mask.bool()
     96                 else:
---> 97                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
     98                                                                 m_args['data2ptr'][idx])
     99                     n_meta = m_args['data2ptr'].max()



ipdb>  


> /tmp/ipykernel_7476/658384959.py(98)fuse_meta_into_embeddings()
     96                 else:
     97                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
---> 98                                                                 m_args['data2ptr'][idx])
     99                     n_meta = m_args['data2ptr'].max()
    100 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(97)fuse_meta_into_embeddings()
     95                     m_repr_mask = m_repr_mask.bool()
     96                 else:
---> 97                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
     98                                                                 m_args['data2ptr'][idx])
     99                     n_meta = m_args['data2ptr'].max()



ipdb>  


> /tmp/ipykernel_7476/658384959.py(99)fuse_meta_into_embeddings()
     97                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
     98                                                                 m_args['data2ptr'][idx])
---> 99                     n_meta = m_args['data2ptr'].max()
    100 
    101                     if self.use_noise:



ipdb>  


> /tmp/ipykernel_7476/658384959.py(101)fuse_meta_into_embeddings()
     99                     n_meta = m_args['data2ptr'].max()
    100 
--> 101                     if self.use_noise:
    102                         m_input_ids, m_attention_mask = self.add_noise(m_input_ids, m_attention_mask, n_meta)
    103 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(102)fuse_meta_into_embeddings()
    100 
    101                     if self.use_noise:
--> 102                         m_input_ids, m_attention_mask = self.add_noise(m_input_ids, m_attention_mask, n_meta)
    103 
    104                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]



ipdb>  


> /tmp/ipykernel_7476/658384959.py(104)fuse_meta_into_embeddings()
    102                         m_input_ids, m_attention_mask = self.add_noise(m_input_ids, m_attention_mask, n_meta)
    103 
--> 104                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
    105 
    106                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)



ipdb>  n


> /tmp/ipykernel_7476/658384959.py(106)fuse_meta_into_embeddings()
    104                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
    105 
--> 106                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
    107                     m_repr_mask = torch.any(m_attention_mask, dim=1)
    108 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(107)fuse_meta_into_embeddings()
    105 
    106                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
--> 107                     m_repr_mask = torch.any(m_attention_mask, dim=1)
    108 
    109                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(109)fuse_meta_into_embeddings()
    107                     m_repr_mask = torch.any(m_attention_mask, dim=1)
    108 
--> 109                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
    110 
    111                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(111)fuse_meta_into_embeddings()
    109                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
    110 
--> 111                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
    112 
    113                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]



ipdb>  m_repr.shape


torch.Size([5, 3, 768])


ipdb>  m_repr_mask.shape


torch.Size([5, 3])


ipdb>  n


> /tmp/ipykernel_7476/658384959.py(113)fuse_meta_into_embeddings()
    111                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
    112 
--> 113                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
    114                 embed[idx] += fused_embed
    115 



ipdb>  meta_repr[m_key].shape


torch.Size([13, 768])


ipdb>  embed[idx].shape


torch.Size([5, 5, 768])


ipdb>  attention_mask[idx].shape


torch.Size([5, 5])


ipdb>  m_repr.shape


torch.Size([5, 3, 768])


ipdb>  m_repr_mask.shape


torch.Size([5, 3])


ipdb>  n


> /tmp/ipykernel_7476/658384959.py(114)fuse_meta_into_embeddings()
    112 
    113                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
--> 114                 embed[idx] += fused_embed
    115 
    116         return embed, meta_repr



ipdb>  fused_embed.shape


torch.Size([5, 5, 768])


ipdb>  n


> /tmp/ipykernel_7476/658384959.py(87)fuse_meta_into_embeddings()
     85         meta_repr = {}
     86 
---> 87         for m_key, m_args in meta_kwargs.items():
     88             idx = torch.where(m_args['data2ptr'] > 0)[0]
     89             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(116)fuse_meta_into_embeddings()
    114                 embed[idx] += fused_embed
    115 
--> 116         return embed, meta_repr
    117 
    118     def forward(



ipdb>  


--Return--
(tensor([[[-0....PutBackward0>), {'cat2data': tensor([[-0.0...DivBackward0>)})
> /tmp/ipykernel_7476/658384959.py(116)fuse_meta_into_embeddings()
    114                 embed[idx] += fused_embed
    115 
--> 116         return embed, meta_repr
    117 
    118     def forward(



ipdb>  


> /tmp/ipykernel_7476/658384959.py(144)forward()
    142                                                                              data_attention_mask,
    143                                                                              meta_kwargs)
--> 144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    145 
    146         return EncoderOutput(



ipdb>  


> /tmp/ipykernel_7476/658384959.py(146)forward()
    144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    145 
--> 146         return EncoderOutput(
    147             rep=data_repr,
    148             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_7476/658384959.py(147)forward()
    145 
    146         return EncoderOutput(
--> 147             rep=data_repr,
    148             fused_rep=data_fused_repr,
    149             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_7476/658384959.py(148)forward()
    146         return EncoderOutput(
    147             rep=data_repr,
--> 148             fused_rep=data_fused_repr,
    149             meta_repr=meta_repr,
    150         )



ipdb>  


> /tmp/ipykernel_7476/658384959.py(149)forward()
    147             rep=data_repr,
    148             fused_rep=data_fused_repr,
--> 149             meta_repr=meta_repr,
    150         )
    151 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(146)forward()
    144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    145 
--> 146         return EncoderOutput(
    147             rep=data_repr,
    148             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /tmp/ipykernel_7476/658384959.py(146)forward()
    144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    145 
--> 146         return EncoderOutput(
    147             rep=data_repr,
    148             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1520)_call_impl()
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1511)_wrapped_call_impl()
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(175)forward()
    173 
    174 
--> 175         loss = None; lbl2data_o = EncoderOutput()
    176         if lbl2data_input_ids is not None:
    177             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(176)forward()
    174 
    175         loss = None; lbl2data_o = EncoderOutput()
--> 176         if lbl2data_input_ids is not None:
    177             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    178             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(177)forward()
    175         loss = None; lbl2data_o = EncoderOutput()
    176         if lbl2data_input_ids is not None:
--> 177             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    178             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    179                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(178)forward()
    176         if lbl2data_input_ids is not None:
    177             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
--> 178             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    179                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    180 



ipdb>  lbl2data_meta_kwargs.keys()


dict_keys(['cat2lbl_attention_mask', 'cat2lbl_input_ids', 'cat2lbl_data2ptr'])


ipdb>  n


> /tmp/ipykernel_7476/2498770119.py(179)forward()
    177             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    178             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
--> 179                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    180 
    181             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(178)forward()
    176         if lbl2data_input_ids is not None:
    177             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
--> 178             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    179                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    180 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(179)forward()
    177             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    178             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
--> 179                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    180 
    181             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(178)forward()
    176         if lbl2data_input_ids is not None:
    177             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
--> 178             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    179                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    180 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(130)forward()
    128         import pdb; pdb.set_trace()
    129         #debug
--> 130         data_o = self.encode(data_input_ids, data_attention_mask)
    131 
    132         if data_type is not None and data_type == "meta":



ipdb>  
ipdb>  
ipdb>  n


> /tmp/ipykernel_7476/658384959.py(132)forward()
    130         data_o = self.encode(data_input_ids, data_attention_mask)
    131 
--> 132         if data_type is not None and data_type == "meta":
    133             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    134         else:



ipdb>  


> /tmp/ipykernel_7476/658384959.py(135)forward()
    133             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    134         else:
--> 135             data_repr = self.dr(data_o[0], data_attention_mask)
    136 
    137         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_7476/658384959.py(137)forward()
    135             data_repr = self.dr(data_o[0], data_attention_mask)
    136 
--> 137         data_fused_repr = meta_repr = None
    138         if data_aug_meta_prefix is not None:
    139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(138)forward()
    136 
    137         data_fused_repr = meta_repr = None
--> 138         if data_aug_meta_prefix is not None:
    139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    140             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_7476/658384959.py(139)forward()
    137         data_fused_repr = meta_repr = None
    138         if data_aug_meta_prefix is not None:
--> 139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    140             if len(meta_kwargs):
    141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(140)forward()
    138         if data_aug_meta_prefix is not None:
    139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 140             if len(meta_kwargs):
    141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    142                                                                              data_attention_mask,



ipdb>  meta_kwargs.keys()


dict_keys(['cat2lbl'])


ipdb>  n


> /tmp/ipykernel_7476/658384959.py(141)forward()
    139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    140             if len(meta_kwargs):
--> 141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    142                                                                              data_attention_mask,
    143                                                                              meta_kwargs)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(142)forward()
    140             if len(meta_kwargs):
    141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
--> 142                                                                              data_attention_mask,
    143                                                                              meta_kwargs)
    144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(143)forward()
    141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    142                                                                              data_attention_mask,
--> 143                                                                              meta_kwargs)
    144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    145 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(141)forward()
    139             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    140             if len(meta_kwargs):
--> 141                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    142                                                                              data_attention_mask,
    143                                                                              meta_kwargs)



ipdb>  s


--Call--
> /tmp/ipykernel_7476/658384959.py(84)fuse_meta_into_embeddings()
     82 
     83 
---> 84     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
     85         meta_repr = {}
     86 



ipdb>  n


> /tmp/ipykernel_7476/658384959.py(85)fuse_meta_into_embeddings()
     83 
     84     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
---> 85         meta_repr = {}
     86 
     87         for m_key, m_args in meta_kwargs.items():



ipdb>  


> /tmp/ipykernel_7476/658384959.py(87)fuse_meta_into_embeddings()
     85         meta_repr = {}
     86 
---> 87         for m_key, m_args in meta_kwargs.items():
     88             idx = torch.where(m_args['data2ptr'] > 0)[0]
     89             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(88)fuse_meta_into_embeddings()
     86 
     87         for m_key, m_args in meta_kwargs.items():
---> 88             idx = torch.where(m_args['data2ptr'] > 0)[0]
     89             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     90 



ipdb>  m_keys


*** NameError: name 'm_keys' is not defined


ipdb>  m_key


'cat2lbl'


ipdb>  n


> /tmp/ipykernel_7476/658384959.py(89)fuse_meta_into_embeddings()
     87         for m_key, m_args in meta_kwargs.items():
     88             idx = torch.where(m_args['data2ptr'] > 0)[0]
---> 89             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     90 
     91             if len(idx):



ipdb>  idx


tensor([0, 1, 3, 4], device='cuda:0')


ipdb>  n


> /tmp/ipykernel_7476/658384959.py(91)fuse_meta_into_embeddings()
     89             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     90 
---> 91             if len(idx):
     92                 if 'meta_repr' in m_args:
     93                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(92)fuse_meta_into_embeddings()
     90 
     91             if len(idx):
---> 92                 if 'meta_repr' in m_args:
     93                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
     94                     m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])



ipdb>  


> /tmp/ipykernel_7476/658384959.py(97)fuse_meta_into_embeddings()
     95                     m_repr_mask = m_repr_mask.bool()
     96                 else:
---> 97                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
     98                                                                 m_args['data2ptr'][idx])
     99                     n_meta = m_args['data2ptr'].max()



ipdb>  


> /tmp/ipykernel_7476/658384959.py(98)fuse_meta_into_embeddings()
     96                 else:
     97                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
---> 98                                                                 m_args['data2ptr'][idx])
     99                     n_meta = m_args['data2ptr'].max()
    100 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(97)fuse_meta_into_embeddings()
     95                     m_repr_mask = m_repr_mask.bool()
     96                 else:
---> 97                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
     98                                                                 m_args['data2ptr'][idx])
     99                     n_meta = m_args['data2ptr'].max()



ipdb>  m_args['data2ptr'][idx]


tensor([3, 1, 3, 1], device='cuda:0')


ipdb>  n


> /tmp/ipykernel_7476/658384959.py(99)fuse_meta_into_embeddings()
     97                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
     98                                                                 m_args['data2ptr'][idx])
---> 99                     n_meta = m_args['data2ptr'].max()
    100 
    101                     if self.use_noise:



ipdb>  m_attention_mask


tensor([[1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0]], device='cuda:0')


ipdb>  n


> /tmp/ipykernel_7476/658384959.py(101)fuse_meta_into_embeddings()
     99                     n_meta = m_args['data2ptr'].max()
    100 
--> 101                     if self.use_noise:
    102                         m_input_ids, m_attention_mask = self.add_noise(m_input_ids, m_attention_mask, n_meta)
    103 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(102)fuse_meta_into_embeddings()
    100 
    101                     if self.use_noise:
--> 102                         m_input_ids, m_attention_mask = self.add_noise(m_input_ids, m_attention_mask, n_meta)
    103 
    104                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]



ipdb>  


> /tmp/ipykernel_7476/658384959.py(104)fuse_meta_into_embeddings()
    102                         m_input_ids, m_attention_mask = self.add_noise(m_input_ids, m_attention_mask, n_meta)
    103 
--> 104                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
    105 
    106                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(106)fuse_meta_into_embeddings()
    104                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
    105 
--> 106                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
    107                     m_repr_mask = torch.any(m_attention_mask, dim=1)
    108 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(107)fuse_meta_into_embeddings()
    105 
    106                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
--> 107                     m_repr_mask = torch.any(m_attention_mask, dim=1)
    108 
    109                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(109)fuse_meta_into_embeddings()
    107                     m_repr_mask = torch.any(m_attention_mask, dim=1)
    108 
--> 109                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
    110 
    111                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(111)fuse_meta_into_embeddings()
    109                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
    110 
--> 111                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
    112 
    113                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]



ipdb>  


> /tmp/ipykernel_7476/658384959.py(113)fuse_meta_into_embeddings()
    111                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
    112 
--> 113                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
    114                 embed[idx] += fused_embed
    115 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(114)fuse_meta_into_embeddings()
    112 
    113                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
--> 114                 embed[idx] += fused_embed
    115 
    116         return embed, meta_repr



ipdb>  


> /tmp/ipykernel_7476/658384959.py(87)fuse_meta_into_embeddings()
     85         meta_repr = {}
     86 
---> 87         for m_key, m_args in meta_kwargs.items():
     88             idx = torch.where(m_args['data2ptr'] > 0)[0]
     89             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_7476/658384959.py(116)fuse_meta_into_embeddings()
    114                 embed[idx] += fused_embed
    115 
--> 116         return embed, meta_repr
    117 
    118     def forward(



ipdb>  


--Return--
(tensor([[[-0....PutBackward0>), {'cat2lbl': tensor([[-2.1...DivBackward0>)})
> /tmp/ipykernel_7476/658384959.py(116)fuse_meta_into_embeddings()
    114                 embed[idx] += fused_embed
    115 
--> 116         return embed, meta_repr
    117 
    118     def forward(



ipdb>  


> /tmp/ipykernel_7476/658384959.py(144)forward()
    142                                                                              data_attention_mask,
    143                                                                              meta_kwargs)
--> 144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    145 
    146         return EncoderOutput(



ipdb>  


> /tmp/ipykernel_7476/658384959.py(146)forward()
    144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    145 
--> 146         return EncoderOutput(
    147             rep=data_repr,
    148             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_7476/658384959.py(147)forward()
    145 
    146         return EncoderOutput(
--> 147             rep=data_repr,
    148             fused_rep=data_fused_repr,
    149             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_7476/658384959.py(148)forward()
    146         return EncoderOutput(
    147             rep=data_repr,
--> 148             fused_rep=data_fused_repr,
    149             meta_repr=meta_repr,
    150         )



ipdb>  


> /tmp/ipykernel_7476/658384959.py(149)forward()
    147             rep=data_repr,
    148             fused_rep=data_fused_repr,
--> 149             meta_repr=meta_repr,
    150         )
    151 



ipdb>  


> /tmp/ipykernel_7476/658384959.py(146)forward()
    144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    145 
--> 146         return EncoderOutput(
    147             rep=data_repr,
    148             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /tmp/ipykernel_7476/658384959.py(146)forward()
    144                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    145 
--> 146         return EncoderOutput(
    147             rep=data_repr,
    148             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1520)_call_impl()
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1511)_wrapped_call_impl()
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(181)forward()
    179                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    180 
--> 181             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
    182                                      plbl2data_data2ptr,plbl2data_idx)
    183             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(182)forward()
    180 
    181             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
--> 182                                      plbl2data_data2ptr,plbl2data_idx)
    183             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
    184 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(181)forward()
    179                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    180 
--> 181             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
    182                                      plbl2data_data2ptr,plbl2data_idx)
    183             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(183)forward()
    181             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
    182                                      plbl2data_data2ptr,plbl2data_idx)
--> 183             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
    184 
    185             if self.use_fusion_loss:



ipdb>  s


--Call--
> /tmp/ipykernel_7476/2498770119.py(67)compute_meta_loss()
     65         return self.rep_loss_fn(inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)
     66 
---> 67     def compute_meta_loss(self, data_repr, lbl2data_repr, **kwargs):
     68         if self.use_encoder_parallel:
     69             encoder = XCDataParallel(module=self.encoder)



ipdb>  n


> /tmp/ipykernel_7476/2498770119.py(68)compute_meta_loss()
     66 
     67     def compute_meta_loss(self, data_repr, lbl2data_repr, **kwargs):
---> 68         if self.use_encoder_parallel:
     69             encoder = XCDataParallel(module=self.encoder)
     70         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(70)compute_meta_loss()
     68         if self.use_encoder_parallel:
     69             encoder = XCDataParallel(module=self.encoder)
---> 70         else: encoder = self.encoder
     71 
     72         data_meta_inputs = Parameters.from_meta_pred_prefix(self.data_pred_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(72)compute_meta_loss()
     70         else: encoder = self.encoder
     71 
---> 72         data_meta_inputs = Parameters.from_meta_pred_prefix(self.data_pred_meta_prefix, **kwargs)
     73         lbl2data_meta_inputs = Parameters.from_meta_pred_prefix(self.lbl2data_pred_meta_prefix, **kwargs)
     74         meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(73)compute_meta_loss()
     71 
     72         data_meta_inputs = Parameters.from_meta_pred_prefix(self.data_pred_meta_prefix, **kwargs)
---> 73         lbl2data_meta_inputs = Parameters.from_meta_pred_prefix(self.lbl2data_pred_meta_prefix, **kwargs)
     74         meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}
     75 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(74)compute_meta_loss()
     72         data_meta_inputs = Parameters.from_meta_pred_prefix(self.data_pred_meta_prefix, **kwargs)
     73         lbl2data_meta_inputs = Parameters.from_meta_pred_prefix(self.lbl2data_pred_meta_prefix, **kwargs)
---> 74         meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}
     75 
     76         m_lw = Parameters.get_meta_loss_weights(self.m_lw, len(meta_inputs)) if len(meta_inputs) else []



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(76)compute_meta_loss()
     74         meta_inputs = {**data_meta_inputs, **lbl2data_meta_inputs}
     75 
---> 76         m_lw = Parameters.get_meta_loss_weights(self.m_lw, len(meta_inputs)) if len(meta_inputs) else []
     77 
     78         loss = 0.0



ipdb>  len(meta_inputs)


0


ipdb>  n


> /tmp/ipykernel_7476/2498770119.py(78)compute_meta_loss()
     76         m_lw = Parameters.get_meta_loss_weights(self.m_lw, len(meta_inputs)) if len(meta_inputs) else []
     77 
---> 78         loss = 0.0
     79         for inputs,lw in zip(meta_inputs.values(), m_lw):
     80             if 'lbl2data2ptr' in inputs:



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(79)compute_meta_loss()
     77 
     78         loss = 0.0
---> 79         for inputs,lw in zip(meta_inputs.values(), m_lw):
     80             if 'lbl2data2ptr' in inputs:
     81                 idx = torch.where(inputs['lbl2data2ptr'])[0]



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(99)compute_meta_loss()
     97 
     98             else: raise ValueError('Invalid metadata input arguments.')
---> 99         return loss
    100 
    101     def compute_fusion_loss(self, data_repr, meta_repr:Dict, prefix:str, **kwargs):



ipdb>  


--Return--
0.0
> /tmp/ipykernel_7476/2498770119.py(99)compute_meta_loss()
     97 
     98             else: raise ValueError('Invalid metadata input arguments.')
---> 99         return loss
    100 
    101     def compute_fusion_loss(self, data_repr, meta_repr:Dict, prefix:str, **kwargs):



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(185)forward()
    183             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
    184 
--> 185             if self.use_fusion_loss:
    186                 loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
    187                 loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(186)forward()
    184 
    185             if self.use_fusion_loss:
--> 186                 loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
    187                 loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
    188 



ipdb>  s


--Call--
> /tmp/ipykernel_7476/2498770119.py(101)compute_fusion_loss()
     99         return loss
    100 
--> 101     def compute_fusion_loss(self, data_repr, meta_repr:Dict, prefix:str, **kwargs):
    102         meta_inputs = Parameters.from_meta_pred_prefix(prefix, **kwargs)
    103 



ipdb>  n


> /tmp/ipykernel_7476/2498770119.py(102)compute_fusion_loss()
    100 
    101     def compute_fusion_loss(self, data_repr, meta_repr:Dict, prefix:str, **kwargs):
--> 102         meta_inputs = Parameters.from_meta_pred_prefix(prefix, **kwargs)
    103 
    104         loss = 0.0



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(104)compute_fusion_loss()
    102         meta_inputs = Parameters.from_meta_pred_prefix(prefix, **kwargs)
    103 
--> 104         loss = 0.0
    105         for key,input_repr in meta_repr.items():
    106             inputs = meta_inputs[key]



ipdb>  meta_inputs.keys()


dict_keys(['cat2data'])


ipdb>  n


> /tmp/ipykernel_7476/2498770119.py(105)compute_fusion_loss()
    103 
    104         loss = 0.0
--> 105         for key,input_repr in meta_repr.items():
    106             inputs = meta_inputs[key]
    107             if 'lbl2data2ptr' in inputs:



ipdb>  meta_repr


{'cat2data': tensor([[-0.0077,  0.0311, -0.1009,  ..., -0.0220, -0.0373,  0.0008],
        [-0.0212, -0.0003, -0.0349,  ..., -0.0195, -0.0281,  0.0302],
        [-0.0238,  0.0017, -0.0369,  ..., -0.0190, -0.0277,  0.0301],
        ...,
        [-0.0352, -0.0064, -0.0240,  ..., -0.0247,  0.0242, -0.0059],
        [ 0.0044, -0.0203, -0.0313,  ..., -0.0202, -0.0038, -0.0124],
        [-0.0115,  0.0246, -0.0558,  ..., -0.0463,  0.0357, -0.0189]],
       device='cuda:0', grad_fn=<DivBackward0>)}


ipdb>  n


> /tmp/ipykernel_7476/2498770119.py(106)compute_fusion_loss()
    104         loss = 0.0
    105         for key,input_repr in meta_repr.items():
--> 106             inputs = meta_inputs[key]
    107             if 'lbl2data2ptr' in inputs:
    108                 idx = torch.where(inputs['lbl2data2ptr'])[0]



ipdb>  n


> /tmp/ipykernel_7476/2498770119.py(107)compute_fusion_loss()
    105         for key,input_repr in meta_repr.items():
    106             inputs = meta_inputs[key]
--> 107             if 'lbl2data2ptr' in inputs:
    108                 idx = torch.where(inputs['lbl2data2ptr'])[0]
    109                 if len(idx) > 0:



ipdb>  inputs


{'pidx': tensor([ 72729, 111911, 117344, 138296, 161845, 161846, 161847, 161848, 161849,
        161850, 161851, 161852, 161853, 161854, 161855, 161856, 161857, 161858,
        161859, 161860, 161861, 161862, 161863, 138151,  54422,  54425,  68239,
         69310,  76501,  79431,  81202, 101916, 102550, 102551, 102552, 102553,
        102554, 102555, 102556, 102557,  79395, 102791, 102792, 102793, 102794,
         84732,  84865,  84866,  84867,  84868,  84869,  84870,  84871,  84872,
         84873,  84874,  84875,  84876,  84877,  84878,  84879,  84880,  84881],
       device='cuda:0'), 'pdata2ptr': tensor([23,  1, 16,  5, 18], device='cuda:0'), 'idx': tensor([161854, 161862, 161850, 138151, 102554, 102557, 102556, 102794, 102791,
        102793,  84871,  84870,  84865], device='cuda:0'), 'input_ids': tensor([[  101, 13140,  4487,  8583,  2696, 16558, 21808,  2015,  1999,  1996,
          3803,  3400,   102],
        [  101,  3470,  1997,  1996,  4549, 27695,   102,     0,     0,     

ipdb>  n


> /tmp/ipykernel_7476/2498770119.py(114)compute_fusion_loss()
    112                     loss += self.f_lw * m_loss
    113 
--> 114             elif 'data2ptr' in inputs:
    115                 idx = torch.where(inputs['data2ptr'])[0]
    116                 if len(idx) > 0:



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(115)compute_fusion_loss()
    113 
    114             elif 'data2ptr' in inputs:
--> 115                 idx = torch.where(inputs['data2ptr'])[0]
    116                 if len(idx) > 0:
    117                     m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(116)compute_fusion_loss()
    114             elif 'data2ptr' in inputs:
    115                 idx = torch.where(inputs['data2ptr'])[0]
--> 116                 if len(idx) > 0:
    117                     m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
    118                                               inputs['pdata2ptr'][idx], inputs['pidx'])



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(117)compute_fusion_loss()
    115                 idx = torch.where(inputs['data2ptr'])[0]
    116                 if len(idx) > 0:
--> 117                     m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
    118                                               inputs['pdata2ptr'][idx], inputs['pidx'])
    119                     loss += self.f_lw * m_loss



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(118)compute_fusion_loss()
    116                 if len(idx) > 0:
    117                     m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
--> 118                                               inputs['pdata2ptr'][idx], inputs['pidx'])
    119                     loss += self.f_lw * m_loss
    120 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(117)compute_fusion_loss()
    115                 idx = torch.where(inputs['data2ptr'])[0]
    116                 if len(idx) > 0:
--> 117                     m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
    118                                               inputs['pdata2ptr'][idx], inputs['pidx'])
    119                     loss += self.f_lw * m_loss



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(119)compute_fusion_loss()
    117                     m_loss = self.rep_loss_fn(data_repr[idx], input_repr, inputs['data2ptr'][idx], inputs['idx'], 
    118                                               inputs['pdata2ptr'][idx], inputs['pidx'])
--> 119                     loss += self.f_lw * m_loss
    120 
    121             else: raise ValueError('Invalid metadata input arguments.')



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(105)compute_fusion_loss()
    103 
    104         loss = 0.0
--> 105         for key,input_repr in meta_repr.items():
    106             inputs = meta_inputs[key]
    107             if 'lbl2data2ptr' in inputs:



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(122)compute_fusion_loss()
    120 
    121             else: raise ValueError('Invalid metadata input arguments.')
--> 122         return loss
    123 
    124 



ipdb>  


--Return--
tensor(0.0040...AddBackward0>)
> /tmp/ipykernel_7476/2498770119.py(122)compute_fusion_loss()
    120 
    121             else: raise ValueError('Invalid metadata input arguments.')
--> 122         return loss
    123 
    124 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(187)forward()
    185             if self.use_fusion_loss:
    186                 loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
--> 187                 loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
    188 
    189 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(190)forward()
    188 
    189 
--> 190         if not return_dict:
    191             o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
    192             return ((loss,) + o) if loss is not None else o



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(195)forward()
    193 
    194 
--> 195         return RADOutput(
    196             loss=loss,
    197 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(196)forward()
    194 
    195         return RADOutput(
--> 196             loss=loss,
    197 
    198             data_repr=data_o.rep,



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(198)forward()
    196             loss=loss,
    197 
--> 198             data_repr=data_o.rep,
    199             data_fused_repr=data_o.fused_rep,
    200 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(199)forward()
    197 
    198             data_repr=data_o.rep,
--> 199             data_fused_repr=data_o.fused_rep,
    200 
    201             lbl2data_repr=lbl2data_o.rep,



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(201)forward()
    199             data_fused_repr=data_o.fused_rep,
    200 
--> 201             lbl2data_repr=lbl2data_o.rep,
    202             lbl2data_fused_repr=lbl2data_o.fused_rep,
    203         )



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(202)forward()
    200 
    201             lbl2data_repr=lbl2data_o.rep,
--> 202             lbl2data_fused_repr=lbl2data_o.fused_rep,
    203         )
    204 



ipdb>  


> /tmp/ipykernel_7476/2498770119.py(195)forward()
    193 
    194 
--> 195         return RADOutput(
    196             loss=loss,
    197 



ipdb>  


--Return--
RADOutput(los...ivBackward0>))
> /tmp/ipykernel_7476/2498770119.py(195)forward()
    193 
    194 
--> 195         return RADOutput(
    196             loss=loss,
    197 



ipdb>  


--Return--
RADOutput(los...ivBackward0>))
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1520)_call_impl()
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:



ipdb>  


--Return--
RADOutput(los...ivBackward0>))
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1511)_wrapped_call_impl()
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):



ipdb>  


--Return--
None
> /tmp/ipykernel_7476/3352081931.py(1)<module>()
----> 1 o = model(**b.to(model.device))



ipdb>  


    [... skipped 1 hidden frame]

> /home/scai/phd/aiz218323/.local/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3553)run_code()
   3551             finally:
   3552                 # Reset our crash handler in place
-> 3553                 sys.excepthook = old_excepthook
   3554         except SystemExit as e:
   3555             if result is not None:



ipdb>  c


In [None]:
o.loss

tensor(0.0634, device='cuda:0', grad_fn=<AddBackward0>)

## `RAD003`

In [None]:
#| export
class Encoder003(Encoder):

    def __init__(self, config:PretrainedConfig, use_noise:Optional[bool]=True, noise_pct:Optional[float]=0.5, resize_length:Optional[int]=None):
        super().__init__(config, use_noise, noise_pct, resize_length)        
        self.cross_gate = nn.Parameter(torch.zeros(1))
        
    
    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            
            if len(idx):
                if not torch.all(m_args['data2ptr'][idx] == m_args['data2ptr'].max()): 
                    raise ValueError(f'All datapoints should have same number of metadata.')
                    
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = m_args['input_ids'], m_args['attention_mask']
                    m_embed = self.encode(m_input_ids, m_attention_mask)[0]
    
                    m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
                    m_repr_mask = torch.any(m_attention_mask, dim=1)
                    
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)

                meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
                
                fused_embed, w = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask, output_attentions=False)[0]
                embed[idx] += self.cross_gate * fused_embed
               
        return embed, meta_repr
        
    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):
        data_o = self.encode(data_input_ids, data_attention_mask)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
        else: 
            data_repr = self.dr(data_o[0], data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
                                                                             data_attention_mask, 
                                                                             meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
        
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

In [None]:
#| export
class RAD003(RAD000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]
    
    def __init__(
        self, config,
        
        calib_margin:Optional[float]=0.3,
        calib_num_negatives:Optional[int]=5,
        calib_tau:Optional[float]=0.1,
        calib_apply_softmax:Optional[bool]=True,
        
        calib_loss_weight:Optional[float]=0.3,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.c_lw = calib_loss_weight
        self.encoder = Encoder003(config, use_noise=False)
        
        self.post_init()
        
        self.cab_loss_fn = Calibration(margin=calib_margin, tau=calib_tau, n_negatives=calib_num_negatives, 
                                       apply_softmax=calib_apply_softmax, reduce='mean')
        
        self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head(); self.init_cross_gate()

    def init_cross_gate(self):
        self.encoder.cross_gate.data = torch.zeros(1)
            
    def remap_post_init(self):
         self.distilbert = self.encoder.distilbert
        
    def calibration_loss(self, einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.c_lw * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)

            loss = self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)
            
            if data_o.fused_rep is not None:
                loss += self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)
                
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.rep, lbl2data_o.rep, **kwargs)
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return RADOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

### Example

In [None]:
model = RAD003.from_pretrained('distilbert-base-uncased', num_batch_labels=5000, batch_size=100,
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               calib_margin=0.3, calib_num_negatives=5, calib_tau=0.1, calib_apply_softmax=True,
                               
                               data_aug_meta_prefix='aug2data', lbl2data_aug_meta_prefix='aug2lbl', 
                               data_pred_meta_prefix='cat2data', lbl2data_pred_meta_prefix='cat2lbl',
                               
                               meta_loss_weight=[0.1, 0.1], calib_loss_weight=0.3,  
                               use_encoder_parallel=True)
model.init_retrieval_head()
model.init_cross_head()

Some weights of RAD003 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.cross_gate', 'encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
])

In [None]:
b.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'data_input_ids', 'data_attention_mask'])

In [None]:
model = model.to('cuda')

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.0684, device='cuda:0', grad_fn=<AddBackward0>)

## `RAD004`

In [None]:
#| export
class Encoder004(Encoder):
    
    def __init__(self, config:PretrainedConfig, use_noise:Optional[bool]=True, noise_pct:Optional[float]=0.5, resize_length:Optional[int]=None,
                cross_margin:Optional[float]=0.3, cross_tau:Optional[float]=0.1):
        super().__init__(config, use_noise, noise_pct, resize_length) 
        self.cross_head = GatedCrossAttention(config, margin=cross_margin, tau=cross_tau)
        self.cross_gate = nn.Parameter(torch.ones(1))
        
    
    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            
            if len(idx):
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
                                                                m_args['data2ptr'][idx])
                    n_meta = m_args['data2ptr'].max()
                    
                    if self.use_noise:
                        m_input_ids, m_attention_mask = self.add_noise(m_input_ids, m_attention_mask, n_meta)

                    m_embed = self.encode(m_input_ids, m_attention_mask)[0]

                    m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
                    m_repr_mask = torch.any(m_attention_mask, dim=1)
                    
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)

                meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
                
                fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
                embed[idx] += self.cross_gate * fused_embed
                
        return embed, meta_repr
        

In [None]:
#| export
class RAD004(RAD000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]
    
    def __init__(
        self, config,
        resize_length:Optional[int]=None,
        use_noise:Optional[bool]=True,
        noise_percent:Optional[float]=0.7,

        cross_margin:Optional[float]=0.3,
        cross_tau:Optional[float]=0.1,
        
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder004(config, use_noise=use_noise, noise_pct=noise_percent, resize_length=resize_length,
                                  cross_margin=cross_margin, cross_tau=cross_tau)
        
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head(); self.init_cross_gate()
        

    def init_cross_gate(self):
        self.encoder.cross_gate.data = torch.ones(1)
            
    def remap_post_init(self):
         self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = RAD004.from_pretrained('distilbert-base-uncased', num_batch_labels=5000, batch_size=100,
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,

                               resize_length=5000, use_noise=True, noise_percent=0.5,
                               
                               meta_loss_weight=0.3, fusion_loss_weight=0.1, 
                               use_fusion_loss=True,  use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_cross_gate()

Some weights of RAD004 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.cross_gate', 'encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
])

In [None]:
b.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'data_input_ids', 'data_attention_mask'])

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.0585, grad_fn=<AddBackward0>)

## `RAD005`

In [None]:
#| export
class Encoder005(Encoder):
    
    def __init__(self, config:PretrainedConfig, use_noise:Optional[bool]=True, noise_pct:Optional[float]=0.5, resize_length:Optional[int]=None,
                cross_margin:Optional[float]=0.3, cross_tau:Optional[float]=0.1, cross_dropout:Optional[float]=0.1):
        super().__init__(config, use_noise, noise_pct, resize_length) 
        self.cross_head = GatedCrossAttention2(config, margin=cross_margin, tau=cross_tau, dropout=cross_dropout)
        self.cross_gate = nn.Parameter(torch.ones(1))
        
    
    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            
            if len(idx):
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
                                                                m_args['data2ptr'][idx])
                    n_meta = m_args['data2ptr'].max()
                    
                    if self.use_noise:
                        m_input_ids, m_attention_mask = self.add_noise(m_input_ids, m_attention_mask, n_meta)

                    m_embed = self.encode(m_input_ids, m_attention_mask)[0]

                    m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
                    m_repr_mask = torch.any(m_attention_mask, dim=1)
                    
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)

                meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
                
                fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
                embed[idx] += self.cross_gate * fused_embed
                
        return embed, meta_repr
        

In [None]:
#| export
class RAD005(RAD000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]
    
    def __init__(
        self, config,
        resize_length:Optional[int]=None,
        use_noise:Optional[bool]=True,
        noise_percent:Optional[float]=0.7,

        cross_margin:Optional[float]=0.3,
        cross_tau:Optional[float]=0.1,
        cross_dropout:Optional[float]=0.1,
        
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder005(config, use_noise=use_noise, noise_pct=noise_percent, resize_length=resize_length,
                                  cross_margin=cross_margin, cross_tau=cross_tau, cross_dropout=cross_dropout)
        
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head(); self.init_cross_gate()
        

    def init_cross_gate(self):
        self.encoder.cross_gate.data = torch.ones(1)
            
    def remap_post_init(self):
         self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = RAD005.from_pretrained('distilbert-base-uncased', num_batch_labels=5000, batch_size=100,
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,

                               cross_margin=0.3, cross_tau=0.1, cross_dropout=0.1,

                               resize_length=5000, use_noise=True, noise_percent=0.5,
                               
                               meta_loss_weight=0.3, fusion_loss_weight=0.1, 
                               use_fusion_loss=False,  use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_cross_gate()

Some weights of RAD005 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.cross_gate', 'encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.margin', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.tau', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
])

In [None]:
b.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr', 'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 'cat2data_data2ptr', 'data_input_ids', 'data_attention_mask'])

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.0571, grad_fn=<AddBackward0>)

## `RAD006`

In [None]:
#| export
class Encoder006(Encoder):
    
    def __init__(
        self, 
        config:PretrainedConfig, 
        use_noise:Optional[bool]=True, 
        shuffle_noise_pct:Optional[float]=0.5, 
        dropout_noise_pct:Optional[float]=0.1, 
        resize_length:Optional[int]=None
    ):
        store_attr('dropout_noise_pct')
        super().__init__(config, use_noise, shuffle_noise_pct, resize_length)

    def add_noise(self, m_repr:torch.Tensor, m_repr_mask:torch.Tensor):
        n_data, n_meta, dim = m_repr.shape
        noise_mask = torch.rand(n_meta, n_data, device=m_repr.device) < self.noise_pct
        for i,mask in enumerate(noise_mask):
            rnd_idx = torch.randperm(mask.sum())
            m_repr[:,i][mask] = m_repr[:,i][mask][rnd_idx]
            m_repr_mask[:,i][mask] = m_repr_mask[:,i][mask][rnd_idx]
        return m_repr,m_repr_mask

    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            
            if len(idx):
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
                                                                m_args['data2ptr'][idx])
                    n_meta = m_args['data2ptr'].max()
                    m_embed = self.encode(m_input_ids, m_attention_mask)[0]

                    m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
                    m_repr_mask = torch.any(m_attention_mask, dim=1)
                    
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
                meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)

                if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())
                
                fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]

                if self.use_noise:
                    noise_mask = torch.rand(len(idx), device=fused_embed.device) > self.dropout_noise_pct
                    embed[idx[noise_mask]] += fused_embed[noise_mask]
                else:
                    embed[idx] += fused_embed
                
        return embed, meta_repr
        

In [None]:
#| export
class RAD006(RAD000, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]
    
    def __init__(
        self, config,
        resize_length:Optional[int]=None,
        use_noise:Optional[bool]=False,
        shuffle_noise_pct:Optional[float]=0.1,
        dropout_noise_pct:Optional[float]=0.1,

        calib_margin:Optional[float]=0.3,
        calib_num_negatives:Optional[int]=10,
        calib_tau:Optional[float]=0.1,
        calib_apply_softmax:Optional[bool]=True,
        calib_loss_weight:Optional[float]=0.1,
        use_calib_loss:Optional[float]=False,

        use_query_loss:Optional[float]=False,
        
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.c_lw, self.use_calib_loss, self.use_query_loss = calib_loss_weight, use_calib_loss, use_query_loss
        self.encoder = Encoder006(config, use_noise=use_noise, shuffle_noise_pct=shuffle_noise_pct, dropout_noise_pct=dropout_noise_pct,
                                  resize_length=resize_length)
        self.cab_loss_fn = Calibration(margin=calib_margin, tau=calib_tau, n_negatives=calib_num_negatives, 
                                       apply_softmax=calib_apply_softmax, reduce='mean')
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()
        
    def calibration_loss(self, einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx):
        return self.c_lw * self.cab_loss_fn(einp_repr, inp_repr, targ_repr, targ_ptr, targ_idx, ptarg_ptr, ptarg_idx)
        
    def remap_post_init(self):
         self.distilbert = self.encoder.distilbert

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):  
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = XCDataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
        
        
        loss = None; lbl2data_o = EncoderOutput()
        if lbl2data_input_ids is not None:
            lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
            lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
                                 data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
            
            loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                     plbl2data_data2ptr,plbl2data_idx)

            if self.use_query_loss:
                loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                          plbl2data_data2ptr,plbl2data_idx)
            
            if self.use_calib_loss:
                loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
                                              plbl2data_data2ptr,plbl2data_idx)
            
            loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
            
            if self.use_fusion_loss:
                loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
                loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
            
        if not return_dict:
            o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
            return ((loss,) + o) if loss is not None else o
        
        
        return RADOutput(
            loss=loss,
            
            data_repr=data_o.rep,
            data_fused_repr=data_o.fused_rep,
            
            lbl2data_repr=lbl2data_o.rep,
            lbl2data_fused_repr=lbl2data_o.fused_rep,
        )
        

### Example

In [None]:
model = RAD006.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,

                               resize_length=5000, use_noise=False, shuffle_noise_pct=0.5, dropout_noise_pct=0.1,
                               
                               use_query_loss=True,

                               calib_margin=0.3, calib_num_negatives=5, calib_tau=0.1, calib_apply_softmax=True, calib_loss_weight=0.1,
                               use_calib_loss=False,
                               
                               meta_loss_weight=0.0, fusion_loss_weight=0.0, use_fusion_loss=False,
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()

Some weights of RAD006 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

In [None]:
o.loss

tensor(0.0087, grad_fn=<AddBackward0>)

## `RAD007`

In [None]:
#| export
class Encoder007(Encoder006):
    
    def __init__(self, config:PretrainedConfig, use_noise:Optional[bool]=True, shuffle_noise_pct:Optional[float]=0.5, 
                 dropout_noise_pct:Optional[float]=0.1, resize_length:Optional[int]=None,
                 cross_margin:Optional[float]=0.3, cross_tau:Optional[float]=0.1, cross_dropout:Optional[float]=0.1):
        
        super().__init__(config, use_noise, shuffle_noise_pct, dropout_noise_pct, resize_length)
        self.cross_head = GatedCrossAttention2(config, margin=cross_margin, tau=cross_tau, dropout=cross_dropout)
        self.cross_gate = nn.Parameter(torch.ones(1))

    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            
            if len(idx):
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
                                                                m_args['data2ptr'][idx])
                    n_meta = m_args['data2ptr'].max()
                    m_embed = self.encode(m_input_ids, m_attention_mask)[0]

                    m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
                    m_repr_mask = torch.any(m_attention_mask, dim=1)
                    
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
                meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)

                if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())
                
                fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]

                if self.use_noise:
                    noise_mask = torch.rand(len(idx), device=fused_embed.device) > self.dropout_noise_pct
                    embed[idx[noise_mask]] += self.cross_gate * fused_embed[noise_mask]
                else:
                    embed[idx] += self.cross_gate * fused_embed
                
        return embed, meta_repr
        

In [None]:
#| export
class RAD007(RAD006, DistilBertPreTrainedModel):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]
    
    def __init__(
        self, config,
        resize_length:Optional[int]=None,
        use_noise:Optional[bool]=True,
        shuffle_noise_pct:Optional[float]=0.3,
        dropout_noise_pct:Optional[float]=0.3,
        
        cross_margin:Optional[float]=0.3,
        cross_tau:Optional[float]=0.1,
        cross_dropout:Optional[float]=0.1,
        
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.encoder = Encoder007(config, use_noise=use_noise, shuffle_noise_pct=shuffle_noise_pct, dropout_noise_pct=dropout_noise_pct,
                                  resize_length=resize_length, cross_margin=cross_margin, cross_tau=cross_tau, cross_dropout=cross_dropout)
        
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head(); self.init_cross_gate()
        

    def init_cross_gate(self):
        self.encoder.cross_gate.data = torch.ones(1)
            
    def remap_post_init(self):
         self.distilbert = self.encoder.distilbert
        

### Examples

In [None]:
model = RAD007.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix='cat2lbl',

                               resize_length=5000, use_noise=True, shuffle_noise_pct=0.5, dropout_noise_pct=0.1,
                               
                               use_query_loss=True,

                               cross_margin=0.3, cross_tau=0.1, cross_dropout=0.1,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, calib_loss_weight=1.0,
                               use_calib_loss= True,
                               
                               meta_loss_weight=0.1, fusion_loss_weight=0.1, use_fusion_loss=True,
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()
model.init_cross_gate()

Some weights of RAD007 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_gate', 'encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.margin', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.tau', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(0.1098, grad_fn=<AddBackward0>)

## `RAD008`

In [None]:
#| export
class Encoder008(Encoder):
    
    def __init__(
        self, 
        config:PretrainedConfig,
        num_metadata:int,
        use_noise:Optional[bool]=True, 
        shuffle_noise_pct:Optional[float]=0.5, 
        dropout_noise_pct:Optional[float]=0.1, 
        resize_length:Optional[int]=None,
    ):
        store_attr('dropout_noise_pct')
        super().__init__(config, use_noise, shuffle_noise_pct, resize_length)
        self.meta_embeddings = nn.Embedding(num_metadata, config.dim)

    def freeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(False)

    def unfreeze_meta_embeddings(self):
        self.meta_embeddings.requires_grad_(True)

    def set_meta_embeddings(self, embed:torch.Tensor):
        self.meta_embeddings.weight.data = embed

    def resize(self, inputs:torch.Tensor, mask:torch.Tensor, idx:torch.Tensor, num_inputs:torch.Tensor):
        if torch.any(num_inputs == 0): raise ValueError("`num_inputs` should be non-zero positive integer.")
        bsz, dim, total_num_inputs = num_inputs.shape[0], inputs.shape[-1], inputs.shape[0]
        
        self.ones = self.ones.to(inputs.device)
        ones = (
            torch.ones(total_num_inputs, dtype=torch.long, device=inputs.device) 
            if self.ones is None or self.ones.shape[0] < total_num_inputs else self.ones[:total_num_inputs]
        )

        max_num_inputs = num_inputs.max()
        xnum_inputs = max_num_inputs-num_inputs+1

        inputs_ptr = num_inputs.cumsum(dim=0)-1
        repeat_inputs = ones.scatter(0, inputs_ptr, xnum_inputs)
        
        resized_inputs = inputs.repeat_interleave(repeat_inputs, dim=0)
        resized_mask = mask.repeat_interleave(repeat_inputs, dim=0)
        resized_idx = idx.repeat_interleave(repeat_inputs, dim=0)
        
        ignore_mask_idx = ones.scatter(0, inputs_ptr, 0).repeat_interleave(repeat_inputs, dim=0).view(bsz, -1)
        ignore_mask_idx[:, -1] = 1; ignore_mask_idx = ignore_mask_idx.view(-1, 1)
        
        resized_mask *= ignore_mask_idx
        
        return resized_inputs,resized_mask,resized_idx

    def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
        meta_repr = {}
        
        for m_key, m_args in meta_kwargs.items():
            idx = torch.where(m_args['data2ptr'] > 0)[0]
            meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
            
            if len(idx):
                if 'meta_repr' in m_args:
                    m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
                    m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])
                    m_repr_mask = m_repr_mask.bool()
                else:
                    m_input_ids, m_attention_mask, m_idx = self.resize(m_args['input_ids'], m_args['attention_mask'], m_args['idx'],
                                                                       m_args['data2ptr'][idx])
                    n_meta = m_args['data2ptr'].max()
                    m_embed = self.encode(m_input_ids, m_attention_mask)[0]

                    m_repr = self.meta(m_embed, m_attention_mask)
                    m_repr_mask = torch.any(m_attention_mask, dim=1)

                m_repr = F.normalize(m_repr + self.meta_embeddings(m_idx), dim=-1)
                
                m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
                meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)

                if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())
                
                fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]

                if self.use_noise:
                    noise_mask = torch.rand(len(idx), device=fused_embed.device) > self.dropout_noise_pct
                    embed[idx[noise_mask]] += fused_embed[noise_mask]
                else:
                    embed[idx] += fused_embed
                
        return embed, meta_repr

    def forward(
        self, 
        data_input_ids: torch.Tensor, 
        data_attention_mask: torch.Tensor,
        data_aug_meta_prefix: Optional[str]=None,
        data_type:Optional[str]=None,
        data_unnormalized:Optional[bool]=False,
        **kwargs
    ):
        data_o = self.encode(data_input_ids, data_attention_mask)
        data_embed = F.normalize(data_o[0], dim=-1)
        
        if data_type is not None and data_type == "meta":
            data_repr = self.meta(data_embed, data_attention_mask) 
        else: 
            data_repr = self.dr(data_embed, data_attention_mask)
        
        data_fused_repr = meta_repr = None
        if data_aug_meta_prefix is not None:
            meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
            if len(meta_kwargs):
                data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
                                                                             data_attention_mask, 
                                                                             meta_kwargs)
                data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
        
        return EncoderOutput(
            rep=data_repr,
            fused_rep=data_fused_repr,
            meta_repr=meta_repr,
        )
        

In [None]:
#| export
class RAD008(RAD006):
    use_generation,use_representation = False,True
    _tied_weights_keys = ["encoder.distilbert"]

    @delegates(RAD006.__init__)
    def __init__(
        self, config,
        num_metadata:int,
        resize_length:Optional[int]=None,
        use_noise:Optional[bool]=False,
        shuffle_noise_pct:Optional[float]=0.1,
        dropout_noise_pct:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)

        
        self.encoder = Encoder008(config, num_metadata=num_metadata, use_noise=use_noise, shuffle_noise_pct=shuffle_noise_pct, 
                                  dropout_noise_pct=dropout_noise_pct, resize_length=resize_length)
        self.post_init(); self.remap_post_init(); self.init_retrieval_head(); self.init_cross_head()
        
    def remap_post_init(self):
        self.distilbert = self.encoder.distilbert
        

### Example

In [None]:
model = RAD008.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,

                               num_metadata=block.train.dset.meta['cat_meta'].n_meta, 
                               resize_length=5000, use_noise=False, shuffle_noise_pct=0.5, dropout_noise_pct=0.1,
                               
                               use_query_loss=True,

                               calib_margin=0.3, calib_num_negatives=5, calib_tau=0.1, calib_apply_softmax=True, calib_loss_weight=0.1,
                               use_calib_loss=False,
                               
                               meta_loss_weight=0.0, fusion_loss_weight=0.0, use_fusion_loss=False,
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()

Some weights of RAD008 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enc

In [None]:
b = prepare_batch(model, batch, m_args=[
    'pcat2data_idx', 'pcat2data_data2ptr', 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask', 
    'cat2data_data2ptr',
    'pcat2lbl_idx', 'pcat2lbl_lbl2data2ptr', 'pcat2lbl_data2ptr', 'cat2lbl_idx', 'cat2lbl_input_ids', 
    'cat2lbl_attention_mask', 'cat2lbl_lbl2data2ptr', 'cat2lbl_data2ptr',
])

In [None]:
o = model(**b.to(model.device))

> /tmp/ipykernel_32743/1454671095.py(53)forward()
     51         **kwargs
     52     ):  
---> 53         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     54 
     55         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_32743/1454671095.py(55)forward()
     53         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     54 
---> 55         if self.use_encoder_parallel:
     56             encoder = XCDataParallel(module=self.encoder)
     57         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(57)forward()
     55         if self.use_encoder_parallel:
     56             encoder = XCDataParallel(module=self.encoder)
---> 57         else: encoder = self.encoder
     58 
     59         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(59)forward()
     57         else: encoder = self.encoder
     58 
---> 59         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     60         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     61                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(60)forward()
     58 
     59         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 60         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     61                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     62 



ipdb>  c


> /tmp/ipykernel_32743/3514307414.py(100)forward()
     98         **kwargs
     99     ):
--> 100         data_o = self.encode(data_input_ids, data_attention_mask)
    101         data_embed = F.normalize(data_o[0], dim=-1)
    102 



ipdb>  n


> /tmp/ipykernel_32743/3514307414.py(101)forward()
     99     ):
    100         data_o = self.encode(data_input_ids, data_attention_mask)
--> 101         data_embed = F.normalize(data_o[0], dim=-1)
    102 
    103         if data_type is not None and data_type == "meta":



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(103)forward()
    101         data_embed = F.normalize(data_o[0], dim=-1)
    102 
--> 103         if data_type is not None and data_type == "meta":
    104             data_repr = self.meta(data_embed, data_attention_mask)
    105         else:



ipdb>  data_embed.norm(dim=-1)


tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  n


> /tmp/ipykernel_32743/3514307414.py(106)forward()
    104             data_repr = self.meta(data_embed, data_attention_mask)
    105         else:
--> 106             data_repr = self.dr(data_embed, data_attention_mask)
    107 
    108         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(108)forward()
    106             data_repr = self.dr(data_embed, data_attention_mask)
    107 
--> 108         data_fused_repr = meta_repr = None
    109         if data_aug_meta_prefix is not None:
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(109)forward()
    107 
    108         data_fused_repr = meta_repr = None
--> 109         if data_aug_meta_prefix is not None:
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    111             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(110)forward()
    108         data_fused_repr = meta_repr = None
    109         if data_aug_meta_prefix is not None:
--> 110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    111             if len(meta_kwargs):
    112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(111)forward()
    109         if data_aug_meta_prefix is not None:
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 111             if len(meta_kwargs):
    112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
    113                                                                              data_attention_mask,



ipdb>  n


> /tmp/ipykernel_32743/3514307414.py(112)forward()
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    111             if len(meta_kwargs):
--> 112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
    113                                                                              data_attention_mask,
    114                                                                              meta_kwargs)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(113)forward()
    111             if len(meta_kwargs):
    112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
--> 113                                                                              data_attention_mask,
    114                                                                              meta_kwargs)
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(114)forward()
    112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
    113                                                                              data_attention_mask,
--> 114                                                                              meta_kwargs)
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(112)forward()
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    111             if len(meta_kwargs):
--> 112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
    113                                                                              data_attention_mask,
    114                                                                              meta_kwargs)



ipdb>  s


--Call--
> /tmp/ipykernel_32743/3514307414.py(53)fuse_meta_into_embeddings()
     51         return resized_inputs,resized_mask,resized_idx
     52 
3--> 53     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
     54         meta_repr = {}
     55 



ipdb>  n


> /tmp/ipykernel_32743/3514307414.py(54)fuse_meta_into_embeddings()
     52 
3    53     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
---> 54         meta_repr = {}
     55 
     56         for m_key, m_args in meta_kwargs.items():



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(56)fuse_meta_into_embeddings()
     54         meta_repr = {}
     55 
---> 56         for m_key, m_args in meta_kwargs.items():
     57             idx = torch.where(m_args['data2ptr'] > 0)[0]
     58             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(57)fuse_meta_into_embeddings()
     55 
     56         for m_key, m_args in meta_kwargs.items():
---> 57             idx = torch.where(m_args['data2ptr'] > 0)[0]
     58             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     59 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(58)fuse_meta_into_embeddings()
     56         for m_key, m_args in meta_kwargs.items():
     57             idx = torch.where(m_args['data2ptr'] > 0)[0]
---> 58             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     59 
     60             if len(idx):



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(60)fuse_meta_into_embeddings()
     58             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     59 
---> 60             if len(idx):
     61                 if 'meta_repr' in m_args:
     62                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(61)fuse_meta_into_embeddings()
     59 
     60             if len(idx):
---> 61                 if 'meta_repr' in m_args:
     62                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
     63                     m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(66)fuse_meta_into_embeddings()
     64                     m_repr_mask = m_repr_mask.bool()
     65                 else:
---> 66                     m_input_ids, m_attention_mask, m_idx = self.resize(m_args['input_ids'], m_args['attention_mask'], m_args['idx'],
     67                                                                        m_args['data2ptr'][idx])
     68                     n_meta = m_args['data2ptr'].max()



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(67)fuse_meta_into_embeddings()
     65                 else:
     66                     m_input_ids, m_attention_mask, m_idx = self.resize(m_args['input_ids'], m_args['attention_mask'], m_args['idx'],
---> 67                                                                        m_args['data2ptr'][idx])
     68                     n_meta = m_args['data2ptr'].max()
     69                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(66)fuse_meta_into_embeddings()
     64                     m_repr_mask = m_repr_mask.bool()
     65                 else:
---> 66                     m_input_ids, m_attention_mask, m_idx = self.resize(m_args['input_ids'], m_args['attention_mask'], m_args['idx'],
     67                                                                        m_args['data2ptr'][idx])
     68                     n_meta = m_args['data2ptr'].max()



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(68)fuse_meta_into_embeddings()
     66                     m_input_ids, m_attention_mask, m_idx = self.resize(m_args['input_ids'], m_args['attention_mask'], m_args['idx'],
     67                                                                        m_args['data2ptr'][idx])
---> 68                     n_meta = m_args['data2ptr'].max()
     69                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
     70 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(69)fuse_meta_into_embeddings()
     67                                                                        m_args['data2ptr'][idx])
     68                     n_meta = m_args['data2ptr'].max()
---> 69                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
     70 
     71                     m_repr = self.meta(m_embed, m_attention_mask)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(71)fuse_meta_into_embeddings()
     69                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
     70 
---> 71                     m_repr = self.meta(m_embed, m_attention_mask)
     72                     m_repr_mask = torch.any(m_attention_mask, dim=1)
     73 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(72)fuse_meta_into_embeddings()
     70 
     71                     m_repr = self.meta(m_embed, m_attention_mask)
---> 72                     m_repr_mask = torch.any(m_attention_mask, dim=1)
     73 
     74                 m_repr = F.normalize(m_repr + self.meta_embeddings(m_idx), dim=-1)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(74)fuse_meta_into_embeddings()
     72                     m_repr_mask = torch.any(m_attention_mask, dim=1)
     73 
---> 74                 m_repr = F.normalize(m_repr + self.meta_embeddings(m_idx), dim=-1)
     75 
     76                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(76)fuse_meta_into_embeddings()
     74                 m_repr = F.normalize(m_repr + self.meta_embeddings(m_idx), dim=-1)
     75 
---> 76                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
     77                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
     78 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(77)fuse_meta_into_embeddings()
     75 
     76                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
---> 77                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
     78 
     79                 if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(79)fuse_meta_into_embeddings()
     77                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
     78 
---> 79                 if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())
     80 
     81                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(81)fuse_meta_into_embeddings()
     79                 if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())
     80 
---> 81                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
     82 
     83                 if self.use_noise:



ipdb>  embed[idx].norm(dim=-1)


tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  m_repr.norm(dim=-1)


tensor([[1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000]], grad_fn=<LinalgVectorNormBackward0>)


ipdb>  n


> /tmp/ipykernel_32743/3514307414.py(83)fuse_meta_into_embeddings()
     81                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
     82 
---> 83                 if self.use_noise:
     84                     noise_mask = torch.rand(len(idx), device=fused_embed.device) > self.dropout_noise_pct
     85                     embed[idx[noise_mask]] += fused_embed[noise_mask]



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(87)fuse_meta_into_embeddings()
     85                     embed[idx[noise_mask]] += fused_embed[noise_mask]
     86                 else:
---> 87                     embed[idx] += fused_embed
     88 
     89         return embed, meta_repr



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(56)fuse_meta_into_embeddings()
     54         meta_repr = {}
     55 
---> 56         for m_key, m_args in meta_kwargs.items():
     57             idx = torch.where(m_args['data2ptr'] > 0)[0]
     58             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(89)fuse_meta_into_embeddings()
     87                     embed[idx] += fused_embed
     88 
---> 89         return embed, meta_repr
     90 
2    91     def forward(



ipdb>  


--Return--
(tensor([[[-0....PutBackward0>), {'cat2data': tensor([[-0.0...DivBackward0>)})
> /tmp/ipykernel_32743/3514307414.py(89)fuse_meta_into_embeddings()
     87                     embed[idx] += fused_embed
     88 
---> 89         return embed, meta_repr
     90 
2    91     def forward(



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(115)forward()
    113                                                                              data_attention_mask,
    114                                                                              meta_kwargs)
--> 115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
    117         return EncoderOutput(



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(118)forward()
    116 
    117         return EncoderOutput(
--> 118             rep=data_repr,
    119             fused_rep=data_fused_repr,
    120             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(119)forward()
    117         return EncoderOutput(
    118             rep=data_repr,
--> 119             fused_rep=data_fused_repr,
    120             meta_repr=meta_repr,
    121         )



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(120)forward()
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,
--> 120             meta_repr=meta_repr,
    121         )
    122 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  c


> /tmp/ipykernel_32743/3514307414.py(100)forward()
     98         **kwargs
     99     ):
--> 100         data_o = self.encode(data_input_ids, data_attention_mask)
    101         data_embed = F.normalize(data_o[0], dim=-1)
    102 



ipdb>  n


> /tmp/ipykernel_32743/3514307414.py(101)forward()
     99     ):
    100         data_o = self.encode(data_input_ids, data_attention_mask)
--> 101         data_embed = F.normalize(data_o[0], dim=-1)
    102 
    103         if data_type is not None and data_type == "meta":



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(103)forward()
    101         data_embed = F.normalize(data_o[0], dim=-1)
    102 
--> 103         if data_type is not None and data_type == "meta":
    104             data_repr = self.meta(data_embed, data_attention_mask)
    105         else:



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(106)forward()
    104             data_repr = self.meta(data_embed, data_attention_mask)
    105         else:
--> 106             data_repr = self.dr(data_embed, data_attention_mask)
    107 
    108         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(108)forward()
    106             data_repr = self.dr(data_embed, data_attention_mask)
    107 
--> 108         data_fused_repr = meta_repr = None
    109         if data_aug_meta_prefix is not None:
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(109)forward()
    107 
    108         data_fused_repr = meta_repr = None
--> 109         if data_aug_meta_prefix is not None:
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    111             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(118)forward()
    116 
    117         return EncoderOutput(
--> 118             rep=data_repr,
    119             fused_rep=data_fused_repr,
    120             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(119)forward()
    117         return EncoderOutput(
    118             rep=data_repr,
--> 119             fused_rep=data_fused_repr,
    120             meta_repr=meta_repr,
    121         )



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(120)forward()
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,
--> 120             meta_repr=meta_repr,
    121         )
    122 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...eta_repr=None)
> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  c


In [None]:
o.loss

tensor(0.0850, grad_fn=<AddBackward0>)

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b.to(model.device))
    

In [None]:
o = func()

> /tmp/ipykernel_32743/3657616883.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b.to(model.device))
      4 



ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_32743/1454671095.py:38


ipdb>  b model.encoder.forward


Breakpoint 2 at /tmp/ipykernel_32743/3514307414.py:91


ipdb>  b model.encoder.fuse_meta_into_embeddings


Breakpoint 3 at /tmp/ipykernel_32743/3514307414.py:53


ipdb>  c


> /tmp/ipykernel_32743/1454671095.py(53)forward()
     51         **kwargs
     52     ):  
---> 53         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     54 
     55         if self.use_encoder_parallel:



ipdb>  n


> /tmp/ipykernel_32743/1454671095.py(55)forward()
     53         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
     54 
---> 55         if self.use_encoder_parallel:
     56             encoder = XCDataParallel(module=self.encoder)
     57         else: encoder = self.encoder



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(57)forward()
     55         if self.use_encoder_parallel:
     56             encoder = XCDataParallel(module=self.encoder)
---> 57         else: encoder = self.encoder
     58 
     59         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(59)forward()
     57         else: encoder = self.encoder
     58 
---> 59         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     60         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     61                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(60)forward()
     58 
     59         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 60         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     61                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     62 



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(61)forward()
     59         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     60         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 61                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     62 
     63 



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(60)forward()
     58 
     59         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 60         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     61                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     62 



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(61)forward()
     59         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
     60         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 61                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     62 
     63 



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(60)forward()
     58 
     59         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
---> 60         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     61                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
     62 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(100)forward()
     98         **kwargs
     99     ):
--> 100         data_o = self.encode(data_input_ids, data_attention_mask)
    101         data_embed = F.normalize(data_o[0], dim=-1)
    102 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(101)forward()
     99     ):
    100         data_o = self.encode(data_input_ids, data_attention_mask)
--> 101         data_embed = F.normalize(data_o[0], dim=-1)
    102 
    103         if data_type is not None and data_type == "meta":



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(103)forward()
    101         data_embed = F.normalize(data_o[0], dim=-1)
    102 
--> 103         if data_type is not None and data_type == "meta":
    104             data_repr = self.meta(data_embed, data_attention_mask)
    105         else:



ipdb>  data_embed.norm(dim=-1)


tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  n


> /tmp/ipykernel_32743/3514307414.py(106)forward()
    104             data_repr = self.meta(data_embed, data_attention_mask)
    105         else:
--> 106             data_repr = self.dr(data_embed, data_attention_mask)
    107 
    108         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(108)forward()
    106             data_repr = self.dr(data_embed, data_attention_mask)
    107 
--> 108         data_fused_repr = meta_repr = None
    109         if data_aug_meta_prefix is not None:
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(109)forward()
    107 
    108         data_fused_repr = meta_repr = None
--> 109         if data_aug_meta_prefix is not None:
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    111             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(110)forward()
    108         data_fused_repr = meta_repr = None
    109         if data_aug_meta_prefix is not None:
--> 110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    111             if len(meta_kwargs):
    112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(111)forward()
    109         if data_aug_meta_prefix is not None:
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 111             if len(meta_kwargs):
    112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
    113                                                                              data_attention_mask,



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(112)forward()
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    111             if len(meta_kwargs):
--> 112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
    113                                                                              data_attention_mask,
    114                                                                              meta_kwargs)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(113)forward()
    111             if len(meta_kwargs):
    112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
--> 113                                                                              data_attention_mask,
    114                                                                              meta_kwargs)
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(114)forward()
    112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
    113                                                                              data_attention_mask,
--> 114                                                                              meta_kwargs)
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(112)forward()
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    111             if len(meta_kwargs):
--> 112                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_embed, 
    113                                                                              data_attention_mask,
    114                                                                              meta_kwargs)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(54)fuse_meta_into_embeddings()
     52 
3    53     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
---> 54         meta_repr = {}
     55 
     56         for m_key, m_args in meta_kwargs.items():



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(56)fuse_meta_into_embeddings()
     54         meta_repr = {}
     55 
---> 56         for m_key, m_args in meta_kwargs.items():
     57             idx = torch.where(m_args['data2ptr'] > 0)[0]
     58             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(57)fuse_meta_into_embeddings()
     55 
     56         for m_key, m_args in meta_kwargs.items():
---> 57             idx = torch.where(m_args['data2ptr'] > 0)[0]
     58             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     59 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(58)fuse_meta_into_embeddings()
     56         for m_key, m_args in meta_kwargs.items():
     57             idx = torch.where(m_args['data2ptr'] > 0)[0]
---> 58             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     59 
     60             if len(idx):



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(60)fuse_meta_into_embeddings()
     58             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
     59 
---> 60             if len(idx):
     61                 if 'meta_repr' in m_args:
     62                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(61)fuse_meta_into_embeddings()
     59 
     60             if len(idx):
---> 61                 if 'meta_repr' in m_args:
     62                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
     63                     m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(66)fuse_meta_into_embeddings()
     64                     m_repr_mask = m_repr_mask.bool()
     65                 else:
---> 66                     m_input_ids, m_attention_mask, m_idx = self.resize(m_args['input_ids'], m_args['attention_mask'], m_args['idx'],
     67                                                                        m_args['data2ptr'][idx])
     68                     n_meta = m_args['data2ptr'].max()



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(67)fuse_meta_into_embeddings()
     65                 else:
     66                     m_input_ids, m_attention_mask, m_idx = self.resize(m_args['input_ids'], m_args['attention_mask'], m_args['idx'],
---> 67                                                                        m_args['data2ptr'][idx])
     68                     n_meta = m_args['data2ptr'].max()
     69                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(66)fuse_meta_into_embeddings()
     64                     m_repr_mask = m_repr_mask.bool()
     65                 else:
---> 66                     m_input_ids, m_attention_mask, m_idx = self.resize(m_args['input_ids'], m_args['attention_mask'], m_args['idx'],
     67                                                                        m_args['data2ptr'][idx])
     68                     n_meta = m_args['data2ptr'].max()



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(68)fuse_meta_into_embeddings()
     66                     m_input_ids, m_attention_mask, m_idx = self.resize(m_args['input_ids'], m_args['attention_mask'], m_args['idx'],
     67                                                                        m_args['data2ptr'][idx])
---> 68                     n_meta = m_args['data2ptr'].max()
     69                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
     70 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(69)fuse_meta_into_embeddings()
     67                                                                        m_args['data2ptr'][idx])
     68                     n_meta = m_args['data2ptr'].max()
---> 69                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
     70 
     71                     m_repr = self.meta(m_embed, m_attention_mask)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(71)fuse_meta_into_embeddings()
     69                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
     70 
---> 71                     m_repr = self.meta(m_embed, m_attention_mask)
     72                     m_repr_mask = torch.any(m_attention_mask, dim=1)
     73 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(72)fuse_meta_into_embeddings()
     70 
     71                     m_repr = self.meta(m_embed, m_attention_mask)
---> 72                     m_repr_mask = torch.any(m_attention_mask, dim=1)
     73 
     74                 m_repr = F.normalize(m_repr + self.meta_embeddings(m_idx), dim=-1)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(74)fuse_meta_into_embeddings()
     72                     m_repr_mask = torch.any(m_attention_mask, dim=1)
     73 
---> 74                 m_repr = F.normalize(m_repr + self.meta_embeddings(m_idx), dim=-1)
     75 
     76                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)



ipdb>  m_repr_mask


tensor([ True,  True,  True, False, False,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True])


ipdb>  n


> /tmp/ipykernel_32743/3514307414.py(76)fuse_meta_into_embeddings()
     74                 m_repr = F.normalize(m_repr + self.meta_embeddings(m_idx), dim=-1)
     75 
---> 76                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
     77                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
     78 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(77)fuse_meta_into_embeddings()
     75 
     76                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
---> 77                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
     78 
     79                 if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(79)fuse_meta_into_embeddings()
     77                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
     78 
---> 79                 if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())
     80 
     81                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(81)fuse_meta_into_embeddings()
     79                 if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())
     80 
---> 81                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
     82 
     83                 if self.use_noise:



ipdb>  embed[idx].norm(dim=-1)


tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  m_repr.norm(dim=-1)


tensor([[1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 1.0000]], grad_fn=<LinalgVectorNormBackward0>)


ipdb>  n


> /tmp/ipykernel_32743/3514307414.py(83)fuse_meta_into_embeddings()
     81                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
     82 
---> 83                 if self.use_noise:
     84                     noise_mask = torch.rand(len(idx), device=fused_embed.device) > self.dropout_noise_pct
     85                     embed[idx[noise_mask]] += fused_embed[noise_mask]



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(87)fuse_meta_into_embeddings()
     85                     embed[idx[noise_mask]] += fused_embed[noise_mask]
     86                 else:
---> 87                     embed[idx] += fused_embed
     88 
     89         return embed, meta_repr



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(56)fuse_meta_into_embeddings()
     54         meta_repr = {}
     55 
---> 56         for m_key, m_args in meta_kwargs.items():
     57             idx = torch.where(m_args['data2ptr'] > 0)[0]
     58             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(89)fuse_meta_into_embeddings()
     87                     embed[idx] += fused_embed
     88 
---> 89         return embed, meta_repr
     90 
2    91     def forward(



ipdb>  


--Return--
(tensor([[[-0....PutBackward0>), {'cat2data': tensor([[ 0.0...DivBackward0>)})
> /tmp/ipykernel_32743/3514307414.py(89)fuse_meta_into_embeddings()
     87                     embed[idx] += fused_embed
     88 
---> 89         return embed, meta_repr
     90 
2    91     def forward(



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(115)forward()
    113                                                                              data_attention_mask,
    114                                                                              meta_kwargs)
--> 115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
    117         return EncoderOutput(



ipdb>  data_fused_embed.norm(dim=-1)


tensor([[1.3055, 1.2600, 1.2811, 1.2890, 1.2858],
        [1.8759, 1.8536, 1.8672, 1.8319, 1.8269],
        [1.3816, 1.3434, 1.3324, 1.3283, 1.3340],
        [1.4855, 1.4616, 1.4759, 1.4511, 1.4630],
        [1.4110, 1.3139, 1.3207, 1.3839, 1.3767]],
       grad_fn=<LinalgVectorNormBackward0>)


ipdb>  n


> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(118)forward()
    116 
    117         return EncoderOutput(
--> 118             rep=data_repr,
    119             fused_rep=data_fused_repr,
    120             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(119)forward()
    117         return EncoderOutput(
    118             rep=data_repr,
--> 119             fused_rep=data_fused_repr,
    120             meta_repr=meta_repr,
    121         )



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(120)forward()
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,
--> 120             meta_repr=meta_repr,
    121         )
    122 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  c


> /tmp/ipykernel_32743/3514307414.py(100)forward()
     98         **kwargs
     99     ):
--> 100         data_o = self.encode(data_input_ids, data_attention_mask)
    101         data_embed = F.normalize(data_o[0], dim=-1)
    102 



ipdb>  n


> /tmp/ipykernel_32743/3514307414.py(101)forward()
     99     ):
    100         data_o = self.encode(data_input_ids, data_attention_mask)
--> 101         data_embed = F.normalize(data_o[0], dim=-1)
    102 
    103         if data_type is not None and data_type == "meta":



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(103)forward()
    101         data_embed = F.normalize(data_o[0], dim=-1)
    102 
--> 103         if data_type is not None and data_type == "meta":
    104             data_repr = self.meta(data_embed, data_attention_mask)
    105         else:



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(106)forward()
    104             data_repr = self.meta(data_embed, data_attention_mask)
    105         else:
--> 106             data_repr = self.dr(data_embed, data_attention_mask)
    107 
    108         data_fused_repr = meta_repr = None



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(108)forward()
    106             data_repr = self.dr(data_embed, data_attention_mask)
    107 
--> 108         data_fused_repr = meta_repr = None
    109         if data_aug_meta_prefix is not None:
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(109)forward()
    107 
    108         data_fused_repr = meta_repr = None
--> 109         if data_aug_meta_prefix is not None:
    110             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    111             if len(meta_kwargs):



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(118)forward()
    116 
    117         return EncoderOutput(
--> 118             rep=data_repr,
    119             fused_rep=data_fused_repr,
    120             meta_repr=meta_repr,



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(119)forward()
    117         return EncoderOutput(
    118             rep=data_repr,
--> 119             fused_rep=data_fused_repr,
    120             meta_repr=meta_repr,
    121         )



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(120)forward()
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,
--> 120             meta_repr=meta_repr,
    121         )
    122 



ipdb>  


> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...eta_repr=None)
> /tmp/ipykernel_32743/3514307414.py(117)forward()
    115                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    116 
--> 117         return EncoderOutput(
    118             rep=data_repr,
    119             fused_rep=data_fused_repr,



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(70)forward()
     68                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     69 
---> 70             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     71                                      plbl2data_data2ptr,plbl2data_idx)
     72 



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(71)forward()
     69 
     70             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
---> 71                                      plbl2data_data2ptr,plbl2data_idx)
     72 
     73             if self.use_query_loss:



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(70)forward()
     68                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
     69 
---> 70             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     71                                      plbl2data_data2ptr,plbl2data_idx)
     72 



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(73)forward()
     71                                      plbl2data_data2ptr,plbl2data_idx)
     72 
---> 73             if self.use_query_loss:
     74                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     75                                           plbl2data_data2ptr,plbl2data_idx)



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(74)forward()
     72 
     73             if self.use_query_loss:
---> 74                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     75                                           plbl2data_data2ptr,plbl2data_idx)
     76 



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(75)forward()
     73             if self.use_query_loss:
     74                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
---> 75                                           plbl2data_data2ptr,plbl2data_idx)
     76 
     77             if self.use_calib_loss:



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(74)forward()
     72 
     73             if self.use_query_loss:
---> 74                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     75                                           plbl2data_data2ptr,plbl2data_idx)
     76 



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(77)forward()
     75                                           plbl2data_data2ptr,plbl2data_idx)
     76 
---> 77             if self.use_calib_loss:
     78                 loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
     79                                               plbl2data_data2ptr,plbl2data_idx)



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(81)forward()
     79                                               plbl2data_data2ptr,plbl2data_idx)
     80 
---> 81             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
     82 
     83             if self.use_fusion_loss:



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(83)forward()
     81             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
     82 
---> 83             if self.use_fusion_loss:
     84                 loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
     85                 loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(87)forward()
     85                 loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
     86 
---> 87         if not return_dict:
     88             o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
     89             return ((loss,) + o) if loss is not None else o



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(92)forward()
     90 
     91 
---> 92         return RADOutput(
     93             loss=loss,
     94 



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(93)forward()
     91 
     92         return RADOutput(
---> 93             loss=loss,
     94 
     95             data_repr=data_o.rep,



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(95)forward()
     93             loss=loss,
     94 
---> 95             data_repr=data_o.rep,
     96             data_fused_repr=data_o.fused_rep,
     97 



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(96)forward()
     94 
     95             data_repr=data_o.rep,
---> 96             data_fused_repr=data_o.fused_rep,
     97 
     98             lbl2data_repr=lbl2data_o.rep,



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(98)forward()
     96             data_fused_repr=data_o.fused_rep,
     97 
---> 98             lbl2data_repr=lbl2data_o.rep,
     99             lbl2data_fused_repr=lbl2data_o.fused_rep,
    100         )



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(99)forward()
     97 
     98             lbl2data_repr=lbl2data_o.rep,
---> 99             lbl2data_fused_repr=lbl2data_o.fused_rep,
    100         )
    101 



ipdb>  


> /tmp/ipykernel_32743/1454671095.py(92)forward()
     90 
     91 
---> 92         return RADOutput(
     93             loss=loss,
     94 



ipdb>  


--Return--
RADOutput(los...sed_repr=None)
> /tmp/ipykernel_32743/1454671095.py(92)forward()
     90 
     91 
---> 92         return RADOutput(
     93             loss=loss,
     94 



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [None]:
o.loss

tensor(0.0821, grad_fn=<AddBackward0>)