In [1]:
#| default_exp models.distillation

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import torch, numpy as np, torch.nn.functional as F, torch.nn as nn
from typing import Optional
from dataclasses import dataclass
from types import MethodType

from xcai.core import store_attr
from xcai.losses import Cosine, MultiTriplet, MarginMSEWithNegatives, MultiTripletWithNegatives
from xcai.models.PPP0XX import XCModelOutput
from xcai.models.oak import OAK001
from xcai.models.radga import RADOutput
from xcai.bandits import *

from transformers import DistilBertPreTrainedModel,DistilBertConfig
from transformers.utils.generic import ModelOutput

## Setup

In [4]:
from xcai.main import *
from xcai.basics import *
from xcai.models.oak import OAK003
from xcai.models.PPP0XX import DBT023

In [5]:
data_dir = '/Users/suchith720/Projects/data/'

config_key = "data_meta"
config_dir = "/Users/suchith720/Projects/mogicX/configs"
pkl_dir = f"{data_dir}/processed/mogicX"
pkl_file = get_pkl_file(pkl_dir, 'wikiseealsotitles_data-oak-for-msmarco-with-hard-negatives-test_distilbert-base-uncased', 
                        True, False, False)

In [6]:
config_file = f"{config_dir}/39_oak-for-msmarco-with-hard-negatives_test.json"

In [7]:
pkl_file

'/Users/suchith720/Projects/data//processed/mogicX/wikiseealsotitles_data-oak-for-msmarco-with-hard-negatives-test_distilbert-base-uncased_sxc.joblib'

In [8]:
block = build_block(pkl_file, config_file, True, config_key=config_key, only_test=False, main_oversample=True, 
                    meta_oversample={"cat_meta":False, "lnk_meta":True}, n_slbl_samples=5, 
                    n_sdata_meta_samples={"cat_meta":2, "lnk_meta":4}, do_build=False, 
                    train_meta_topk={"lnk_meta":10}, test_meta_topk={"lnk_meta":10}, return_scores=True)

In [9]:
block.train.dset.meta['neg_meta'] = SMetaXCDataset(prefix='neg', data_meta=block.train.dset.meta['cat_meta'].data_meta, 
                                                   lbl_meta=block.train.dset.meta['cat_meta'].lbl_meta, 
                                                   meta_info=block.train.dset.meta['cat_meta'].meta_info, 
                                                   return_scores=True)

## Teacher

### Helper

In [10]:
#| export
@dataclass
class TCHOutput(ModelOutput):
    data_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    neg2data_repr: Optional[torch.FloatTensor] = None
    

### Configuration

In [12]:
#| export
class TCHConfig(DistilBertConfig):

    def __init__(
        self,
        n_data:Optional[int]=None,
        n_lbl:Optional[int]=None,
        n_neg:Optional[int]=None,
        embed_dim:Optional[int]=None,
        normalize:Optional[bool]=True,
        **kwargs,
    ):
        self.n_data, self.n_lbl, self.n_neg, self.embed_dim, self.normalize = n_data, n_lbl, n_neg, embed_dim, normalize
        super().__init__(**kwargs)
        

### `TCH001`

In [23]:
#| export
class TCH001(DistilBertPreTrainedModel):

    def __init__(self, config:TCHConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.data_repr = nn.Embedding(config.n_data, config.dim)
        self.lbl_repr = nn.Embedding(config.n_lbl, config.dim)
        self.register_buffer("neg2lbl_idx", None if config.n_neg is None else torch.arange(config.n_neg), persistent=True)

    @torch.no_grad()
    def set_neg2lbl_idx_mapping(self, neg2lbl_idx:torch.Tensor):
        assert neg2lbl_idx.shape[0] == self.neg2lbl_idx.shape[0], f"Shape mismatch, `neg2lbl_idx` should have {self.neg2lbl_idx.shape[0]} elements."
        self.neg2lbl_idx.copy_(neg2lbl_idx)

    def get_lbl_embeddings(self):
        return self.lbl_repr.weight

    def get_data_embeddings(self):
        return self.data_repr.weight

    @torch.no_grad()
    def init_embeddings(self, data_repr:torch.Tensor, lbl_repr:torch.Tensor):
        self.data_repr.weight.copy_(data_repr)
        self.lbl_repr.weight.copy_(lbl_repr)

    def freeze_embeddings(self):
        self.data_repr.requires_grad_(False)
        self.lbl_repr.requires_grad_(False)

    def freeze_data_embeddings(self):
        self.data_repr.requires_grad_(False)

    def unfreeze_embeddings(self):
        self.data_repr.requires_grad_(True)
        self.lbl_repr.requires_grad_(True)

    def forward(
        self,
        data_idx:torch.Tensor,
        lbl2data_idx:torch.Tensor,
        neg2data_idx:Optional[torch.Tensor]=None,
        **kwargs,
    ):
        return TCHOutput(
            data_repr=self.data_repr(data_idx),
            lbl2data_repr= self.lbl_repr(lbl2data_idx),
            neg2data_repr=None if neg2data_idx is None else self.lbl_repr(self.neg2lbl_idx[neg2data_idx]) 
        )
        

#### Example

In [24]:
config = TCHConfig(n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

In [25]:
model = TCH001(config)

In [26]:
model.neg2lbl_idx

In [11]:
data_repr, lbl_repr = torch.randn(block.train.dset.n_data, 768), torch.randn(block.n_lbl, 768)

In [12]:
data_repr.shape, lbl_repr.shape

(torch.Size([693082, 768]), torch.Size([312330, 768]))

In [14]:
model.init_embeddings(data_repr, lbl_repr)

In [15]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch)

In [16]:
o = model(**b)

In [17]:
o.data_repr.shape, o.lbl2data_repr.shape

(torch.Size([1, 768]), torch.Size([5, 768]))

### `TCH002`

In [12]:
#| export
class TCH002(DistilBertPreTrainedModel):

    def __init__(self, config:TCHConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.data_repr = nn.Embedding(config.n_data, config.dim)
        self.lbl_repr = nn.Embedding(config.n_lbl, config.dim)
        self.lbl_embeddings = nn.Embedding(config.n_lbl, config.dim)

    def get_lbl_embeddings(self):
        return self.lbl_repr.weight + self.lbl_embeddings.weight

    def get_data_embeddings(self):
        return self.data_repr.weight

    @torch.no_grad()
    def init_representations(self, data_repr:torch.Tensor, lbl_repr:torch.Tensor):
        self.data_repr.weight.copy_(data_repr)
        self.lbl_repr.weight.copy_(lbl_repr)

    @torch.no_grad()
    def init_lbl_embeddings(self):
        nn.init.zeros_(self.lbl_embeddings.weight)

    def freeze_representations(self):
        self.data_repr.requires_grad_(False)
        self.lbl_repr.requires_grad_(False)

    def unfreeze_representations(self):
        self.data_repr.requires_grad_(True)
        self.lbl_repr.requires_grad_(True)

    def forward(
        self,
        data_idx:torch.Tensor,
        lbl2data_idx:torch.Tensor,
        neg2data_idx:Optional[torch.Tensor]=None,
        **kwargs,
    ):
        data_repr = self.data_repr(data_idx)
        lbl2data_repr = self.lbl_repr(lbl2data_idx) + self.lbl_embeddings(lbl2data_idx)
        neg2data_repr = None if neg2data_idx is None else self.lbl_repr(neg2data_idx) + self.lbl_embeddings(neg2data_idx)
        return TCHOutput(
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
            neg2data_repr=neg2data_repr,
        )
        

#### Example

In [42]:
m_teacher = TCH002(config)

In [43]:
m_teacher.data_repr.weight.requires_grad, m_teacher.lbl_repr.weight.requires_grad

(True, True)

In [44]:
m_teacher.freeze_representations()
m_teacher.init_lbl_embeddings()

In [45]:
m_teacher.data_repr.weight.requires_grad,m_teacher.lbl_repr.weight.requires_grad, m_teacher.lbl_embeddings.weight.requires_grad

(False, False, True)

In [46]:
m_teacher.lbl_embeddings.weight

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)

### `TCH003`

In [59]:
#| export
class TCH003(DistilBertPreTrainedModel):

    def __init__(self, config:TCHConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.data_repr = nn.Embedding(config.n_data, config.dim)

    def get_data_embeddings(self):
        return self.data_repr.weight

    @torch.no_grad()
    def init_embeddings(self, data_repr:torch.Tensor):
        self.data_repr.weight.copy_(data_repr)

    def freeze_embeddings(self):
        self.data_repr.requires_grad_(False)

    def unfreeze_representations(self):
        self.data_repr.requires_grad_(True)

    def forward(
        self,
        data_idx:torch.Tensor,
        **kwargs,
    ):
        data_repr = self.data_repr(data_idx)
        data_repr = F.normalize(data_repr, dim=1) if self.config.normalize else data_repr
        return TCHOutput(
            data_repr=data_repr,
        )
        

#### Example

In [33]:
config = TCHConfig(n_data=1000, embed_dim=4096, normalize=False)
model = TCH003(config)

In [34]:
data_repr, data_idx = torch.randn(1000, 768), torch.randint(0, 1000, size=(10,))

In [35]:
model.init_embeddings(data_repr)
model.freeze_embeddings()

In [36]:
o = model(data_idx)

In [37]:
o.data_repr.norm(dim=1)

tensor([26.6562, 28.0896, 28.5650, 27.4916, 27.0713, 27.2589, 28.2481, 27.8620,
        27.5937, 27.8447])

### `TCH004`

In [11]:
#| export
class TCH004(DistilBertPreTrainedModel):

    def __init__(self, config:TCHConfig, **kwargs):
        super().__init__(config, **kwargs)
        self.data_repr = nn.Embedding(config.n_data, config.embed_dim)
        self.transform = nn.Linear(config.embed_dim, config.dim)

    def get_data_embeddings(self):
        return self.data_repr.weight

    @torch.no_grad()
    def init_embeddings(self, data_repr:torch.Tensor):
        self.data_repr.weight.copy_(data_repr)

    def freeze_embeddings(self):
        self.data_repr.requires_grad_(False)

    def unfreeze_representations(self):
        self.data_repr.requires_grad_(True)

    @torch.no_grad()
    def init_transform(self, embed:Optional[torch.Tensor]=None):
        if embed is None: nn.init.eye_(self.transform.weight)
        else: self.transform.weight.copy_(embed)
        nn.init.zeros_(self.transform.bias)

    def forward(
        self,
        data_idx:torch.Tensor,
        **kwargs,
    ):
        data_repr = self.transform(self.data_repr(data_idx))
        data_repr = F.normalize(data_repr, dim=1) if self.config.normalize else data_repr
        return TCHOutput(
            data_repr=data_repr,
        )
        

#### Example

In [51]:
config = TCHConfig(n_data=1000, embed_dim=4096, normalize=False)
model = TCH004(config)

In [52]:
data_repr, data_idx = torch.randn(1000, 4096), torch.randint(0, 1000, size=(10,))

In [53]:
model.init_embeddings(data_repr)
model.init_transform(torch.eye(config.dim, config.embed_dim))

In [54]:
o = model(data_idx)

In [55]:
o.data_repr.norm(dim=1)

tensor([28.8158, 28.0094, 26.3769, 27.1537, 28.1371, 28.5593, 27.6033, 28.7885,
        28.6897, 27.2464], grad_fn=<LinalgVectorNormBackward0>)

## Distillation

### Helper

In [12]:
#| export
@dataclass
class DTLOutput(ModelOutput):
    loss:Optional[torch.FloatTensor]=None
    data_repr: Optional[torch.FloatTensor] = None
    data_fused_repr:Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    lbl2data_fused_repr:Optional[torch.FloatTensor] = None
    

### Configuration

In [13]:
#| export
class DTLConfig(DistilBertConfig):

    def __init__(
        self,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        student_data_teacher_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        label_mse_loss_weight:Optional[float]=0.1,
        teacher_data_repr_name:Optional[str]='data_repr',
        student_data_repr_name:Optional[str]='data_fused_repr',
        teacher_lbl2data_repr_name:Optional[str]='lbl2data_repr',
        student_lbl2data_repr_name:Optional[str]='lbl2data_repr',
        teacher_neg2data_repr_name:Optional[str]='neg2data_repr',
        student_neg2data_repr_name:Optional[str]='neg2data_repr',
        bandit_learning_rate:Optional[float]=0.01,
        bandit_minimum_value:Optional[float]=0.1,
        bandit_collector:Optional[int]=20,
        **kwargs,
    ):
        self.margin, self.tau, self.apply_softmax, self.n_negatives = margin, tau, apply_softmax, n_negatives
        self.teacher_data_student_label_loss_weight = teacher_data_student_label_loss_weight
        self.student_data_teacher_label_loss_weight = student_data_teacher_label_loss_weight
        self.data_mse_loss_weight, self.label_mse_loss_weight = data_mse_loss_weight, label_mse_loss_weight
        self.teacher_data_repr_name, self.student_data_repr_name = teacher_data_repr_name, student_data_repr_name
        self.teacher_lbl2data_repr_name, self.student_lbl2data_repr_name = teacher_lbl2data_repr_name, student_lbl2data_repr_name
        self.teacher_neg2data_repr_name, self.student_neg2data_repr_name = teacher_neg2data_repr_name, student_neg2data_repr_name
        self.bandit_learning_rate, self.bandit_minimum_value = bandit_learning_rate, bandit_minimum_value
        self.bandit_collector = bandit_collector
        super().__init__(**kwargs)
        

### `DTL001`

In [36]:
#| export
class DTL001(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(margin=config.margin, n_negatives=config.n_negatives, tau=config.tau, 
                                        apply_softmax=config.apply_softmax, reduce='mean')

        if hasattr(m_student, 'get_label_representation'):
            def get_label_representation(
                self,
                data_idx:Optional[torch.Tensor]=None,
                data_input_ids:Optional[torch.Tensor]=None,
                data_attention_mask:Optional[torch.Tensor]=None,
                **kwargs
            ):
                return self.m_student.get_label_representation(data_idx, data_input_ids, data_attention_mask, **kwargs)
            self.get_label_representation = MethodType(get_label_representation, self)

    def combine_losses(self, student_loss:float, tdsl_loss:float, sdtl_loss:float, dm_loss:float, lm_loss:float, 
                       student_data_repr:Optional[torch.Tensor]=None, student_lbl2data_repr:Optional[torch.Tensor]=None, **kwargs):
        loss = student_loss
        loss += self.config.teacher_data_student_label_loss_weight * tdsl_loss
        loss += self.config.student_data_teacher_label_loss_weight * sdtl_loss
        loss += self.config.data_mse_loss_weight * dm_loss + self.config.label_mse_loss_weight * lm_loss
        return loss
        
    def forward(
        self,
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(lbl2data_idx=lbl2data_idx, **kwargs)
        student_data_repr = getattr(student_o, self.config.student_data_repr_name, None)
        student_lbl2data_repr = getattr(student_o, self.config.student_lbl2data_repr_name, None)

        loss = None
        if student_o.loss is not None:
            with torch.no_grad(): teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
            teacher_data_repr = getattr(student_o, self.config.teacher_data_repr_name, None)
            teacher_lbl2data_repr = getattr(student_o, self.config.teacher_lbl2data_repr_name, None)

            tdsl_loss = 0.0
            if teacher_data_repr is not None and student_lbl2data_repr is not None and self.config.teacher_data_student_label_loss_weight > 0:
                tdsl_loss = self.rep_loss_fn(teacher_data_repr, student_lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                             kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            sdtl_loss = 0.0
            if student_data_repr is not None and teacher_lbl2data_repr is not None and self.config.student_data_teacher_label_loss_weight > 0:
                sdtl_loss = self.rep_loss_fn(student_data_repr, teacher_lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                             kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = 0.0
            if teacher_data_repr is not None and student_data_repr is not None and self.config.data_mse_loss_weight > 0:
                dm_loss = self.mse_loss_fn(teacher_data_repr, student_data_repr)

            lm_loss = 0.0
            if teacher_lbl2data_repr is not None and student_lbl2data_repr is not None and self.config.label_mse_loss_weight > 0:
                lm_loss = self.mse_loss_fn(teacher_lbl2data_repr, student_lbl2data_repr)

            loss = self.combine_losses(student_o.loss, tdsl_loss, sdtl_loss, dm_loss, lm_loss, 
                                       student_data_repr, student_lbl2data_repr, lbl2data_idx=lbl2data_idx, **kwargs)

        return DTLOutput(
            loss=loss,
            data_repr=getattr(student_o, 'data_repr', None),
            data_fused_repr=getattr(student_o, 'data_fused_repr', None),
            lbl2data_repr=getattr(student_o, 'lbl2data_repr', None),
            lbl2data_fused_repr=getattr(student_o, 'lbl2data_fused_repr', None),
        )
        

#### Example

In [56]:
config = TCHConfig(n_data=1000, embed_dim=4096, normalize=False)
m_teacher = TCH004(config)

In [57]:
mname = 'distilbert-base-uncased'
meta_name = 'lnk'

m_student = OAK003.from_pretrained(mname, margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                                   
                                   data_aug_meta_prefix=f'{meta_name}2data', lbl2data_aug_meta_prefix=None,
                               
                                   num_metadata=block.train.dset.meta[f'{meta_name}_meta'].n_meta,
                                   
                                   calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                                   calib_loss_weight=0.1, use_calib_loss=True,
                                   
                                   use_query_loss=True,
                                   
                                   use_encoder_parallel=True, normalize=True)


Some weights of OAK003 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.

In [58]:
config = DTLConfig(margin=0.3, tau=0.1, apply_softmax=False, n_negatives=5, 
                   teacher_data_student_label_loss_weight=1.0, student_data_teacher_label_loss_weight=1.0,
                   data_mse_loss_weight=0.1, label_mse_loss_weight=0.1,
                   teacher_data_repr_name='data_repr', student_data_repr_name='data_fused_repr',
                   teacher_label_repr_name='lbl2data_repr', student_label_repr_name='lbl2data_repr')

In [59]:
model = DTL001(config, m_student=m_student, m_teacher=m_teacher)

In [60]:
batch = block.train.one_batch(10)
b = prepare_batch(model, batch, m_args=['data_input_ids', 'data_attention_mask', 'lbl2data_data2ptr', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'lnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 
                                        'lnk2data_attention_mask', 'plnk2data_data2ptr', 'plnk2data_idx',
                                       ])

In [61]:
o = model(**b)

In [62]:
o.loss

tensor(0.6609, grad_fn=<AddBackward0>)

### `DTL002`

In [15]:
#| export
class DTL002(DTL001):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        rest_init = [self.config.teacher_data_student_label_loss_weight, self.config.student_data_teacher_label_loss_weight, 
                     self.config.data_mse_loss_weight, self.config.label_mse_loss_weight]
        self.loss_weights = RLLossWeightsCumuluative(num_samples=len(rest_init), reward_func=AccMiniBatch, lr=self.config.bandit_learning_rate, 
                                                     collector=self.config.bandit_collector, std=0.1, min=self.config.bandit_minimum_value,
                                                     rest_init=rest_init)

    def combine_losses(self, student_loss:float, tdsl_loss:float, sdtl_loss:float, dm_loss:float, lm_loss:float, 
                       student_data_repr:Optional[torch.Tensor]=None, student_lbl2data_repr:Optional[torch.Tensor]=None, **kwargs):
        ws = self.loss_weights.sample(kwargs['lbl2data_idx'].device)
        if self.training:
            self.loss_weights.step(student_data_repr, student_lbl2data_repr, kwargs['lbl2data_data2ptr'], 
                                   kwargs['lbl2data_idx'], kwargs['plbl2data_data2ptr'], 
                                   kwargs['plbl2data_idx'])
        loss = student_loss + ws[0]*tdsl_loss + ws[1]*sdtl_loss + ws[2]*dm_loss + ws[3]*lm_loss
        return loss
        

#### Example

In [16]:
config = TCHConfig(n_data=1000, embed_dim=4096, normalize=False)
m_teacher = TCH004(config)

In [17]:
mname = 'distilbert-base-uncased'
meta_name = 'lnk'

m_student = OAK003.from_pretrained(mname, margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                                   
                                   data_aug_meta_prefix=f'{meta_name}2data', lbl2data_aug_meta_prefix=None,
                               
                                   num_metadata=block.train.dset.meta[f'{meta_name}_meta'].n_meta,
                                   
                                   calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                                   calib_loss_weight=0.1, use_calib_loss=True,
                                   
                                   use_query_loss=True,
                                   
                                   use_encoder_parallel=True, normalize=True)


Some weights of OAK003 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.

In [18]:
config = DTLConfig(margin=0.3, tau=0.1, apply_softmax=False, n_negatives=5, 
                   teacher_data_student_label_loss_weight=1.0, student_data_teacher_label_loss_weight=1.0,
                   data_mse_loss_weight=0.1, label_mse_loss_weight=0.1,
                   teacher_data_repr_name='data_repr', student_data_repr_name='data_fused_repr',
                   teacher_label_repr_name='lbl2data_repr', student_label_repr_name='lbl2data_repr', 
                   bandit_learning_rate=0.01, bandit_minimum_value=0.1, bandit_collector=20,)


In [19]:
model = DTL001(config, m_student=m_student, m_teacher=m_teacher)

In [20]:
batch = block.train.one_batch(10)
b = prepare_batch(model, batch, m_args=['data_input_ids', 'data_attention_mask', 'lbl2data_data2ptr', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'lnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 
                                        'lnk2data_attention_mask', 'plnk2data_data2ptr', 'plnk2data_idx',
                                       ])

In [21]:
o = model(**b)

  torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device, size=size)


In [22]:
o.loss

tensor(0.6511, grad_fn=<AddBackward0>)

### `DTL003`

In [47]:
#| export
class DTL003(DTL001):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.rep_loss_fn = self.rep_loss_fn = MarginMSEWithNegatives()

    def forward(
        self,
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        neg2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(lbl2data_idx=lbl2data_idx, **kwargs)
        student_data_repr = getattr(student_o, self.config.student_data_repr_name, None)
        student_lbl2data_repr = getattr(student_o, self.config.student_lbl2data_repr_name, None)
        student_neg2data_repr = getattr(student_o, self.config.student_neg2data_repr_name, None)

        loss = None
        if student_o.loss is not None:
            with torch.no_grad(): 
                teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx, neg2data_idx=neg2data_idx)
            teacher_data_repr = getattr(student_o, self.config.teacher_data_repr_name, None)
            teacher_lbl2data_repr = getattr(student_o, self.config.teacher_lbl2data_repr_name, None)
            teacher_neg2data_repr = getattr(student_o, self.config.teacher_neg2data_repr_name, None)

            tdsl_loss = 0.0
            if (
                teacher_data_repr is not None and student_lbl2data_repr is not None and 
                student_neg2data_repr is not None and self.config.teacher_data_student_label_loss_weight > 0
            ):
                tdsl_loss = self.rep_loss_fn(teacher_data_repr, student_lbl2data_repr, kwargs['lbl2data_scores'], 
                                             student_neg2data_repr, kwargs['neg2data_scores'], **kwargs)
            
            sdtl_loss = 0.0
            if (
                student_data_repr is not None and teacher_lbl2data_repr is not None and 
                teacher_neg2data_repr is not None and self.config.student_data_teacher_label_loss_weight > 0
            ):
                sdtl_loss = self.rep_loss_fn(student_data_repr, teacher_lbl2data_repr, kwargs['lbl2data_scores'], 
                                             teacher_neg2data_repr, kwargs['neg2data_scores'], **kwargs)
                
            dm_loss = 0.0
            if teacher_data_repr is not None and student_data_repr is not None and self.config.data_mse_loss_weight > 0:
                dm_loss = self.mse_loss_fn(teacher_data_repr, student_data_repr)

            lm_loss = 0.0
            if teacher_lbl2data_repr is not None and student_lbl2data_repr is not None and self.config.label_mse_loss_weight > 0:
                lm_loss += self.mse_loss_fn(teacher_lbl2data_repr, student_lbl2data_repr)
                
            if teacher_neg2data_repr is not None and student_neg2data_repr is not None and self.config.label_mse_loss_weight > 0:
                lm_loss += self.mse_loss_fn(teacher_neg2data_repr, student_neg2data_repr)

            loss = self.combine_losses(student_o.loss, tdsl_loss, sdtl_loss, dm_loss, lm_loss, 
                                       student_data_repr, student_lbl2data_repr, lbl2data_idx=lbl2data_idx, **kwargs)
            
        return DTLOutput(
            loss=loss,
            data_repr=getattr(student_o, 'data_repr', None),
            data_fused_repr=getattr(student_o, 'data_fused_repr', None),
            lbl2data_repr=getattr(student_o, 'lbl2data_repr', None),
            lbl2data_fused_repr=getattr(student_o, 'lbl2data_fused_repr', None),
        )
        

#### Example

In [48]:
config = TCHConfig(n_data=1000, embed_dim=4096, normalize=False)
m_teacher = TCH004(config)

In [30]:
mname = 'distilbert-base-uncased'
meta_name = 'lnk'

m_student = DBT023.from_pretrained(mname, use_encoder_parallel=True, normalize=False, use_layer_norm=False)

Some weights of DBT023 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.projector.bias', 'encoder.projector.weight', 'encoder.transform.bias', 'encoder.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [50]:
config = DTLConfig(margin=0.3, tau=0.1, apply_softmax=False, n_negatives=5, 
                   teacher_data_student_label_loss_weight=1.0, student_data_teacher_label_loss_weight=1.0,
                   data_mse_loss_weight=0.1, label_mse_loss_weight=0.1,
                   teacher_data_repr_name='data_repr', student_data_repr_name='data_fused_repr',
                   teacher_lbl2data_repr_name='lbl2data_repr', student_lbl2data_repr_name='lbl2data_repr', 
                   teacher_neg2data_repr_name='neg2data_repr', student_neg2data_repr_name='neg2data_repr',
                   bandit_learning_rate=0.01, bandit_minimum_value=0.1, bandit_collector=20,)


In [51]:
model = DTL003(config, m_student=m_student, m_teacher=m_teacher)

In [73]:
batch = block.train.one_batch(10)
b = prepare_batch(model, batch, m_args=['data_input_ids', 'data_attention_mask', 'lbl2data_data2ptr', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'lbl2data_scores', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'neg2data_data2ptr', 'neg2data_idx', 'neg2data_input_ids', 'neg2data_attention_mask', 
                                        'neg2data_scores', 'pneg2data_data2ptr', 'pneg2data_idx',
                                       ])

In [75]:
o = model(**b)

In [76]:
o.loss

tensor(2.4043, grad_fn=<AddBackward0>)

### `DTL004`

In [77]:
#| export
class DTL004(DTL001):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        self.rep_loss_fn = MultiTripletWithNegatives(margin=config.margin, n_negatives=config.n_negatives, 
                                                     tau=config.tau, apply_softmax=config.apply_softmax, 
                                                     reduce='mean')
        
    def forward(
        self,
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        neg2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(lbl2data_idx=lbl2data_idx, **kwargs)
        student_data_repr = getattr(student_o, self.config.student_data_repr_name, None)
        student_lbl2data_repr = getattr(student_o, self.config.student_lbl2data_repr_name, None)
        student_neg2data_repr = getattr(student_o, self.config.student_neg2data_repr_name, None)

        loss = None
        if student_o.loss is not None:
            with torch.no_grad(): 
                teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx, neg2data_idx=neg2data_idx)
            teacher_data_repr = getattr(student_o, self.config.teacher_data_repr_name, None)
            teacher_lbl2data_repr = getattr(student_o, self.config.teacher_lbl2data_repr_name, None)
            teacher_neg2data_repr = getattr(student_o, self.config.teacher_neg2data_repr_name, None)

            tdsl_loss = 0.0
            if (
                teacher_data_repr is not None and student_lbl2data_repr is not None and 
                student_neg2data_repr is not None and self.config.teacher_data_student_label_loss_weight > 0
            ):
                tdsl_loss = self.rep_loss_fn(teacher_data_repr, pos_targ=student_lbl2data_repr, 
                                             n_pos=kwargs['lbl2data_data2ptr'], pos_idx=kwargs['lbl2data_idx'], 
                                             neg_targ=student_neg2data_repr, n_neg=kwargs['neg2data_data2ptr'], 
                                             neg_idx=kwargs['neg2data_idx'], n_ppos=kwargs['plbl2data_data2ptr'], 
                                             ppos_idx=kwargs['plbl2data_idx'], **kwargs)
                
            sdtl_loss = 0.0
            if (
                student_data_repr is not None and teacher_lbl2data_repr is not None and 
                teacher_neg2data_repr is not None and self.config.student_data_teacher_label_loss_weight > 0
            ):
                sdtl_loss = self.rep_loss_fn(student_data_repr, pos_targ=teacher_lbl2data_repr, 
                                             n_pos=kwargs['lbl2data_data2ptr'], pos_idx=kwargs['lbl2data_idx'], 
                                             neg_targ=teacher_neg2data_repr, n_neg=kwargs['neg2data_data2ptr'], 
                                             neg_idx=kwargs['neg2data_idx'], n_ppos=kwargs['plbl2data_data2ptr'], 
                                             ppos_idx=kwargs['plbl2data_idx'], **kwargs)
                
            dm_loss = 0.0
            if teacher_data_repr is not None and student_data_repr is not None and self.config.data_mse_loss_weight > 0:
                dm_loss = self.mse_loss_fn(teacher_data_repr, student_data_repr)

            lm_loss = 0.0
            if teacher_lbl2data_repr is not None and student_lbl2data_repr is not None and self.config.label_mse_loss_weight > 0:
                lm_loss += self.mse_loss_fn(teacher_lbl2data_repr, student_lbl2data_repr)
                
            if teacher_neg2data_repr is not None and student_neg2data_repr is not None and self.config.label_mse_loss_weight > 0:
                lm_loss += self.mse_loss_fn(teacher_neg2data_repr, student_neg2data_repr)

            loss = self.combine_losses(student_o.loss, tdsl_loss, sdtl_loss, dm_loss, lm_loss, 
                                       student_data_repr, student_lbl2data_repr, lbl2data_idx=lbl2data_idx, **kwargs)

        return DTLOutput(
            loss=loss,
            data_repr=getattr(student_o, 'data_repr', None),
            data_fused_repr=getattr(student_o, 'data_fused_repr', None),
            lbl2data_repr=getattr(student_o, 'lbl2data_repr', None),
            lbl2data_fused_repr=getattr(student_o, 'lbl2data_fused_repr', None),
        )
        

#### Example

In [78]:
config = TCHConfig(n_data=1000, embed_dim=4096, normalize=False)
m_teacher = TCH004(config)

In [79]:
mname = 'distilbert-base-uncased'
meta_name = 'lnk'

m_student = DBT023.from_pretrained(mname, use_encoder_parallel=True, normalize=False, use_layer_norm=False)

Some weights of DBT023 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.projector.bias', 'encoder.projector.weight', 'encoder.transform.bias', 'encoder.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [80]:
config = DTLConfig(margin=0.3, tau=0.1, apply_softmax=False, n_negatives=5, 
                   teacher_data_student_label_loss_weight=1.0, student_data_teacher_label_loss_weight=1.0,
                   data_mse_loss_weight=0.1, label_mse_loss_weight=0.1,
                   teacher_data_repr_name='data_repr', student_data_repr_name='data_fused_repr',
                   teacher_lbl2data_repr_name='lbl2data_repr', student_lbl2data_repr_name='lbl2data_repr', 
                   teacher_neg2data_repr_name='neg2data_repr', student_neg2data_repr_name='neg2data_repr',
                   bandit_learning_rate=0.01, bandit_minimum_value=0.1, bandit_collector=20,)


In [81]:
model = DTL004(config, m_student=m_student, m_teacher=m_teacher)

In [82]:
batch = block.train.one_batch(10)
b = prepare_batch(model, batch, m_args=['data_input_ids', 'data_attention_mask', 'lbl2data_data2ptr', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'lbl2data_scores', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'neg2data_data2ptr', 'neg2data_idx', 'neg2data_input_ids', 'neg2data_attention_mask', 
                                        'neg2data_scores', 'pneg2data_data2ptr', 'pneg2data_idx',
                                       ])

In [83]:
o = model(**b)

In [84]:
o.loss

tensor(1.6274, grad_fn=<AddBackward0>)