In [None]:
#| default_exp models.distillation

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import torch, numpy as np
from typing import Optional
import torch.nn as nn
from dataclasses import dataclass

from xcai.core import store_attr
from xcai.losses import Cosine, MultiTriplet
from xcai.models.PPP0XX import XCModelOutput
from transformers import DistilBertPreTrainedModel,DistilBertConfig
from transformers.utils.generic import ModelOutput

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
import os,torch, pickle, numpy as np

from xcai.block import *
from xcai.basics import *
from xcai.models.PPP0XX import DBT010

## Setup

In [None]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-metas_distilbert-base-uncased_rm_radga-aug-cat-hlk-block-032.pkl'

In [None]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
block.train.dset.data.data_info['aug_input_ids'] = block.train.dset.data.data_info['input_ids_aug_cat']
block.train.dset.data.data_info['aug_attention_mask'] = block.train.dset.data.data_info['attention_mask_aug_cat']
block.test.dset.data.data_info['aug_input_ids'] = block.test.dset.data.data_info['input_ids_aug_cat']
block.test.dset.data.data_info['aug_attention_mask'] = block.test.dset.data.data_info['attention_mask_aug_cat']

In [None]:
block.train.dset.data.data_info['input_ids'] = block.train.dset.data.data_info['input_ids_aug_cat']
block.train.dset.data.data_info['attention_mask'] = block.train.dset.data.data_info['attention_mask_aug_cat']
block.test.dset.data.data_info['input_ids'] = block.test.dset.data.data_info['input_ids_aug_cat']
block.test.dset.data.data_info['attention_mask'] = block.test.dset.data.data_info['attention_mask_aug_cat']

## Teacher

### Base

In [None]:
from safetensors import safe_open

In [None]:
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/69-distillation-for-wikiseealso-1-1',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=3000,
    save_steps=3000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
    label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask'],
)

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-1'
output_dir = f"/home/scai/phd/aiz218323/scratch/outputs/{os.path.basename(model_output)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

m_teacher = DBT010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=800, tn_targ=5000, margin=0.3, tau=0.1, 
                                   n_negatives=10, apply_softmax=True, use_encoder_parallel=True)

model_weight_file,model_weights = f'{mname}/model.safetensors',{}
with safe_open(model_weight_file, framework="pt") as file:
    for k in file.keys(): model_weights[k] = file.get_tensor(k)

m_teacher.load_state_dict(model_weights, strict=False)

Some weights of DBT010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


_IncompatibleKeys(missing_keys=['distilbert.embeddings.word_embeddings.weight', 'distilbert.embeddings.position_embeddings.weight', 'distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.0.attention.q_lin.weight', 'distilbert.transformer.layer.0.attention.q_lin.bias', 'distilbert.transformer.layer.0.attention.k_lin.weight', 'distilbert.transformer.layer.0.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.v_lin.weight', 'distilbert.transformer.layer.0.attention.v_lin.bias', 'distilbert.transformer.layer.0.attention.out_lin.weight', 'distilbert.transformer.layer.0.attention.out_lin.bias', 'distilbert.transformer.layer.0.sa_layer_norm.weight', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transformer.layer.0.ffn.lin1.weight', 'distilbert.transformer.layer.0.ffn.lin1.bias', 'distilbert.transformer.layer.0.ffn.lin2.weight', 'distilbert.transformer.layer.0.ffn.lin2.bias', 'distilbert.transformer.layer.

In [None]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
learn = XCLearner(
    model=m_teacher, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
o = learn.predict(block.test.dset)

  0%|          | 0/196 [00:00<?, ?it/s]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


  self._set_arrayXarray(i, j, x)


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,39.1787,25.1231,18.8719,11.9556,39.1787,38.7991,40.2252,42.7056,27.7771,30.1143,32.8007,38.2846,27.7771,31.0783,33.3646,36.2856,49.0181,66.5973,70.1792,0.0198,343.8718,516.224,0.323


### Query-label representation

In [None]:
dataset = learn.train_dataset.data_dset
dataloader = learn.get_test_dataloader(dataset)
data_repr = learn.get_representation(dataloader, representation_attribute='data_repr')

  0%|          | 0/434 [00:00<?, ?it/s]

In [None]:
dataset = learn.train_dataset.lbl_dset
dataloader = learn.get_test_dataloader(dataset)
lbl_repr = learn.get_representation(dataloader, representation_attribute='data_repr')

  0%|          | 0/196 [00:00<?, ?it/s]

### `TCH001`

In [None]:
#| export
@dataclass
class TCHOutput(ModelOutput):
    data_repr: Optional[torch.FloatTensor] = None
    lbl2data_repr: Optional[torch.FloatTensor] = None
    

In [None]:
#| export
class TCH001(DistilBertPreTrainedModel):

    def __init__(self, config, n_data:int, n_lbl:int, **kwargs):
        super().__init__(config, **kwargs)
        store_attr('n_data,n_lbl')
        self.data_repr = nn.Embedding(self.n_data, config.dim)
        self.lbl_repr = nn.Embedding(self.n_lbl, config.dim)

    def init_embeddings(self, data_repr:torch.Tensor, lbl_repr:torch.Tensor):
        self.data_repr.data = data_repr
        self.lbl_repr.data = lbl_repr

    def forward(
        self,
        data_idx:torch.Tensor,
        lbl2data_idx:torch.Tensor,
    ):
        return TCHOutput(
            data_repr=self.data_repr(data_idx),
            lbl2data_repr= self.lbl_repr(lbl2data_idx),
        )
        

#### Example

In [None]:
model = TCH001(DistilBertConfig(), n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

In [None]:
model.init_embeddings(data_repr, lbl_repr)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch)

In [None]:
o = model(**b)

In [None]:
o.data_repr.shape, o.lbl2data_repr.shape

(torch.Size([10, 768]), torch.Size([10, 768]))

In [None]:
model.save_pretrained(f'{model_output}/teacher')

In [None]:
m_teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

## Distillation

### `DTL001`

In [None]:
#| export
class DTL001(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert,m_teacher.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        embed_sim_loss_weight:Optional[float]=1.0,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        store_attr('m_student,m_teacher')
        self.s_lw = embed_sim_loss_weight
        
        self.loss_fn = Cosine(reduce='mean')

    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        data_aug_input_ids:Optional[torch.Tensor]=None,
        data_aug_attention_mask:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, **kwargs)

        loss = None
        if data_aug_input_ids is not None and student_o.loss is not None:
            with torch.no_grad(): 
                teacher_o = self.m_teacher(data_input_ids=data_aug_input_ids, data_attention_mask=data_aug_attention_mask, **kwargs)

            dloss = self.loss_fn(student_o.data_embed, data_attention_mask, teacher_o.data_embed, data_aug_attention_mask)
            dloss += self.loss_fn(student_o.lbl2data_embed, kwargs['lbl2data_attention_mask'], 
                                  teacher_o.lbl2data_embed, kwargs['lbl2data_attention_mask'])
            loss = student_o.loss + self.s_lw * dloss
            
        return XCModelOutput(
            loss=loss,
            data_repr=student_o.data_repr,
            lbl2data_repr=student_o.lbl2data_repr,
        )
        

#### Example

In [None]:
from safetensors import safe_open

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-1'
output_dir = f"/home/scai/phd/aiz218323/scratch/outputs/{os.path.basename(model_output)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

m_teacher = DBT010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=800, tn_targ=5000, margin=0.3, tau=0.1, 
                                   n_negatives=10, apply_softmax=True, use_encoder_parallel=True)

model_weight_file,model_weights = f'{mname}/model.safetensors',{}
with safe_open(model_weight_file, framework="pt") as file:
    for k in file.keys(): model_weights[k] = file.get_tensor(k)

m_teacher.load_state_dict(model_weights, strict=False)

Some weights of DBT010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


_IncompatibleKeys(missing_keys=['distilbert.embeddings.word_embeddings.weight', 'distilbert.embeddings.position_embeddings.weight', 'distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.0.attention.q_lin.weight', 'distilbert.transformer.layer.0.attention.q_lin.bias', 'distilbert.transformer.layer.0.attention.k_lin.weight', 'distilbert.transformer.layer.0.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.v_lin.weight', 'distilbert.transformer.layer.0.attention.v_lin.bias', 'distilbert.transformer.layer.0.attention.out_lin.weight', 'distilbert.transformer.layer.0.attention.out_lin.bias', 'distilbert.transformer.layer.0.sa_layer_norm.weight', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transformer.layer.0.ffn.lin1.weight', 'distilbert.transformer.layer.0.ffn.lin1.bias', 'distilbert.transformer.layer.0.ffn.lin2.weight', 'distilbert.transformer.layer.0.ffn.lin2.bias', 'distilbert.transformer.layer.

In [None]:
m_student = DBT010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=800, tn_targ=5000, margin=0.3, tau=0.1, 
                                       n_negatives=10, apply_softmax=True, use_encoder_parallel=True)
m_student.init_dr_head()

Some weights of DBT010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = DTL001(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, embed_sim_loss_weight=1.0)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch, m_args=['lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx'])

In [None]:
m,b = model.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
o.loss

tensor(1.6006, device='cuda:0', grad_fn=<AddBackward0>)

### `DTL002`

In [None]:
#| export
class DTL002(DistilBertPreTrainedModel):
    use_representation,use_generation = True,False
    _tied_weights_keys = ["m_student.encoder.distilbert"]
    
    def __init__(
        self,
        config,
        m_student:nn.Module,
        m_teacher:nn.Module,
        bsz:Optional[int]=None,
        tn_targ:Optional[int]=None,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=False,
        n_negatives:Optional[int]=5,
        distil_loss_weight:Optional[float]=1.0,
        mse_loss_weight:Optional[float]=0.1,
        **kwargs
    ):
        super().__init__(config, **kwargs)
        self.d_lw,self.m_lw = distil_loss_weight,mse_loss_weight
        store_attr('m_student,m_teacher')
        self.mse_loss_fn = nn.MSELoss()
        self.rep_loss_fn = MultiTriplet(bsz=bsz, tn_targ=tn_targ, margin=margin, n_negatives=n_negatives, tau=tau, 
                                        apply_softmax=apply_softmax, reduce='mean')
        
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        
        data_idx:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        **kwargs
    ):
        student_o = self.m_student(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
                                   lbl2data_idx=lbl2data_idx, **kwargs)

        loss = None
        if lbl2data_idx is not None and student_o.loss is not None:
            with torch.no_grad(): 
                teacher_o = self.m_teacher(data_idx=data_idx, lbl2data_idx=lbl2data_idx)

            dloss = self.rep_loss_fn(teacher_o.data_repr, student_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                     kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)
            
            dloss += self.rep_loss_fn(student_o.data_repr, teacher_o.lbl2data_repr, kwargs['lbl2data_data2ptr'], lbl2data_idx, 
                                      kwargs['plbl2data_data2ptr'], kwargs['plbl2data_idx'], **kwargs)

            mloss = self.mse_loss_fn(teacher_o.data_repr, student_o.data_repr) + self.mse_loss_fn(teacher_o.lbl2data_repr, student_o.lbl2data_repr)
            
            loss = student_o.loss + self.d_lw * dloss + self.m_lw * mloss
            
        return XCModelOutput(
            loss=loss,
            data_repr=student_o.data_repr,
            lbl2data_repr=student_o.lbl2data_repr,
        )
        

#### Example

In [None]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-1'
m_teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

In [None]:
m_student = DBT010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=800, tn_targ=5000, margin=0.3, tau=0.1, 
                                   n_negatives=10, apply_softmax=True, use_encoder_parallel=True)
m_student.init_dr_head()

Some weights of DBT010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model = DTL002(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=1024, margin=0.3, tau=0.1, 
               n_negatives=10, apply_softmax=True, distil_loss_weight=1.0, mse_loss_weight=0.1)

In [None]:
batch = next(iter(block.train.dl))
b = prepare_batch(model, batch, m_args=['lbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 
                                        'lbl2data_attention_mask', 'plbl2data_data2ptr', 'plbl2data_idx'])

In [None]:
b.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_input_ids', 'data_attention_mask', 'data_idx'])

In [None]:
m,b = model.to('cuda'), b.to('cuda')

In [None]:
o = m(**b)

In [None]:
o.loss

tensor(0.6356, device='cuda:0', grad_fn=<AddBackward0>)