In [1]:
#| default_exp models.distillation

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import torch, numpy as np, torch.nn.functional as F, torch.nn as nn
from typing import Optional
from dataclasses import dataclass
from types import MethodType

from xcai.core import store_attr
from xcai.losses import Cosine, MultiTriplet
from xcai.models.PPP0XX import XCModelOutput
from xcai.models.oak import OAK001
from xcai.models.radga import RADOutput
from xcai.bandits import *

from transformers import DistilBertPreTrainedModel,DistilBertConfig
from transformers.utils.generic import ModelOutput

## Setup

In [114]:
from xcai.main import *
from xcai.basics import *
from xcai.models.oak import OAK003

In [5]:
data_dir = '/Users/suchith720/Projects/data/'

config_key = "data_meta"
config_dir = "/Users/suchith720/Projects/mogicX/configs"
pkl_dir = f"{data_dir}/processed/mogicX"
pkl_file = get_pkl_file(pkl_dir, 'wikiseealsotitles_data-oak-for-msmarco-with-hard-negatives-test_distilbert-base-uncased', 
                        True, False, False)

In [6]:
config_file = f"{config_dir}/39_oak-for-msmarco-with-hard-negatives_test.json"

In [7]:
pkl_file

'/Users/suchith720/Projects/data//processed/mogicX/wikiseealsotitles_data-oak-for-msmarco-with-hard-negatives-test_distilbert-base-uncased_sxc.joblib'

In [8]:
block = build_block(pkl_file, config_file, True, config_key=config_key, only_test=False, main_oversample=True, 
                    meta_oversample={"cat_meta":False, "lnk_meta":True}, n_slbl_samples=5, 
                    n_sdata_meta_samples={"cat_meta":2, "lnk_meta":4}, do_build=False, 
                    train_meta_topk={"lnk_meta":10}, test_meta_topk={"lnk_meta":10}, return_scores=True)

## Teacher

### Helper

In [9]:
#| export
@dataclass
class TCHOutput(ModelOutput):
    data_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    

### Configuration

In [18]:
#| export
class TCHConfig(DistilBertConfig):

    def __init__(
        self,
        n_data:Optional[int]=None,
        n_lbl:Optional[int]=None,
        embed_dim:Optional[int]=None,
        normalize:Optional[bool]=True,
        **kwargs,
    ):
        self.n_data, self.n_lbl, self.embed_dim, self.normalize = n_data, n_lbl, embed_dim, normalize
        super().__init__(**kwargs)
        

### `TCH001`

In [62]:
#| export
class TCH001(DistilBertPreTrainedModel):

    def __init__(self, config:TCHConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.data_repr = nn.Embedding(config.n_data, config.dim)
        self.lbl_repr = nn.Embedding(config.n_lbl, config.dim)

    def get_lbl_embeddings(self):
        return self.lbl_repr.weight

    def get_data_embeddings(self):
        return self.data_repr.weight

    @torch.no_grad()
    def init_embeddings(self, data_repr:torch.Tensor, lbl_repr:torch.Tensor):
        self.data_repr.weight.copy_(data_repr)
        self.lbl_repr.weight.copy_(lbl_repr)

    def freeze_embeddings(self):
        self.data_repr.requires_grad_(False)
        self.lbl_repr.requires_grad_(False)

    def freeze_data_embeddings(self):
        self.data_repr.requires_grad_(False)

    def unfreeze_embeddings(self):
        self.data_repr.requires_grad_(True)
        self.lbl_repr.requires_grad_(True)

    def forward(
        self,
        data_idx:torch.Tensor,
        lbl2data_idx:torch.Tensor,
        **kwargs,
    ):
        return TCHOutput(
            data_repr=self.data_repr(data_idx),
            lbl2data_repr= self.lbl_repr(lbl2data_idx),
        )
        

#### Example

In [37]:
config = TCHConfig(n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

In [38]:
model = TCH001(config)

In [11]:
data_repr, lbl_repr = torch.randn(block.train.dset.n_data, 768), torch.randn(block.n_lbl, 768)

In [12]:
data_repr.shape, lbl_repr.shape

(torch.Size([693082, 768]), torch.Size([312330, 768]))

In [14]:
model.init_embeddings(data_repr, lbl_repr)

In [15]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch)

In [16]:
o = model(**b)

In [17]:
o.data_repr.shape, o.lbl2data_repr.shape

(torch.Size([1, 768]), torch.Size([5, 768]))

### `TCH002`

In [61]:
#| export
class TCH002(DistilBertPreTrainedModel):

    def __init__(self, config:TCHConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.data_repr = nn.Embedding(config.n_data, config.dim)
        self.lbl_repr = nn.Embedding(config.n_lbl, config.dim)
        self.lbl_embeddings = nn.Embedding(config.n_lbl, config.dim)

    def get_lbl_embeddings(self):
        return self.lbl_repr.weight + self.lbl_embeddings.weight

    def get_data_embeddings(self):
        return self.data_repr.weight

    @torch.no_grad()
    def init_representations(self, data_repr:torch.Tensor, lbl_repr:torch.Tensor):
        self.data_repr.weight.copy_(data_repr)
        self.lbl_repr.weight.copy_(lbl_repr)

    @torch.no_grad()
    def init_lbl_embeddings(self):
        nn.init.zeros_(self.lbl_embeddings.weight)

    def freeze_representations(self):
        self.data_repr.requires_grad_(False)
        self.lbl_repr.requires_grad_(False)

    def unfreeze_representations(self):
        self.data_repr.requires_grad_(True)
        self.lbl_repr.requires_grad_(True)

    def forward(
        self,
        data_idx:torch.Tensor,
        lbl2data_idx:torch.Tensor,
        **kwargs,
    ):
        data_repr = self.data_repr(data_idx)
        lbl2data_repr = self.lbl_repr(lbl2data_idx) + self.lbl_embeddings(lbl2data_idx)
        return TCHOutput(
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

#### Example

In [42]:
m_teacher = TCH002(config)

In [43]:
m_teacher.data_repr.weight.requires_grad, m_teacher.lbl_repr.weight.requires_grad

(True, True)

In [44]:
m_teacher.freeze_representations()
m_teacher.init_lbl_embeddings()

In [45]:
m_teacher.data_repr.weight.requires_grad,m_teacher.lbl_repr.weight.requires_grad, m_teacher.lbl_embeddings.weight.requires_grad

(False, False, True)

In [46]:
m_teacher.lbl_embeddings.weight

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)

### `TCH003`

In [59]:
#| export
class TCH003(DistilBertPreTrainedModel):

    def __init__(self, config:TCHConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.data_repr = nn.Embedding(config.n_data, config.dim)

    def get_data_embeddings(self):
        return self.data_repr.weight

    @torch.no_grad()
    def init_embeddings(self, data_repr:torch.Tensor):
        self.data_repr.weight.copy_(data_repr)

    def freeze_embeddings(self):
        self.data_repr.requires_grad_(False)

    def unfreeze_representations(self):
        self.data_repr.requires_grad_(True)

    def forward(
        self,
        data_idx:torch.Tensor,
        **kwargs,
    ):
        data_repr = self.data_repr(data_idx)
        data_repr = F.normalize(data_repr, dim=1) if self.config.normalize else data_repr
        return TCHOutput(
            data_repr=data_repr,
        )
        

#### Example

In [33]:
config = TCHConfig(n_data=1000, embed_dim=4096, normalize=False)
model = TCH003(config)

In [34]:
data_repr, data_idx = torch.randn(1000, 768), torch.randint(0, 1000, size=(10,))

In [35]:
model.init_embeddings(data_repr)
model.freeze_embeddings()

In [36]:
o = model(data_idx)

In [37]:
o.data_repr.norm(dim=1)

tensor([26.6562, 28.0896, 28.5650, 27.4916, 27.0713, 27.2589, 28.2481, 27.8620,
        27.5937, 27.8447])

### `TCH004`

In [60]:
#| export
class TCH004(DistilBertPreTrainedModel):

    def __init__(self, config:TCHConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.data_repr = nn.Embedding(config.n_data, config.embed_dim)
        self.transform = nn.Linear(config.embed_dim, config.dim)

    def get_data_embeddings(self):
        return self.data_repr.weight

    @torch.no_grad()
    def init_embeddings(self, data_repr:torch.Tensor):
        self.data_repr.weight.copy_(data_repr)

    def freeze_embeddings(self):
        self.data_repr.requires_grad_(False)

    def unfreeze_representations(self):
        self.data_repr.requires_grad_(True)

    @torch.no_grad()
    def init_transform(self, embed:Optional[torch.Tensor]=None):
        if embed is None: nn.init.eye_(self.transform.weight)
        else: self.transform.weight.copy_(embed)
        nn.init.zeros_(self.transform.bias)

    def forward(
        self,
        data_idx:torch.Tensor,
        **kwargs,
    ):
        data_repr = self.transform(self.data_repr(data_idx))
        data_repr = F.normalize(data_repr, dim=1) if self.config.normalize else data_repr
        return TCHOutput(
            data_repr=data_repr,
        )
        

#### Example

In [51]:
config = TCHConfig(n_data=1000, embed_dim=4096, normalize=False)
model = TCH004(config)

In [52]:
data_repr, data_idx = torch.randn(1000, 4096), torch.randint(0, 1000, size=(10,))

In [53]:
model.init_embeddings(data_repr)
model.init_transform(torch.eye(config.dim, config.embed_dim))

In [54]:
o = model(data_idx)

In [55]:
o.data_repr.norm(dim=1)

tensor([28.8158, 28.0094, 26.3769, 27.1537, 28.1371, 28.5593, 27.6033, 28.7885,
        28.6897, 27.2464], grad_fn=<LinalgVectorNormBackward0>)

## Distillation

### Helper

In [66]:
#| export
@dataclass
class DTLOutput(ModelOutput):
    loss:Optional[torch.FloatTensor]=None
    data_repr: Optional[torch.FloatTensor] = None
    data_fused_repr:Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    lbl2data_fused_repr:Optional[torch.FloatTensor] = None
    

### Configuration

In [67]:
#| export
class DTLConfig(DistilBertConfig):

    def __init__(
        self,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        student_data_teacher_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        label_mse_loss_weight:Optional[float]=0.1,
        teacher_data_repr_name:Optional[str]='data_repr',
        student_data_repr_name:Optional[str]='data_fused_repr',
        teacher_label_repr_name:Optional[str]='data_repr',
        student_label_repr_name:Optional[str]='data_repr',
        **kwargs,
    ):
        self.margin, self.tau, self.apply_softmax, self.n_negatives = margin, tau, apply_softmax, n_negatives
        self.teacher_data_student_label_loss_weight = teacher_data_student_label_loss_weight
        self.student_data_teacher_label_loss_weight = student_data_teacher_label_loss_weight
        self.data_mse_loss_weight, self.label_mse_loss_weight = data_mse_loss_weight, label_mse_loss_weight
        self.teacher_data_repr_name, self.student_data_repr_name = teacher_data_repr_name, student_data_repr_name
        self.teacher_label_repr_name, self.student_label_repr_name = teacher_label_repr_name, student_label_repr_name
        super().__init__(**kwargs)
        

### `DTL001`

In [115]:
#| export
class DTL001(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(margin=config.margin, n_negatives=config.n_negatives, tau=config.tau, 
                                        apply_softmax=config.apply_softmax, reduce='mean')

        if hasattr(m_student, 'get_label_representation'):
            def get_label_representation(
                self,
                data_idx:Optional[torch.Tensor]=None,
                data_input_ids:Optional[torch.Tensor]=None,
                data_attention_mask:Optional[torch.Tensor]=None,
                **kwargs
            ):
                return self.m_student.get_label_representation(data_idx, data_input_ids, data_attention_mask, **kwargs)
            self.get_label_representation = MethodType(get_label_representation, self)
        
    def forward(
        self,
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(lbl2data_idx=lbl2data_idx, **kwargs)
        student_data_repr = getattr(student_o, self.config.student_data_repr_name, None)
        student_label_repr = getattr(student_o, self.config.student_label_repr_name, None)

        loss = None
        if student_o.loss is not None:
            with torch.no_grad(): teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
            teacher_data_repr = getattr(student_o, self.config.teacher_data_repr_name, None)
            teacher_label_repr = getattr(student_o, self.config.teacher_label_repr_name, None)

            tdsl_loss = 0.0
            if teacher_data_repr is not None and student_label_repr is not None:
                tdsl_loss = self.rep_loss_fn(teacher_data_repr, student_label_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                             kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            sdtl_loss = 0.0
            if student_data_repr is not None and teacher_label_repr is not None:
                sdtl_loss = self.rep_loss_fn(student_data_repr, teacher_label_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                             kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = 0.0
            if teacher_data_repr is not None and student_data_repr is not None:
                dm_loss = self.mse_loss_fn(teacher_data_repr, student_data_repr)

            lm_loss = 0.0
            if teacher_label_repr is not None and student_label_repr is not None:
                lm_loss = self.mse_loss_fn(teacher_label_repr, student_label_repr)
            
            loss = student_o.loss
            loss += self.config.teacher_data_student_label_loss_weight * tdsl_loss
            loss += self.config.student_data_teacher_label_loss_weight * sdtl_loss
            loss += self.config.data_mse_loss_weight * dm_loss + self.config.label_mse_loss_weight * lm_loss

        return DTLOutput(
            loss=loss,
            data_repr=student_o.data_repr,
            data_fused_repr=student_o.data_fused_repr,
            lbl2data_repr=student_o.lbl2data_repr,
            lbl2data_fused_repr=student_o.lbl2data_fused_repr,
        )
        

#### Example

In [116]:
config = TCHConfig(n_data=1000, embed_dim=4096, normalize=False)
m_teacher = TCH004(config)

In [118]:
mname = 'distilbert-base-uncased'
meta_name = 'lnk'

m_student = OAK003.from_pretrained(mname, margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                                   
                                   data_aug_meta_prefix=f'{meta_name}2data', lbl2data_aug_meta_prefix=None,
                               
                                   num_metadata=block.train.dset.meta[f'{meta_name}_meta'].n_meta,
                                   
                                   calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                                   calib_loss_weight=0.1, use_calib_loss=True,
                                   
                                   use_query_loss=True,
                                   
                                   use_encoder_parallel=True, normalize=True)


Some weights of OAK003 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.

In [119]:
config = DTLConfig(margin=0.3, tau=0.1, apply_softmax=False, n_negatives=5, 
                   teacher_data_student_label_loss_weight=1.0, student_data_teacher_label_loss_weight=1.0,
                   data_mse_loss_weight=0.1, label_mse_loss_weight=0.1,
                   teacher_data_repr_name='data_repr', student_data_repr_name='data_fused_repr',
                   teacher_label_repr_name='data_repr', student_label_repr_name='data_repr')

In [120]:
model = DTL001(config, m_student=m_student, m_teacher=m_teacher)

In [126]:
batch = block.train.one_batch(10)
b = prepare_batch(model, batch, m_args=['data_input_ids', 'data_attention_mask', 'lbl2data_data2ptr', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'lnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 
                                        'lnk2data_attention_mask', 'plnk2data_data2ptr', 'plnk2data_idx',
                                       ])

In [127]:
o = model.m_student(**b)

In [128]:
o.loss

tensor(0.0989, grad_fn=<AddBackward0>)

In [129]:
o = model(**b)

IndexError: index 43 is out of bounds for dimension 0 with size 10

In [97]:
o.loss

### `DTL005`

In [None]:
#| export
class DTL005(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        store_attr('teacher_data_student_label_loss_weight,data_mse_loss_weight')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, **kwargs)

        loss = None
        if student_o.loss is not None:
            teacher_o = self.m_teacher(data_idx=data_idx)

            tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], kwargs['lbl2data_idx'], 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)
            
            loss = student_o.loss
            loss += self.teacher_data_student_label_loss_weight * tdsl_loss
            loss += self.data_mse_loss_weight * dm_loss

        return RADOutput(
            loss=loss,
            
            data_repr=student_o.data_repr,
            data_fused_repr=student_o.data_fused_repr,
            
            lbl2data_repr=student_o.lbl2data_repr,
            lbl2data_fused_repr=student_o.lbl2data_fused_repr,
        )
        

#### Example

In [None]:
m_student = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-cos-v5', batch_size=1000, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=True,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

m_student.init_retrieval_head()
m_student.init_cross_head()

m_student.encoder.set_meta_embeddings(torch.zeros(block.train.dset.meta['lnk_meta'].n_meta, m_student.config.dim))

Some weights of OAK001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-cos-v5 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enco

In [None]:
model = DTL005(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=1024, margin=0.3, tau=0.1, n_negatives=10, 
               apply_softmax=True, teacher_data_student_label_loss_weight=1.0, data_mse_loss_weight=0.1)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch, m_args=['lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'lnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 
                                        'lnk2data_attention_mask', 'plnk2data_data2ptr', 'plnk2data_idx'
                                       ])

In [None]:
m,b = model.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

In [None]:
o.loss

AttributeError: 'NoneType' object has no attribute 'loss'

### `DTL006`

In [None]:
#| export
class DTL006(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        student_data_teacher_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        label_mse_loss_weight:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        store_attr('teacher_data_student_label_loss_weight,student_data_teacher_label_loss_weight')
        store_attr('data_mse_loss_weight,label_mse_loss_weight')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')

    def get_label_representation(self, data_idx:torch.Tensor, **kwargs):
        return self.m_student.get_label_representation(data_idx, **kwargs)
        
    def forward(
        self,
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_idx=data_idx, lbl2data_idx=lbl2data_idx, **kwargs)

        loss = None
        if lbl2data_idx is not None and student_o.loss is not None:
            teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)

            tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
            
            sdtl_loss = self.rep_loss_fn(student_o.data_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_repr)
            lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
            
            loss = student_o.loss
            loss += self.teacher_data_student_label_loss_weight * tdsl_loss
            loss += self.student_data_teacher_label_loss_weight * sdtl_loss
            loss += self.data_mse_loss_weight * dm_loss + self.label_mse_loss_weight * lm_loss
            

        return RADOutput(
            loss=loss,
            data_repr=student_o.data_repr,
            lbl2data_repr=student_o.lbl2data_repr,
        )
        

### `DTL007`

In [None]:
#| export
class DTL007(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        student_data_teacher_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        label_mse_loss_weight:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        store_attr('teacher_data_student_label_loss_weight,student_data_teacher_label_loss_weight')
        store_attr('data_mse_loss_weight,label_mse_loss_weight')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                                   lbl2data_idx=lbl2data_idx, **kwargs)

        loss = None
        if lbl2data_idx is not None and student_o.loss is not None:
            teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)

            tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
            
            sdtl_loss = self.rep_loss_fn(student_o.data_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_repr)
            lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
            
            loss = student_o.loss
            loss += self.teacher_data_student_label_loss_weight * tdsl_loss
            loss += self.student_data_teacher_label_loss_weight * sdtl_loss
            loss += self.data_mse_loss_weight * dm_loss + self.label_mse_loss_weight * lm_loss
            

        return RADOutput(
            loss=loss,
            data_repr=student_o.data_repr,            
            lbl2data_repr=student_o.lbl2data_repr,
        )
        

### `DTL008`

In [None]:
#| export
class DTL008(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        store_attr('teacher_data_student_label_loss_weight,data_mse_loss_weight')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, **kwargs)

        loss = None
        if student_o.loss is not None:
            teacher_o = self.m_teacher(data_idx=data_idx)

            tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], kwargs['lbl2data_idx'], 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_repr)
            
            loss = student_o.loss
            loss += self.teacher_data_student_label_loss_weight * tdsl_loss
            loss += self.data_mse_loss_weight * dm_loss

        return RADOutput(
            loss=loss,
            data_repr=student_o.data_repr,            
            lbl2data_repr=student_o.lbl2data_repr,
        )

    

### `DTL009`

In [None]:
#| export
class DTL009(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        student_loss_weight:Optional[float]=1.0,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        student_data_teacher_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        label_mse_loss_weight:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        store_attr('student_loss_weight,teacher_data_student_label_loss_weight,student_data_teacher_label_loss_weight')
        store_attr('data_mse_loss_weight,label_mse_loss_weight')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                                   lbl2data_idx=lbl2data_idx, **kwargs)

        loss = None
        if lbl2data_idx is not None and student_o.loss is not None:
            teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)

            tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
            
            sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)
            lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
            
            loss = self.student_loss_weight * student_o.loss
            loss += self.teacher_data_student_label_loss_weight * tdsl_loss
            loss += self.student_data_teacher_label_loss_weight * sdtl_loss
            loss += self.data_mse_loss_weight * dm_loss + self.label_mse_loss_weight * lm_loss
            

        return RADOutput(
            loss=loss,
            
            data_repr=student_o.data_repr,
            data_fused_repr=student_o.data_fused_repr,
            
            lbl2data_repr=student_o.lbl2data_repr,
            lbl2data_fused_repr=student_o.lbl2data_fused_repr,
        )
        

### `DTL010`

In [None]:
#| export
class DTL010(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        student_loss_weight:Optional[float]=1.0,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        student_data_teacher_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        label_mse_loss_weight:Optional[float]=0.1,

        bandit_learning_rate:Optional[float]=0.01,
        bandit_minimum_value:Optional[float]=0.1,
        bandit_collector:Optional[int]=20,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        self.loss_weights = RLLossWeightsCumuluative(num_samples=5, reward_func=AccMiniBatch, lr=bandit_learning_rate, 
                                                     collector=bandit_collector, std=0.1, min=bandit_minimum_value,
                                                     rest_init=[student_loss_weight,
                                                                teacher_data_student_label_loss_weight, 
                                                                student_data_teacher_label_loss_weight, 
                                                                data_mse_loss_weight, label_mse_loss_weight])
        
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                                   lbl2data_idx=lbl2data_idx, **kwargs)

        loss = None
        if lbl2data_idx is not None and student_o.loss is not None:
            teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)

            tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
            
            sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)
            lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)

            ws = self.loss_weights.sample(lbl2data_idx.device)

            if self.training:
                self.loss_weights.step(student_o.data_fused_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx,
                                       kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'])
            
            loss = ws[0] * student_o.loss + ws[1] * tdsl_loss + ws[2] * sdtl_loss + ws[3] * dm_loss + ws[4] * lm_loss
            

        return RADOutput(
            loss=loss,
            
            data_repr=student_o.data_repr,
            data_fused_repr=student_o.data_fused_repr,
            
            lbl2data_repr=student_o.lbl2data_repr,
            lbl2data_fused_repr=student_o.lbl2data_fused_repr,
        )
        

#### Example

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
m_teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

m_teacher.freeze_embeddings()

In [None]:
m_student = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-cos-v5', batch_size=1000, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=True,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

m_student.init_retrieval_head()
m_student.init_cross_head()

m_student.encoder.set_meta_embeddings(torch.zeros(block.train.dset.meta['lnk_meta'].n_meta, m_student.config.dim))

Some weights of OAK001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-cos-v5 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enco

In [None]:
model = DTL010(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=1024, margin=0.3, tau=0.1, n_negatives=10, 
               apply_softmax=True, bandit_learning_rate=0.001, bandit_minimum_value=0.01, 
               teacher_data_student_label_loss_weight=1.0, student_data_teacher_label_loss_weight=0.1, 
               data_mse_loss_weight=0.1,label_mse_loss_weight=0.1)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch, m_args=['lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'lnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 
                                        'lnk2data_attention_mask', 'plnk2data_data2ptr', 'plnk2data_idx'
                                       ])

In [None]:
model,b = model.to('cuda'), b.to('cuda')

In [None]:
o = model(**b)

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b)
    

In [None]:
o = func()

> /tmp/ipykernel_18051/2066784281.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b)
      4 



ipdb>  q


In [None]:
o.loss

tensor(0.1139, device='cuda:0', grad_fn=<AddBackward0>)

### `DTL011`

In [None]:
#| export
class DTL011(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,

        bandit_learning_rate:Optional[float]=0.01,
        bandit_minimum_value:Optional[float]=0.1,
        bandit_collector:Optional[int]=20,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        self.loss_weights = RLLossWeightsCumuluative(num_samples=4, reward_func=AccMiniBatch, lr=bandit_learning_rate, 
                                                     collector=bandit_collector, std=0.1, min=bandit_minimum_value, 
                                                     rest_init=bandit_minimum_value)
        
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                                   lbl2data_idx=lbl2data_idx, **kwargs)

        loss = None
        if lbl2data_idx is not None and student_o.loss is not None:
            teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)

            tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
            
            sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)
            lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)

            ws = self.loss_weights.sample(lbl2data_idx.device)

            if self.training:
                self.loss_weights.step(student_o.data_fused_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx,
                                       kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'])
            
            loss = student_o.loss + ws[0] * tdsl_loss + ws[1] * sdtl_loss + ws[2] * dm_loss + ws[3] * lm_loss
            

        return RADOutput(
            loss=loss,
            
            data_repr=student_o.data_repr,
            data_fused_repr=student_o.data_fused_repr,
            
            lbl2data_repr=student_o.lbl2data_repr,
            lbl2data_fused_repr=student_o.lbl2data_fused_repr,
        )
        

#### Example

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
m_teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

m_teacher.freeze_embeddings()

In [None]:
m_student = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-cos-v5', batch_size=1000, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=True,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

m_student.init_retrieval_head()
m_student.init_cross_head()

m_student.encoder.set_meta_embeddings(torch.zeros(block.train.dset.meta['lnk_meta'].n_meta, m_student.config.dim))

Some weights of OAK001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-cos-v5 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enco

In [None]:
model = DTL011(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=1024, margin=0.3, tau=0.1, 
               n_negatives=10, apply_softmax=True, bandit_learning_rate=0.01, bandit_minimum_value=0.05)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch, m_args=['lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'lnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 
                                        'lnk2data_attention_mask', 'plnk2data_data2ptr', 'plnk2data_idx'
                                       ])

In [None]:
model,b = model.to('cuda'), b.to('cuda')

In [None]:
o = model(**b)

In [None]:
def func():
    import pdb; pdb.set_trace()
    return model(**b)
    

In [None]:
o = func()

> /tmp/ipykernel_18051/2066784281.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return model(**b)
      4 



ipdb>  b model.forward


Breakpoint 1 at /tmp/ipykernel_18051/802618658.py:28


ipdb>  c


> /tmp/ipykernel_18051/802618658.py(37)forward()
     35         **kwargs
     36     ):
---> 37         student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     38                                    lbl2data_idx=lbl2data_idx, **kwargs)
     39 



ipdb>  n


> /tmp/ipykernel_18051/802618658.py(38)forward()
     36     ):
     37         student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 38                                    lbl2data_idx=lbl2data_idx, **kwargs)
     39 
     40         loss = None



ipdb>  


> /tmp/ipykernel_18051/802618658.py(37)forward()
     35         **kwargs
     36     ):
---> 37         student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     38                                    lbl2data_idx=lbl2data_idx, **kwargs)
     39 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(38)forward()
     36     ):
     37         student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 38                                    lbl2data_idx=lbl2data_idx, **kwargs)
     39 
     40         loss = None



ipdb>  


> /tmp/ipykernel_18051/802618658.py(37)forward()
     35         **kwargs
     36     ):
---> 37         student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     38                                    lbl2data_idx=lbl2data_idx, **kwargs)
     39 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(40)forward()
     38                                    lbl2data_idx=lbl2data_idx, **kwargs)
     39 
---> 40         loss = None
     41         if lbl2data_idx is not None and student_o.loss is not None:
     42             teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)



ipdb>  student_o.loss


tensor(0.0890, grad_fn=<AddBackward0>)


ipdb>  n


> /tmp/ipykernel_18051/802618658.py(41)forward()
     39 
     40         loss = None
---> 41         if lbl2data_idx is not None and student_o.loss is not None:
     42             teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
     43 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(42)forward()
     40         loss = None
     41         if lbl2data_idx is not None and student_o.loss is not None:
---> 42             teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
     43 
     44             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(44)forward()
     42             teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
     43 
---> 44             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     45                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     46 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(45)forward()
     43 
     44             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
---> 45                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     46 
     47             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(44)forward()
     42             teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
     43 
---> 44             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     45                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     46 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(45)forward()
     43 
     44             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
---> 45                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     46 
     47             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(44)forward()
     42             teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
     43 
---> 44             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     45                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     46 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(47)forward()
     45                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     46 
---> 47             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     48                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     49 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(48)forward()
     46 
     47             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
---> 48                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     49 
     50             dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)



ipdb>  


> /tmp/ipykernel_18051/802618658.py(47)forward()
     45                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     46 
---> 47             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     48                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     49 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(48)forward()
     46 
     47             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
---> 48                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     49 
     50             dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)



ipdb>  


> /tmp/ipykernel_18051/802618658.py(47)forward()
     45                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     46 
---> 47             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     48                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     49 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(50)forward()
     48                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     49 
---> 50             dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)
     51             lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
     52 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(51)forward()
     49 
     50             dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)
---> 51             lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
     52 
     53             ws = self.loss_weights.sample(lbl2data_idx.device)



ipdb>  


> /tmp/ipykernel_18051/802618658.py(53)forward()
     51             lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
     52 
---> 53             ws = self.loss_weights.sample(lbl2data_idx.device)
     54 
     55             if self.training:



ipdb>  


> /tmp/ipykernel_18051/802618658.py(55)forward()
     53             ws = self.loss_weights.sample(lbl2data_idx.device)
     54 
---> 55             if self.training:
     56                 self.loss_weights.step(student_o.data_fused_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx,
     57                                        kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'])



ipdb>  ws


tensor([0.0500, 0.0845, 0.0500, 0.0500])


ipdb>  self.loss_weight.std


*** AttributeError: 'DTL011' object has no attribute 'loss_weight'


ipdb>  self.loss_weights.std


Parameter containing:
tensor([0.1000, 0.1000, 0.1000, 0.1000])


ipdb>  self.loss_weights.mu


Parameter containing:
tensor([0.0500, 0.0500, 0.0500, 0.0500], requires_grad=True)


ipdb>  self.loss_weights.w


tensor([0.0500, 0.0845, 0.0500, 0.0500])


ipdb>  n


> /tmp/ipykernel_18051/802618658.py(56)forward()
     54 
     55             if self.training:
---> 56                 self.loss_weights.step(student_o.data_fused_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx,
     57                                        kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'])
     58 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(57)forward()
     55             if self.training:
     56                 self.loss_weights.step(student_o.data_fused_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx,
---> 57                                        kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'])
     58 
     59             loss = student_o.loss + ws[0] * tdsl_loss + ws[1] * sdtl_loss + ws[2] * dm_loss + ws[3] * lm_loss



ipdb>  


> /tmp/ipykernel_18051/802618658.py(56)forward()
     54 
     55             if self.training:
---> 56                 self.loss_weights.step(student_o.data_fused_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx,
     57                                        kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'])
     58 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(59)forward()
     57                                        kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'])
     58 
---> 59             loss = student_o.loss + ws[0] * tdsl_loss + ws[1] * sdtl_loss + ws[2] * dm_loss + ws[3] * lm_loss
     60 
     61 



ipdb>  len(ws)


4


ipdb>  n


> /tmp/ipykernel_18051/802618658.py(62)forward()
     60 
     61 
---> 62         return RADOutput(
     63             loss=loss,
     64 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(63)forward()
     61 
     62         return RADOutput(
---> 63             loss=loss,
     64 
     65             data_repr=student_o.data_repr,



ipdb>  loss


tensor(0.0924, grad_fn=<AddBackward0>)


ipdb>  n


> /tmp/ipykernel_18051/802618658.py(65)forward()
     63             loss=loss,
     64 
---> 65             data_repr=student_o.data_repr,
     66             data_fused_repr=student_o.data_fused_repr,
     67 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(66)forward()
     64 
     65             data_repr=student_o.data_repr,
---> 66             data_fused_repr=student_o.data_fused_repr,
     67 
     68             lbl2data_repr=student_o.lbl2data_repr,



ipdb>  


> /tmp/ipykernel_18051/802618658.py(68)forward()
     66             data_fused_repr=student_o.data_fused_repr,
     67 
---> 68             lbl2data_repr=student_o.lbl2data_repr,
     69             lbl2data_fused_repr=student_o.lbl2data_fused_repr,
     70         )



ipdb>  


> /tmp/ipykernel_18051/802618658.py(69)forward()
     67 
     68             lbl2data_repr=student_o.lbl2data_repr,
---> 69             lbl2data_fused_repr=student_o.lbl2data_fused_repr,
     70         )
     71 



ipdb>  


> /tmp/ipykernel_18051/802618658.py(62)forward()
     60 
     61 
---> 62         return RADOutput(
     63             loss=loss,
     64 



ipdb>  


--Return--
RADOutput(los...sed_repr=None)
> /tmp/ipykernel_18051/802618658.py(62)forward()
     60 
     61 
---> 62         return RADOutput(
     63             loss=loss,
     64 



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [None]:
o.loss

tensor(0.0607, device='cuda:0', grad_fn=<AddBackward0>)

### `DTL012`

In [None]:
#| export
class DTL012(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        store_attr('teacher_data_student_label_loss_weight,data_mse_loss_weight')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        self.transform = nn.Linear(config.dim, m_teacher.data_repr.embedding_dim)

    def init_transform(self, embed:Optional[torch.Tensor]=None):
        if embed is None:
            self.transform.weight.data = torch.eye(self.transform.out_features, self.transform.in_features, 
                                                   dtype=self.transform.weight.dtype)
        else:
            if self.transform.in_features != embed.shape[1] or self.transform.out_features != embed.shape[0]:
                raise ValueError(f'Shape mismatch, input embedding: {embed.shape[0]}X{embed.shape[1]}')
            self.transform.weight.data = embed
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, **kwargs)

        loss = None
        if student_o.loss is not None:
            teacher_o = self.m_teacher(data_idx=data_idx)

            lbl2data_repr = F.normalize(self.transform(student_o.lbl2data_repr), dim=1)
            data_fused_repr = F.normalize(self.transform(student_o.data_fused_repr), dim=1)
            
            tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, lbl2data_repr, kwargs['lbl2data_data2ptr'], kwargs['lbl2data_idx'], 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = self.mse_loss_fn(teacher_o.data_repr, data_fused_repr)
            
            loss = student_o.loss
            loss += self.teacher_data_student_label_loss_weight * tdsl_loss
            loss += self.data_mse_loss_weight * dm_loss

        return RADOutput(
            loss=loss,
            
            data_repr=student_o.data_repr,
            data_fused_repr=student_o.data_fused_repr,
            
            lbl2data_repr=student_o.lbl2data_repr,
            lbl2data_fused_repr=student_o.lbl2data_fused_repr,
        )
        

#### Example

In [None]:
m_student = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-cos-v5', batch_size=1000, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=True,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

m_student.init_retrieval_head()
m_student.init_cross_head()

m_student.encoder.set_meta_embeddings(torch.zeros(block.train.dset.meta['lnk_meta'].n_meta, m_student.config.dim))

Some weights of OAK001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-cos-v5 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enco

In [None]:
model = DTL012(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=1024, margin=0.3, tau=0.1, n_negatives=10, 
               apply_softmax=True, teacher_data_student_label_loss_weight=1.0, data_mse_loss_weight=0.1)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch, m_args=['lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'lnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 
                                        'lnk2data_attention_mask', 'plnk2data_data2ptr', 'plnk2data_idx'
                                       ])

In [None]:
m,b = model.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(0.1215, device='cuda:0', grad_fn=<AddBackward0>)