In [None]:
#| default_exp models.distillation

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import torch, numpy as np
from typing import Optional
import torch.nn as nn
from dataclasses import dataclass

from xcai.core import store_attr
from xcai.losses import Cosine, MultiTriplet
from xcai.models.PPP0XX import XCModelOutput
from xcai.models.oak import OAK001
from xcai.models.radga import RADOutput
from transformers import DistilBertPreTrainedModel,DistilBertConfig
from transformers.utils.generic import ModelOutput

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
import os,torch, pickle, numpy as np

from xcai.block import *
from xcai.basics import *
from xcai.models.PPP0XX import DBT010
from xcai.models.radga import RAD006

## Setup

In [None]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-metas_distilbert-base-uncased_rm_radga-aug-cat-hlk-block-128.pkl'

In [None]:
!ls {pkl_file}

/home/scai/phd/aiz218323/scratch/datasets/processed/wikiseealso_data-metas_distilbert-base-uncased_rm_radga-aug-cat-hlk-block-128.pkl


In [None]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
block.train.dset.data.data_info['aug_input_ids'] = block.train.dset.data.data_info['input_ids_aug_cat']
block.train.dset.data.data_info['aug_attention_mask'] = block.train.dset.data.data_info['attention_mask_aug_cat']
block.test.dset.data.data_info['aug_input_ids'] = block.test.dset.data.data_info['input_ids_aug_cat']
block.test.dset.data.data_info['aug_attention_mask'] = block.test.dset.data.data_info['attention_mask_aug_cat']

In [None]:
block.train.dset.data.data_info['input_ids'] = block.train.dset.data.data_info['input_ids_aug_cat']
block.train.dset.data.data_info['attention_mask'] = block.train.dset.data.data_info['attention_mask_aug_cat']
block.test.dset.data.data_info['input_ids'] = block.test.dset.data.data_info['input_ids_aug_cat']
block.test.dset.data.data_info['attention_mask'] = block.test.dset.data.data_info['attention_mask_aug_cat']

In [None]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-metas_distilbert-base-uncased_rm_radga-cat-linker.pkl'

In [None]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

## Teacher

### Base

In [None]:
from safetensors import safe_open

In [None]:
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/69-distillation-for-wikiseealso-1-4',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=3000,
    save_steps=3000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
    label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask'],
)

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
output_dir = f"/home/scai/phd/aiz218323/scratch/outputs/{os.path.basename(model_output)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

m_teacher = DBT010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=800, tn_targ=5000, margin=0.3, tau=0.1, 
                                   n_negatives=10, apply_softmax=True, use_encoder_parallel=True)

model_weight_file,model_weights = f'{mname}/model.safetensors',{}
with safe_open(model_weight_file, framework="pt") as file:
    for k in file.keys(): model_weights[k] = file.get_tensor(k)

m_teacher.load_state_dict(model_weights, strict=False)

Some weights of DBT010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


_IncompatibleKeys(missing_keys=['distilbert.embeddings.word_embeddings.weight', 'distilbert.embeddings.position_embeddings.weight', 'distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.0.attention.q_lin.weight', 'distilbert.transformer.layer.0.attention.q_lin.bias', 'distilbert.transformer.layer.0.attention.k_lin.weight', 'distilbert.transformer.layer.0.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.v_lin.weight', 'distilbert.transformer.layer.0.attention.v_lin.bias', 'distilbert.transformer.layer.0.attention.out_lin.weight', 'distilbert.transformer.layer.0.attention.out_lin.bias', 'distilbert.transformer.layer.0.sa_layer_norm.weight', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transformer.layer.0.ffn.lin1.weight', 'distilbert.transformer.layer.0.ffn.lin1.bias', 'distilbert.transformer.layer.0.ffn.lin2.weight', 'distilbert.transformer.layer.0.ffn.lin2.bias', 'distilbert.transformer.layer.

In [None]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
learn = XCLearner(
    model=m_teacher, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
o = learn.predict(block.test.dset)

  0%|          | 0/196 [00:00<?, ?it/s]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


  self._set_arrayXarray(i, j, x)


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,44.9849,28.9636,21.6018,13.4459,44.9849,44.4991,45.9046,48.3028,31.7676,34.6033,37.4802,42.9854,31.7676,35.671,38.1608,41.1199,54.4304,70.9628,74.1619,0.0177,289.8157,612.51,0.383


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,39.1787,25.1231,18.8719,11.9556,39.1787,38.7991,40.2252,42.7056,27.7771,30.1143,32.8007,38.2846,27.7771,31.0783,33.3646,36.2856,49.0181,66.5973,70.1792,0.0198,343.8718,516.224,0.323


### Query-label representation

In [None]:
dataset = learn.train_dataset.data_dset
dataloader = learn.get_test_dataloader(dataset)
data_repr = learn.get_representation(dataloader, representation_attribute='data_repr')

  0%|          | 0/434 [00:00<?, ?it/s]

In [None]:
dataset = learn.train_dataset.lbl_dset
dataloader = learn.get_test_dataloader(dataset)
lbl_repr = learn.get_representation(dataloader, representation_attribute='data_repr')

  0%|          | 0/196 [00:00<?, ?it/s]

### Helper

In [None]:
#| export
@dataclass
class TCHOutput(ModelOutput):
    data_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    

### `TCH001`

In [None]:
#| export
class TCH001(DistilBertPreTrainedModel):

    def __init__(self, config, n_data:int, n_lbl:int, **kwargs):
        super().__init__(config, **kwargs)
        store_attr('n_data,n_lbl')
        self.data_repr = nn.Embedding(self.n_data, config.dim)
        self.lbl_repr = nn.Embedding(self.n_lbl, config.dim)

    def get_lbl_embeddings(self):
        return self.lbl_repr.weight

    def get_data_embeddings(self):
        return self.data_repr.weight

    def init_embeddings(self, data_repr:torch.Tensor, lbl_repr:torch.Tensor):
        self.data_repr.weight.data = data_repr
        self.lbl_repr.weight.data = lbl_repr

    def freeze_embeddings(self):
        self.data_repr.requires_grad_(False)
        self.lbl_repr.requires_grad_(False)

    def freeze_data_embeddings(self):
        self.data_repr.requires_grad_(False)

    def unfreeze_embeddings(self):
        self.data_repr.requires_grad_(True)
        self.lbl_repr.requires_grad_(True)

    def forward(
        self,
        data_idx:torch.Tensor,
        lbl2data_idx:torch.Tensor,
    ):
        return TCHOutput(
            data_repr=self.data_repr(data_idx),
            lbl2data_repr= self.lbl_repr(lbl2data_idx),
        )
        

In [None]:
block.train.dset.n_data, block.n_lbl

(693082, 312330)

#### Example

In [None]:
model = TCH001(DistilBertConfig(), n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

In [None]:
model.init_embeddings(data_repr, lbl_repr)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch)

In [None]:
o = model(**b)

In [None]:
o.data_repr.shape, o.lbl2data_repr.shape

(torch.Size([10, 768]), torch.Size([10, 768]))

In [None]:
model.save_pretrained(f'{model_output}/teacher')

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
m_teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

In [None]:
m_teacher.freeze_data_embeddings()

In [None]:
m_teacher.data_repr.weight.requires_grad,m_teacher.lbl_repr.weight.requires_grad

(False, True)

### `TCH002`

In [None]:
#| export
class TCH002(DistilBertPreTrainedModel):

    def __init__(self, config, n_data:int, n_lbl:int, **kwargs):
        super().__init__(config, **kwargs)
        store_attr('n_data,n_lbl')
        self.data_repr = nn.Embedding(self.n_data, config.dim)
        self.lbl_repr = nn.Embedding(self.n_lbl, config.dim)
        
        self.lbl_embeddings = nn.Embedding(self.n_lbl, config.dim)

    def get_lbl_embeddings(self):
        return self.lbl_repr.weight + self.lbl_embeddings.weight

    def get_data_embeddings(self):
        return self.data_repr.weight

    def init_representations(self, data_repr:torch.Tensor, lbl_repr:torch.Tensor):
        self.data_repr.weight.data = data_repr
        self.lbl_repr.weight.data = lbl_repr

    def init_lbl_embeddings(self):
        self.lbl_embeddings.weight.data = torch.zeros_like(self.lbl_repr.weight.data, dtype=torch.float32)

    def freeze_representations(self):
        self.data_repr.requires_grad_(False)
        self.lbl_repr.requires_grad_(False)

    def unfreeze_representations(self):
        self.data_repr.requires_grad_(True)
        self.lbl_repr.requires_grad_(True)

    def forward(
        self,
        data_idx:torch.Tensor,
        lbl2data_idx:torch.Tensor,
    ):
        data_repr = self.data_repr(data_idx)
        lbl2data_repr = self.lbl_repr(lbl2data_idx) + self.lbl_embeddings(lbl2data_idx)
        return TCHOutput(
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

#### Example

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
m_teacher = TCH002.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

Some weights of TCH002 were not initialized from the model checkpoint at /home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4/teacher and are newly initialized: ['lbl_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
m_teacher.data_repr.weight.requires_grad, m_teacher.lbl_repr.weight.requires_grad

(True, True)

In [None]:
m_teacher.freeze_representations()
m_teacher.init_lbl_embeddings()

In [None]:
m_teacher.data_repr.weight.requires_grad,m_teacher.lbl_repr.weight.requires_grad, m_teacher.lbl_embeddings.weight.requires_grad

(False, False, True)

In [None]:
m_teacher.lbl_embeddings.weight

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)

### `TCH003`

In [None]:
#| export
class TCH003(DistilBertPreTrainedModel):

    def __init__(self, config, n_data:int, **kwargs):
        super().__init__(config, **kwargs)
        store_attr('n_data')
        self.data_repr = nn.Embedding(self.n_data, config.dim)

    def get_data_embeddings(self):
        return self.data_repr.weight

    def init_embeddings(self, data_repr:torch.Tensor):
        self.data_repr.weight.data = data_repr

    def freeze_embeddings(self):
        self.data_repr.requires_grad_(False)

    def unfreeze_representations(self):
        self.data_repr.requires_grad_(True)

    def forward(
        self,
        data_idx:torch.Tensor,
    ):
        data_repr = self.data_repr(data_idx)
        return TCHOutput(
            data_repr=data_repr,
        )
        

#### Example

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
m_teacher = TCH003.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data)

In [None]:
m_teacher.init_embeddings(torch.zeros(block.train.dset.n_data, 768))
m_teacher.freeze_embeddings()

In [None]:
m_teacher.data_repr.weight

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

## Distillation

### `DTL001`

In [None]:
#| export
class DTL001(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert,m_teacher.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        embed_sim_loss_weight:Optional[float]=1.0,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        self.s_lw = embed_sim_loss_weight
        
        self.loss_fn = Cosine(reduce='mean')

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_aug_input_ids:Optional[torch.Tensor]=None,
        data_aug_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, **kwargs)

        loss = None
        if data_aug_input_ids is not None and student_o.loss is not None:
            with torch.no_grad(): 
                teacher_o = self.m_teacher(data_input_ids=data_aug_input_ids, data_attention_mask=data_aug_attention_mask, **kwargs)

            dloss = self.loss_fn(student_o.data_embed, data_attention_mask, teacher_o.data_embed, data_aug_attention_mask)
            dloss += self.loss_fn(student_o.lbl2data_embed, kwargs['lbl2data_attention_mask'], 
                                  teacher_o.lbl2data_embed, kwargs['lbl2data_attention_mask'])
            loss = student_o.loss + self.s_lw * dloss
            
        return XCModelOutput(
            loss=loss,
            data_repr=student_o.data_repr,
            lbl2data_repr=student_o.lbl2data_repr,
        )
        

#### Example

In [None]:
from safetensors import safe_open

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-1'
output_dir = f"/home/scai/phd/aiz218323/scratch/outputs/{os.path.basename(model_output)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

m_teacher = DBT010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=800, tn_targ=5000, margin=0.3, tau=0.1, 
                                   n_negatives=10, apply_softmax=True, use_encoder_parallel=True)

model_weight_file,model_weights = f'{mname}/model.safetensors',{}
with safe_open(model_weight_file, framework="pt") as file:
    for k in file.keys(): model_weights[k] = file.get_tensor(k)

m_teacher.load_state_dict(model_weights, strict=False)

Some weights of DBT010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


_IncompatibleKeys(missing_keys=['distilbert.embeddings.word_embeddings.weight', 'distilbert.embeddings.position_embeddings.weight', 'distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.0.attention.q_lin.weight', 'distilbert.transformer.layer.0.attention.q_lin.bias', 'distilbert.transformer.layer.0.attention.k_lin.weight', 'distilbert.transformer.layer.0.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.v_lin.weight', 'distilbert.transformer.layer.0.attention.v_lin.bias', 'distilbert.transformer.layer.0.attention.out_lin.weight', 'distilbert.transformer.layer.0.attention.out_lin.bias', 'distilbert.transformer.layer.0.sa_layer_norm.weight', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transformer.layer.0.ffn.lin1.weight', 'distilbert.transformer.layer.0.ffn.lin1.bias', 'distilbert.transformer.layer.0.ffn.lin2.weight', 'distilbert.transformer.layer.0.ffn.lin2.bias', 'distilbert.transformer.layer.

In [None]:
m_student = DBT010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=800, tn_targ=5000, margin=0.3, tau=0.1, 
                                       n_negatives=10, apply_softmax=True, use_encoder_parallel=True)
m_student.init_dr_head()

Some weights of DBT010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = DTL001(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, embed_sim_loss_weight=1.0)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch, m_args=['lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx'])

In [None]:
m,b = model.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(1.6006, device='cuda:0', grad_fn=<AddBackward0>)

### `DTL002`

In [None]:
#| export
class DTL002(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        distil_loss_weight:Optional[float]=1.0,
        mse_loss_weight:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        vself.d_lw,self.m_lw = distil_loss_weight,mse_loss_weight
        store_attr('m_student,m_teacher')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                                   lbl2data_idx=lbl2data_idx, **kwargs)

        loss = None
        if lbl2data_idx is not None and student_o.loss is not None:
            with torch.no_grad(): 
                teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)

            dloss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                     kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
            
            dloss += self.rep_loss_fn(student_o.data_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                      kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            mloss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_repr) + self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
            
            loss = student_o.loss + self.d_lw * dloss + self.m_lw * mloss
            
        return XCModelOutput(
            loss=loss,
            data_repr=student_o.data_repr,
            lbl2data_repr=student_o.lbl2data_repr,
        )
        

#### Example

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-1'
m_teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

In [None]:
m_student = DBT010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=800, tn_targ=5000, margin=0.3, tau=0.1, 
                                   n_negatives=10, apply_softmax=True, use_encoder_parallel=True)
m_student.init_dr_head()

Some weights of DBT010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = DTL002(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=1024, margin=0.3, tau=0.1, 
               n_negatives=10, apply_softmax=True, distil_loss_weight=1.0, mse_loss_weight=0.1)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch, m_args=['lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx'])

In [None]:
b.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_input_ids', 'data_attention_mask', 'data_idx'])

In [None]:
m,b = model.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

In [None]:
o.loss

tensor(0.6356, device='cuda:0', grad_fn=<AddBackward0>)

### `DTL003`

In [None]:
#| export
class DTL003(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        student_data_teacher_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        label_mse_loss_weight:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        store_attr('teacher_data_student_label_loss_weight,student_data_teacher_label_loss_weight')
        store_attr('data_mse_loss_weight,label_mse_loss_weight')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                                   lbl2data_idx=lbl2data_idx, **kwargs)

        loss = None
        if lbl2data_idx is not None and student_o.loss is not None:
            with torch.no_grad(): 
                teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)

            tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
            
            sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)
            lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
            
            loss = student_o.loss
            loss += self.teacher_data_student_label_loss_weight * tdsl_loss
            loss += self.student_data_teacher_label_loss_weight * sdtl_loss
            loss += self.data_mse_loss_weight * dm_loss + self.label_mse_loss_weight * lm_loss
            

        return RADOutput(
            loss=loss,
            
            data_repr=student_o.data_repr,
            data_fused_repr=student_o.data_fused_repr,
            
            lbl2data_repr=student_o.lbl2data_repr,
            lbl2data_fused_repr=student_o.lbl2data_fused_repr,
        )
        

#### Example

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-1'
m_teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

In [None]:
m_student = RAD006.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=5000, num_batch_labels=5000, 
                                   margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                                   data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                                   data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
    
                                   resize_length=5000, use_noise=False, shuffle_noise_pct=0.5, dropout_noise_pct=0.1,
                                   
                                   use_query_loss=True,
    
                                   calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, calib_loss_weight=0.1,
                                   use_calib_loss=True,
                                   
                                   meta_loss_weight=0.0, fusion_loss_weight=0.0, use_fusion_loss=False,
                                   use_encoder_parallel=False)

m_student.init_retrieval_head()
m_student.init_cross_head()

Some weights of RAD006 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = DTL003(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=1024, margin=0.3, tau=0.1, n_negatives=10, apply_softmax=True, 
               teacher_data_student_label_loss_weight=1.0, student_data_teacher_label_loss_weight=0.1, 
               data_mse_loss_weight=0.1,label_mse_loss_weight=0.1)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch, m_args=['lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'lnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 
                                        'lnk2data_attention_mask', 'plnk2data_data2ptr', 'plnk2data_idx'
                                       ])

In [None]:
m,b = model.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

> /tmp/ipykernel_40248/676308714.py(42)forward()
     40         import pdb; pdb.set_trace()
     41 
---> 42         student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     43                                    lbl2data_idx=lbl2data_idx, **kwargs)
     44 



ipdb>  self.m_student.encoder


Encoder006(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear

ipdb>  b self.m_student.forward


Breakpoint 1 at /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py:1310


ipdb>  b self.m_student.encoder


*** The specified object 'self.m_student.encoder' is not a function or was not found along sys.path.


ipdb>  b self.m_student.encoder.forward


Breakpoint 2 at /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py:457


ipdb>  n


> /tmp/ipykernel_40248/676308714.py(43)forward()
     41 
     42         student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 43                                    lbl2data_idx=lbl2data_idx, **kwargs)
     44 
     45         loss = None



ipdb>  n


> /tmp/ipykernel_40248/676308714.py(42)forward()
     40         import pdb; pdb.set_trace()
     41 
---> 42         student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     43                                    lbl2data_idx=lbl2data_idx, **kwargs)
     44 



ipdb>  s


> /tmp/ipykernel_40248/676308714.py(43)forward()
     41 
     42         student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
---> 43                                    lbl2data_idx=lbl2data_idx, **kwargs)
     44 
     45         loss = None



ipdb>  n


> /tmp/ipykernel_40248/676308714.py(42)forward()
     40         import pdb; pdb.set_trace()
     41 
---> 42         student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
     43                                    lbl2data_idx=lbl2data_idx, **kwargs)
     44 



ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1507)_wrapped_call_impl()
   1505         return result
   1506 
-> 1507     def _wrapped_call_impl(self, *args, **kwargs):
   1508         if self._compiled_call_impl is not None:
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]



ipdb>  c


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1325)forward()
   1323         **kwargs
   1324     ):  
-> 1325         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1326 
   1327         if self.use_encoder_parallel:



ipdb>  kwargs.keys()


dict_keys(['plnk2data_idx', 'plnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask', 'lnk2data_data2ptr'])


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1327)forward()
   1325         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1326 
-> 1327         if self.use_encoder_parallel:
   1328             encoder = XCDataParallel(module=self.encoder)
   1329         else: encoder = self.encoder



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1329)forward()
   1327         if self.use_encoder_parallel:
   1328             encoder = XCDataParallel(module=self.encoder)
-> 1329         else: encoder = self.encoder
   1330 
   1331         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1331)forward()
   1329         else: encoder = self.encoder
   1330 
-> 1331         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
   1332         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
   1333                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1332)forward()
   1330 
   1331         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
-> 1332         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
   1333                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
   1334 



ipdb>  data_meta_kwargs.keys()


dict_keys(['lnk2data_attention_mask', 'lnk2data_input_ids', 'lnk2data_data2ptr'])


ipdb>  data_meta_kwargs['lnk2data_attention_mask'].shape


torch.Size([30, 10])


ipdb>  data_meta_kwargs['lnk2data_input_ids'].shape


torch.Size([30, 10])


ipdb>  data_meta_kwargs['lnk2data_data2ptr'].shape


torch.Size([10])


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1333)forward()
   1331         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
   1332         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
-> 1333                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
   1334 
   1335 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1332)forward()
   1330 
   1331         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
-> 1332         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
   1333                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
   1334 



ipdb>  s


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1333)forward()
   1331         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
   1332         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
-> 1333                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
   1334 
   1335 



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1332)forward()
   1330 
   1331         data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('data', self.data_aug_meta_prefix, **kwargs)
-> 1332         data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
   1333                          data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
   1334 



ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1507)_wrapped_call_impl()
   1505         return result
   1506 
-> 1507     def _wrapped_call_impl(self, *args, **kwargs):
   1508         if self._compiled_call_impl is not None:
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]



ipdb>  c


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(466)forward()
    464         **kwargs
    465     ):
--> 466         data_o = self.encode(data_input_ids, data_attention_mask)
    467 
    468         if data_type is not None and data_type == "meta":



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(468)forward()
    466         data_o = self.encode(data_input_ids, data_attention_mask)
    467 
--> 468         if data_type is not None and data_type == "meta":
    469             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    470         else:



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(471)forward()
    469             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    470         else:
--> 471             data_repr = self.dr(data_o[0], data_attention_mask)
    472 
    473         data_fused_repr = meta_repr = None



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(473)forward()
    471             data_repr = self.dr(data_o[0], data_attention_mask)
    472 
--> 473         data_fused_repr = meta_repr = None
    474         if data_aug_meta_prefix is not None:
    475             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  data_repr.shape


torch.Size([10, 768])


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(474)forward()
    472 
    473         data_fused_repr = meta_repr = None
--> 474         if data_aug_meta_prefix is not None:
    475             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    476             if len(meta_kwargs):



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(475)forward()
    473         data_fused_repr = meta_repr = None
    474         if data_aug_meta_prefix is not None:
--> 475             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    476             if len(meta_kwargs):
    477                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(476)forward()
    474         if data_aug_meta_prefix is not None:
    475             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 476             if len(meta_kwargs):
    477                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    478                                                                              data_attention_mask,



ipdb>  meta_kwargs.keys()


dict_keys(['lnk2data'])


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(477)forward()
    475             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    476             if len(meta_kwargs):
--> 477                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    478                                                                              data_attention_mask,
    479                                                                              meta_kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(478)forward()
    476             if len(meta_kwargs):
    477                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
--> 478                                                                              data_attention_mask,
    479                                                                              meta_kwargs)
    480                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(479)forward()
    477                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    478                                                                              data_attention_mask,
--> 479                                                                              meta_kwargs)
    480                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    481 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(477)forward()
    475             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    476             if len(meta_kwargs):
--> 477                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    478                                                                              data_attention_mask,
    479                                                                              meta_kwargs)



ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1236)fuse_meta_into_embeddings()
   1234         return m_repr,m_repr_mask
   1235 
-> 1236     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
   1237         meta_repr = {}
   1238 



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1237)fuse_meta_into_embeddings()
   1235 
   1236     def fuse_meta_into_embeddings(self, embed:torch.Tensor, attention_mask:torch.Tensor, meta_kwargs:Dict):
-> 1237         meta_repr = {}
   1238 
   1239         for m_key, m_args in meta_kwargs.items():



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1239)fuse_meta_into_embeddings()
   1237         meta_repr = {}
   1238 
-> 1239         for m_key, m_args in meta_kwargs.items():
   1240             idx = torch.where(m_args['data2ptr'] > 0)[0]
   1241             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1240)fuse_meta_into_embeddings()
   1238 
   1239         for m_key, m_args in meta_kwargs.items():
-> 1240             idx = torch.where(m_args['data2ptr'] > 0)[0]
   1241             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
   1242 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1241)fuse_meta_into_embeddings()
   1239         for m_key, m_args in meta_kwargs.items():
   1240             idx = torch.where(m_args['data2ptr'] > 0)[0]
-> 1241             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
   1242 
   1243             if len(idx):



ipdb>  idx.shape


torch.Size([10])


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1243)fuse_meta_into_embeddings()
   1241             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)
   1242 
-> 1243             if len(idx):
   1244                 if 'meta_repr' in m_args:
   1245                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)



ipdb>  meta_repr[m_key]


tensor([], device='cuda:0', size=(0, 768))


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1244)fuse_meta_into_embeddings()
   1242 
   1243             if len(idx):
-> 1244                 if 'meta_repr' in m_args:
   1245                     m_repr,m_repr_mask = m_args['meta_repr'],torch.any(m_args['attention_mask'], dim=1).long().view(-1,1)
   1246                     m_repr,m_repr_mask = self.resize(m_repr, m_repr_mask, m_args['data2ptr'][idx])



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1249)fuse_meta_into_embeddings()
   1247                     m_repr_mask = m_repr_mask.bool()
   1248                 else:
-> 1249                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
   1250                                                                 m_args['data2ptr'][idx])
   1251                     n_meta = m_args['data2ptr'].max()



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1250)fuse_meta_into_embeddings()
   1248                 else:
   1249                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
-> 1250                                                                 m_args['data2ptr'][idx])
   1251                     n_meta = m_args['data2ptr'].max()
   1252                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1249)fuse_meta_into_embeddings()
   1247                     m_repr_mask = m_repr_mask.bool()
   1248                 else:
-> 1249                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
   1250                                                                 m_args['data2ptr'][idx])
   1251                     n_meta = m_args['data2ptr'].max()



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1251)fuse_meta_into_embeddings()
   1249                     m_input_ids, m_attention_mask = self.resize(m_args['input_ids'], m_args['attention_mask'], 
   1250                                                                 m_args['data2ptr'][idx])
-> 1251                     n_meta = m_args['data2ptr'].max()
   1252                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
   1253 



ipdb>  m_input_ids.shape


torch.Size([30, 10])


ipdb>  m_attention_mask.shape


torch.Size([30, 10])


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1252)fuse_meta_into_embeddings()
   1250                                                                 m_args['data2ptr'][idx])
   1251                     n_meta = m_args['data2ptr'].max()
-> 1252                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
   1253 
   1254                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)



ipdb>  n_meta


tensor(3, device='cuda:0')


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1254)fuse_meta_into_embeddings()
   1252                     m_embed = self.encode(m_input_ids, m_attention_mask)[0]
   1253 
-> 1254                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
   1255                     m_repr_mask = torch.any(m_attention_mask, dim=1)
   1256 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1255)fuse_meta_into_embeddings()
   1253 
   1254                     m_repr = self.meta_unnormalized(m_embed, m_attention_mask)
-> 1255                     m_repr_mask = torch.any(m_attention_mask, dim=1)
   1256 
   1257                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1257)fuse_meta_into_embeddings()
   1255                     m_repr_mask = torch.any(m_attention_mask, dim=1)
   1256 
-> 1257                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
   1258                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
   1259 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1258)fuse_meta_into_embeddings()
   1256 
   1257                 m_repr, m_repr_mask = m_repr.view(len(idx), -1, self.config.dim), m_repr_mask.view(len(idx), -1)
-> 1258                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
   1259 
   1260                 if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1260)fuse_meta_into_embeddings()
   1258                 meta_repr[m_key] = F.normalize(m_repr[m_repr_mask], dim=1)
   1259 
-> 1260                 if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())
   1261 
   1262                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1262)fuse_meta_into_embeddings()
   1260                 if self.use_noise: m_repr, m_repr_mask = self.add_noise(m_repr.clone(), m_repr_mask.clone())
   1261 
-> 1262                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
   1263 
   1264                 if self.use_noise:



ipdb>  embed.shape


torch.Size([10, 5, 768])


ipdb>  m_repr.shape


torch.Size([10, 3, 768])


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1264)fuse_meta_into_embeddings()
   1262                 fused_embed = self.cross_head(embed[idx], attention_mask[idx], m_repr, m_repr_mask)[0]
   1263 
-> 1264                 if self.use_noise:
   1265                     noise_mask = torch.rand(len(idx), device=fused_embed.device) > self.dropout_noise_pct
   1266                     embed[idx[noise_mask]] += fused_embed[noise_mask]



ipdb>  self.use_noise


False


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1268)fuse_meta_into_embeddings()
   1266                     embed[idx[noise_mask]] += fused_embed[noise_mask]
   1267                 else:
-> 1268                     embed[idx] += fused_embed
   1269 
   1270         return embed, meta_repr



ipdb>  fused_embed.shape


torch.Size([10, 5, 768])


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1239)fuse_meta_into_embeddings()
   1237         meta_repr = {}
   1238 
-> 1239         for m_key, m_args in meta_kwargs.items():
   1240             idx = torch.where(m_args['data2ptr'] > 0)[0]
   1241             meta_repr[m_key] = torch.empty(0, self.config.dim).to(embed)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1270)fuse_meta_into_embeddings()
   1268                     embed[idx] += fused_embed
   1269 
-> 1270         return embed, meta_repr
   1271 
   1272 



ipdb>  


--Return--
(tensor([[[-0....PutBackward0>), {'lnk2data': tensor([[-0.0...DivBackward0>)})
> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1270)fuse_meta_into_embeddings()
   1268                     embed[idx] += fused_embed
   1269 
-> 1270         return embed, meta_repr
   1271 
   1272 



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(480)forward()
    478                                                                              data_attention_mask,
    479                                                                              meta_kwargs)
--> 480                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    481 
    482         return EncoderOutput(



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(482)forward()
    480                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    481 
--> 482         return EncoderOutput(
    483             rep=data_repr,
    484             fused_rep=data_fused_repr,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(483)forward()
    481 
    482         return EncoderOutput(
--> 483             rep=data_repr,
    484             fused_rep=data_fused_repr,
    485             meta_repr=meta_repr,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(484)forward()
    482         return EncoderOutput(
    483             rep=data_repr,
--> 484             fused_rep=data_fused_repr,
    485             meta_repr=meta_repr,
    486         )



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(485)forward()
    483             rep=data_repr,
    484             fused_rep=data_fused_repr,
--> 485             meta_repr=meta_repr,
    486         )
    487 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(482)forward()
    480                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    481 
--> 482         return EncoderOutput(
    483             rep=data_repr,
    484             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(482)forward()
    480                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    481 
--> 482         return EncoderOutput(
    483             rep=data_repr,
    484             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1511)_wrapped_call_impl()
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1336)forward()
   1334 
   1335 
-> 1336         loss = None; lbl2data_o = EncoderOutput()
   1337         if lbl2data_input_ids is not None:
   1338             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1337)forward()
   1335 
   1336         loss = None; lbl2data_o = EncoderOutput()
-> 1337         if lbl2data_input_ids is not None:
   1338             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
   1339             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1338)forward()
   1336         loss = None; lbl2data_o = EncoderOutput()
   1337         if lbl2data_input_ids is not None:
-> 1338             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
   1339             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
   1340                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1339)forward()
   1337         if lbl2data_input_ids is not None:
   1338             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
-> 1339             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
   1340                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
   1341 



ipdb>  lbl2data_meta_kwargs.keys()


dict_keys([])


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1340)forward()
   1338             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
   1339             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
-> 1340                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
   1341 
   1342             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1339)forward()
   1337         if lbl2data_input_ids is not None:
   1338             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
-> 1339             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
   1340                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
   1341 



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1340)forward()
   1338             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
   1339             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
-> 1340                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
   1341 
   1342             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1339)forward()
   1337         if lbl2data_input_ids is not None:
   1338             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
-> 1339             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
   1340                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
   1341 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(466)forward()
    464         **kwargs
    465     ):
--> 466         data_o = self.encode(data_input_ids, data_attention_mask)
    467 
    468         if data_type is not None and data_type == "meta":



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(468)forward()
    466         data_o = self.encode(data_input_ids, data_attention_mask)
    467 
--> 468         if data_type is not None and data_type == "meta":
    469             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    470         else:



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(471)forward()
    469             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    470         else:
--> 471             data_repr = self.dr(data_o[0], data_attention_mask)
    472 
    473         data_fused_repr = meta_repr = None



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(473)forward()
    471             data_repr = self.dr(data_o[0], data_attention_mask)
    472 
--> 473         data_fused_repr = meta_repr = None
    474         if data_aug_meta_prefix is not None:
    475             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(474)forward()
    472 
    473         data_fused_repr = meta_repr = None
--> 474         if data_aug_meta_prefix is not None:
    475             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    476             if len(meta_kwargs):



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(482)forward()
    480                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    481 
--> 482         return EncoderOutput(
    483             rep=data_repr,
    484             fused_rep=data_fused_repr,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(483)forward()
    481 
    482         return EncoderOutput(
--> 483             rep=data_repr,
    484             fused_rep=data_fused_repr,
    485             meta_repr=meta_repr,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(484)forward()
    482         return EncoderOutput(
    483             rep=data_repr,
--> 484             fused_rep=data_fused_repr,
    485             meta_repr=meta_repr,
    486         )



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(485)forward()
    483             rep=data_repr,
    484             fused_rep=data_fused_repr,
--> 485             meta_repr=meta_repr,
    486         )
    487 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(482)forward()
    480                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    481 
--> 482         return EncoderOutput(
    483             rep=data_repr,
    484             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...eta_repr=None)
> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(482)forward()
    480                 data_fused_repr = self.dr(data_fused_embed, data_attention_mask)
    481 
--> 482         return EncoderOutput(
    483             rep=data_repr,
    484             fused_rep=data_fused_repr,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1342)forward()
   1340                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
   1341 
-> 1342             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
   1343                                      plbl2data_data2ptr,plbl2data_idx)
   1344 



ipdb>  data_o.fused_rep.shape


torch.Size([10, 768])


ipdb>  lbl2data_o.rep.shape


torch.Size([10, 768])


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1343)forward()
   1341 
   1342             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
-> 1343                                      plbl2data_data2ptr,plbl2data_idx)
   1344 
   1345             if self.use_query_loss:



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1342)forward()
   1340                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
   1341 
-> 1342             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
   1343                                      plbl2data_data2ptr,plbl2data_idx)
   1344 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1345)forward()
   1343                                      plbl2data_data2ptr,plbl2data_idx)
   1344 
-> 1345             if self.use_query_loss:
   1346                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
   1347                                           plbl2data_data2ptr,plbl2data_idx)



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1346)forward()
   1344 
   1345             if self.use_query_loss:
-> 1346                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
   1347                                           plbl2data_data2ptr,plbl2data_idx)
   1348 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1347)forward()
   1345             if self.use_query_loss:
   1346                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
-> 1347                                           plbl2data_data2ptr,plbl2data_idx)
   1348 
   1349             if self.use_calib_loss:



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1346)forward()
   1344 
   1345             if self.use_query_loss:
-> 1346                 loss += self.compute_loss(data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
   1347                                           plbl2data_data2ptr,plbl2data_idx)
   1348 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1349)forward()
   1347                                           plbl2data_data2ptr,plbl2data_idx)
   1348 
-> 1349             if self.use_calib_loss:
   1350                 loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
   1351                                               plbl2data_data2ptr,plbl2data_idx)



ipdb>  self.use_calib_loss


True


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1350)forward()
   1348 
   1349             if self.use_calib_loss:
-> 1350                 loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
   1351                                               plbl2data_data2ptr,plbl2data_idx)
   1352 



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1351)forward()
   1349             if self.use_calib_loss:
   1350                 loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
-> 1351                                               plbl2data_data2ptr,plbl2data_idx)
   1352 
   1353             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1350)forward()
   1348 
   1349             if self.use_calib_loss:
-> 1350                 loss += self.calibration_loss(data_o.fused_rep, data_o.rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,
   1351                                               plbl2data_data2ptr,plbl2data_idx)
   1352 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1353)forward()
   1351                                               plbl2data_data2ptr,plbl2data_idx)
   1352 
-> 1353             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
   1354 
   1355             if self.use_fusion_loss:



ipdb>  data_o.fused_rep.shape


torch.Size([10, 768])


ipdb>  lbl2data_o.rep.shape


torch.Size([10, 768])


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1355)forward()
   1353             loss += self.compute_meta_loss(data_o.fused_rep, lbl2data_o.rep, **kwargs)
   1354 
-> 1355             if self.use_fusion_loss:
   1356                 loss += self.compute_fusion_loss(data_o.fused_rep, data_o.meta_repr, self.data_aug_meta_prefix, **kwargs)
   1357                 loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1359)forward()
   1357                 loss += self.compute_fusion_loss(lbl2data_o.rep, lbl2data_o.meta_repr, self.lbl2data_aug_meta_prefix, **kwargs)
   1358 
-> 1359         if not return_dict:
   1360             o = (data_o.logits,data_o.rep,data_o.fused_rep,lbl2data_o.logits,lbl2data_o.rep,lbl2data_o.fused_rep)
   1361             return ((loss,) + o) if loss is not None else o



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1364)forward()
   1362 
   1363 
-> 1364         return RADOutput(
   1365             loss=loss,
   1366 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1365)forward()
   1363 
   1364         return RADOutput(
-> 1365             loss=loss,
   1366 
   1367             data_repr=data_o.rep,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1367)forward()
   1365             loss=loss,
   1366 
-> 1367             data_repr=data_o.rep,
   1368             data_fused_repr=data_o.fused_rep,
   1369 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1368)forward()
   1366 
   1367             data_repr=data_o.rep,
-> 1368             data_fused_repr=data_o.fused_rep,
   1369 
   1370             lbl2data_repr=lbl2data_o.rep,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1370)forward()
   1368             data_fused_repr=data_o.fused_rep,
   1369 
-> 1370             lbl2data_repr=lbl2data_o.rep,
   1371             lbl2data_fused_repr=lbl2data_o.fused_rep,
   1372         )



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1371)forward()
   1369 
   1370             lbl2data_repr=lbl2data_o.rep,
-> 1371             lbl2data_fused_repr=lbl2data_o.fused_rep,
   1372         )
   1373 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1364)forward()
   1362 
   1363 
-> 1364         return RADOutput(
   1365             loss=loss,
   1366 



ipdb>  


--Return--
RADOutput(los...sed_repr=None)
> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga.py(1364)forward()
   1362 
   1363 
-> 1364         return RADOutput(
   1365             loss=loss,
   1366 



ipdb>  


--Return--
RADOutput(los...sed_repr=None)
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1511)_wrapped_call_impl()
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):



ipdb>  


> /tmp/ipykernel_40248/676308714.py(45)forward()
     43                                    lbl2data_idx=lbl2data_idx, **kwargs)
     44 
---> 45         loss = None
     46         if lbl2data_idx is not None and student_o.loss is not None:
     47             with torch.no_grad():



ipdb>  


> /tmp/ipykernel_40248/676308714.py(46)forward()
     44 
     45         loss = None
---> 46         if lbl2data_idx is not None and student_o.loss is not None:
     47             with torch.no_grad():
     48                 teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)



ipdb>  


> /tmp/ipykernel_40248/676308714.py(47)forward()
     45         loss = None
     46         if lbl2data_idx is not None and student_o.loss is not None:
---> 47             with torch.no_grad():
     48                 teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
     49 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(48)forward()
     46         if lbl2data_idx is not None and student_o.loss is not None:
     47             with torch.no_grad():
---> 48                 teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
     49 
     50             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(50)forward()
     48                 teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
     49 
---> 50             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     51                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     52 



ipdb>  teacher_o.data_repr.shape


torch.Size([10, 768])


ipdb>  student_o.lbl2data_repr.shape


torch.Size([10, 768])


ipdb>  n


> /tmp/ipykernel_40248/676308714.py(51)forward()
     49 
     50             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
---> 51                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     52 
     53             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(50)forward()
     48                 teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
     49 
---> 50             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     51                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     52 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(51)forward()
     49 
     50             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
---> 51                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     52 
     53             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(50)forward()
     48                 teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)
     49 
---> 50             tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     51                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     52 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(53)forward()
     51                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     52 
---> 53             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     54                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     55 



ipdb>  tdsl_loss


tensor(0.1338, device='cuda:0', grad_fn=<DivBackward0>)


ipdb>  n


> /tmp/ipykernel_40248/676308714.py(54)forward()
     52 
     53             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
---> 54                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     55 
     56             dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)



ipdb>  student_o.data_fused_repr.shape


torch.Size([10, 768])


ipdb>  teacher_o.lbl2data_repr.shape


torch.Size([10, 768])


ipdb>  n


> /tmp/ipykernel_40248/676308714.py(53)forward()
     51                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     52 
---> 53             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     54                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     55 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(54)forward()
     52 
     53             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
---> 54                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     55 
     56             dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)



ipdb>  


> /tmp/ipykernel_40248/676308714.py(53)forward()
     51                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     52 
---> 53             sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
     54                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     55 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(56)forward()
     54                                          kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
     55 
---> 56             dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)
     57             lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
     58 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(57)forward()
     55 
     56             dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)
---> 57             lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
     58 
     59             loss = student_o.loss



ipdb>  


> /tmp/ipykernel_40248/676308714.py(59)forward()
     57             lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
     58 
---> 59             loss = student_o.loss
     60             loss += self.teacher_data_student_label_loss_weight * tdsl_loss
     61             loss += self.student_data_teacher_label_loss_weight * sdtl_loss



ipdb>  student_o.loss


tensor(0.0424, device='cuda:0', grad_fn=<AddBackward0>)


ipdb>  n


> /tmp/ipykernel_40248/676308714.py(60)forward()
     58 
     59             loss = student_o.loss
---> 60             loss += self.teacher_data_student_label_loss_weight * tdsl_loss
     61             loss += self.student_data_teacher_label_loss_weight * sdtl_loss
     62             loss += self.data_mse_loss_weight * dm_loss + self.label_mse_loss_weight * lm_loss



ipdb>  


> /tmp/ipykernel_40248/676308714.py(61)forward()
     59             loss = student_o.loss
     60             loss += self.teacher_data_student_label_loss_weight * tdsl_loss
---> 61             loss += self.student_data_teacher_label_loss_weight * sdtl_loss
     62             loss += self.data_mse_loss_weight * dm_loss + self.label_mse_loss_weight * lm_loss
     63 



ipdb>  self.teacher_data_student_label_loss_weight


1.0


ipdb>  tdsl_loss


tensor(0.1338, device='cuda:0', grad_fn=<DivBackward0>)


ipdb>  sdtl_loss


tensor(0.2057, device='cuda:0', grad_fn=<DivBackward0>)


ipdb>  n


> /tmp/ipykernel_40248/676308714.py(62)forward()
     60             loss += self.teacher_data_student_label_loss_weight * tdsl_loss
     61             loss += self.student_data_teacher_label_loss_weight * sdtl_loss
---> 62             loss += self.data_mse_loss_weight * dm_loss + self.label_mse_loss_weight * lm_loss
     63 
     64 



ipdb>  dm_loss


tensor(1.0163, device='cuda:0', grad_fn=<MseLossBackward0>)


ipdb>  lm_loss


tensor(1.0100, device='cuda:0', grad_fn=<MseLossBackward0>)


ipdb>  n


> /tmp/ipykernel_40248/676308714.py(65)forward()
     63 
     64 
---> 65         return RADOutput(
     66             loss=loss,
     67 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(66)forward()
     64 
     65         return RADOutput(
---> 66             loss=loss,
     67 
     68             data_repr=student_o.data_repr,



ipdb>  


> /tmp/ipykernel_40248/676308714.py(68)forward()
     66             loss=loss,
     67 
---> 68             data_repr=student_o.data_repr,
     69             data_fused_repr=student_o.data_fused_repr,
     70 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(69)forward()
     67 
     68             data_repr=student_o.data_repr,
---> 69             data_fused_repr=student_o.data_fused_repr,
     70 
     71             lbl2data_repr=student_o.lbl2data_repr,



ipdb>  


> /tmp/ipykernel_40248/676308714.py(71)forward()
     69             data_fused_repr=student_o.data_fused_repr,
     70 
---> 71             lbl2data_repr=student_o.lbl2data_repr,
     72             lbl2data_fused_repr=student_o.lbl2data_fused_repr,
     73         )



ipdb>  


> /tmp/ipykernel_40248/676308714.py(72)forward()
     70 
     71             lbl2data_repr=student_o.lbl2data_repr,
---> 72             lbl2data_fused_repr=student_o.lbl2data_fused_repr,
     73         )
     74 



ipdb>  


> /tmp/ipykernel_40248/676308714.py(65)forward()
     63 
     64 
---> 65         return RADOutput(
     66             loss=loss,
     67 



ipdb>  


--Return--
RADOutput(los...sed_repr=None)
> /tmp/ipykernel_40248/676308714.py(65)forward()
     63 
     64 
---> 65         return RADOutput(
     66             loss=loss,
     67 



ipdb>  


--Return--
RADOutput(los...sed_repr=None)
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1520)_call_impl()
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:



ipdb>  


--Return--
RADOutput(los...sed_repr=None)
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/nn/modules/module.py(1511)_wrapped_call_impl()
   1509             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510         else:
-> 1511             return self._call_impl(*args, **kwargs)
   1512 
   1513     def _call_impl(self, *args, **kwargs):



ipdb>  


--Return--
None
> /tmp/ipykernel_40248/1822690529.py(1)<module>()
----> 1 o = m(**b)



ipdb>  


    [... skipped 1 hidden frame]

> /home/scai/phd/aiz218323/.local/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3553)run_code()
   3551             finally:
   3552                 # Reset our crash handler in place
-> 3553                 sys.excepthook = old_excepthook
   3554         except SystemExit as e:
   3555             if result is not None:



ipdb>  


    [... skipped 1 hidden frame]

> /home/scai/phd/aiz218323/.local/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3574)run_code()
   3572             self.showtraceback(running_compiled_code=True)
   3573         else:
-> 3574             outflag = False
   3575         return outflag
   3576 



ipdb>  c


    [... skipped 1 hidden frame]

    [... skipped 1 hidden frame]



In [None]:
o.loss

### `DTL004`

In [None]:
#| export
class DTL004(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        student_data_teacher_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        label_mse_loss_weight:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        store_attr('teacher_data_student_label_loss_weight,student_data_teacher_label_loss_weight')
        store_attr('data_mse_loss_weight,label_mse_loss_weight')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                                   lbl2data_idx=lbl2data_idx, **kwargs)

        loss = None
        if lbl2data_idx is not None and student_o.loss is not None:
            teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)

            tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
            
            sdtl_loss = self.rep_loss_fn(student_o.data_fused_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)
            lm_loss = self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
            
            loss = student_o.loss
            loss += self.teacher_data_student_label_loss_weight * tdsl_loss
            loss += self.student_data_teacher_label_loss_weight * sdtl_loss
            loss += self.data_mse_loss_weight * dm_loss + self.label_mse_loss_weight * lm_loss
            

        return RADOutput(
            loss=loss,
            
            data_repr=student_o.data_repr,
            data_fused_repr=student_o.data_fused_repr,
            
            lbl2data_repr=student_o.lbl2data_repr,
            lbl2data_fused_repr=student_o.lbl2data_fused_repr,
        )
        

#### Example

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
m_teacher = TCH002.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

m_teacher.freeze_embeddings()

Some weights of TCH002 were not initialized from the model checkpoint at /home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4/teacher and are newly initialized: ['data_embeddings.weight', 'lbl_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
m_student = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-cos-v5', batch_size=1000, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=True,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()

model.encoder.set_meta_embeddings(torch.zeros(block.train.dset.meta['lnk_meta'].n_meta, model.config.dim))

Some weights of OAK001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-cos-v5 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enco

In [None]:
model = DTL004(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=1024, margin=0.3, tau=0.1, n_negatives=10, 
               apply_softmax=True, teacher_data_student_label_loss_weight=1.0, student_data_teacher_label_loss_weight=0.1, 
               data_mse_loss_weight=0.1,label_mse_loss_weight=0.1)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch, m_args=['lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'lnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 
                                        'lnk2data_attention_mask', 'plnk2data_data2ptr', 'plnk2data_idx'
                                       ])

In [None]:
m,b = model.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(0.1536, device='cuda:0', grad_fn=<AddBackward0>)

### `DTL005`

In [None]:
#| export
class DTL005(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        teacher_data_student_label_loss_weight:Optional[float]=1.0,
        data_mse_loss_weight:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        store_attr('teacher_data_student_label_loss_weight,data_mse_loss_weight')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, **kwargs)

        loss = None
        if student_o.loss is not None:
            teacher_o = self.m_teacher(data_idx=data_idx)

            tdsl_loss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], kwargs['lbl2data_idx'], 
                                         kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            dm_loss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_fused_repr)
            
            loss = student_o.loss
            loss += self.teacher_data_student_label_loss_weight * tdsl_loss
            loss += self.data_mse_loss_weight * dm_loss

        return RADOutput(
            loss=loss,
            
            data_repr=student_o.data_repr,
            data_fused_repr=student_o.data_fused_repr,
            
            lbl2data_repr=student_o.lbl2data_repr,
            lbl2data_fused_repr=student_o.lbl2data_fused_repr,
        )
        

#### Example

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
m_teacher = TCH003.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

m_teacher.freeze_embeddings()

In [None]:
m_student = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-cos-v5', batch_size=1000, num_batch_labels=5000, 
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=True,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)

m_student.init_retrieval_head()
m_student.init_cross_head()

m_student.encoder.set_meta_embeddings(torch.zeros(block.train.dset.meta['lnk_meta'].n_meta, m_student.config.dim))

Some weights of OAK001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-cos-v5 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_embeddings.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'enco

In [None]:
model = DTL005(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=1024, margin=0.3, tau=0.1, n_negatives=10, 
               apply_softmax=True, teacher_data_student_label_loss_weight=1.0, data_mse_loss_weight=0.1)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch, m_args=['lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx',
                                        'lnk2data_data2ptr', 'lnk2data_idx', 'lnk2data_input_ids', 
                                        'lnk2data_attention_mask', 'plnk2data_data2ptr', 'plnk2data_idx'
                                       ])

In [None]:
m,b = model.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

In [None]:
o.loss

AttributeError: 'NoneType' object has no attribute 'loss'