In [None]:
#| default_exp 69-distillation-for-wikiseealso-1-0

In [None]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np
from transformers import DistilBertConfig

from xcai.basics import *
from xcai.models.PPP0XX import DBT010
from xcai.models.distillation import DTL001

In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='xc-nlg_69-distillation-for-wikiseealso'

In [None]:
#| export
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-metas_distilbert-base-uncased_rm_radga-aug-cat-hlk-block-032.pkl'

In [None]:
#| export
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
block.train.dset.data.data_info['input_ids'] = block.train.dset.data.data_info['input_ids_aug_cat']
block.train.dset.data.data_info['attention_mask'] = block.train.dset.data.data_info['attention_mask_aug_cat']
block.test.dset.data.data_info['input_ids'] = block.test.dset.data.data_info['input_ids_aug_cat']
block.test.dset.data.data_info['attention_mask'] = block.test.dset.data.data_info['attention_mask_aug_cat']

In [None]:
#| export
block.train.dset.data.data_info['aug_input_ids'] = block.train.dset.data.data_info['input_ids_aug_cat']
block.train.dset.data.data_info['aug_attention_mask'] = block.train.dset.data.data_info['attention_mask_aug_cat']
block.test.dset.data.data_info['aug_input_ids'] = block.test.dset.data.data_info['input_ids_aug_cat']
block.test.dset.data.data_info['aug_attention_mask'] = block.test.dset.data.data_info['attention_mask_aug_cat']

## Models

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/69-distillation-for-wikiseealso-1-0',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=3000,
    save_steps=3000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    representation_search_type='INDEX',
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
    label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask'],
)

In [None]:
#| export
from safetensors import safe_open

In [None]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-1'
output_dir = f"/home/scai/phd/aiz218323/scratch/outputs/{os.path.basename(model_output)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

m_teacher = DBT010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, 
                                   n_negatives=10, apply_softmax=True, use_encoder_parallel=True)

model_weight_file,model_weights = f'{mname}/model.safetensors',{}
with safe_open(model_weight_file, framework="pt") as file:
    for k in file.keys(): model_weights[k] = file.get_tensor(k)

m_teacher.load_state_dict(model_weights, strict=False)

Some weights of DBT010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


_IncompatibleKeys(missing_keys=['distilbert.embeddings.word_embeddings.weight', 'distilbert.embeddings.position_embeddings.weight', 'distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.0.attention.q_lin.weight', 'distilbert.transformer.layer.0.attention.q_lin.bias', 'distilbert.transformer.layer.0.attention.k_lin.weight', 'distilbert.transformer.layer.0.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.v_lin.weight', 'distilbert.transformer.layer.0.attention.v_lin.bias', 'distilbert.transformer.layer.0.attention.out_lin.weight', 'distilbert.transformer.layer.0.attention.out_lin.bias', 'distilbert.transformer.layer.0.sa_layer_norm.weight', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transformer.layer.0.ffn.lin1.weight', 'distilbert.transformer.layer.0.ffn.lin1.bias', 'distilbert.transformer.layer.0.ffn.lin2.weight', 'distilbert.transformer.layer.0.ffn.lin2.bias', 'distilbert.transformer.layer.

In [None]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

m_student = DBT010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, 
                                   n_negatives=10, apply_softmax=True, use_encoder_parallel=True)
m_student.init_dr_head()

Some weights of DBT010 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
#| export
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
learn = XCLearner(
    model=m_teacher, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

In [None]:
o = learn.predict(block.test.dset)

  0%|          | 0/196 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,17.2622,11.2798,8.5273,5.5284,17.2622,17.0362,17.7345,19.0837,16.3327,16.7056,17.7868,20.5618,16.3327,17.3546,18.4396,20.0053,22.9164,37.1221,41.469,0.0335,252.57,702.835,0.439


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,39.1849,25.1252,18.8724,11.9554,39.1849,38.8025,40.2272,42.7062,27.787,30.1188,32.8063,38.2856,27.787,31.084,33.3704,36.2891,49.0164,66.6496,70.3424,0.0198,702.0258,252.861,0.158


## `Distillation`

In [None]:
#| export
model = DTL001(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, embed_sim_loss_weight=1.0)

In [None]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
learn.train()

Step,Training Loss,Validation Loss,P@1,P@10,P@3,P@5,N@1,N@10,N@3,N@5,Psp@1,Psp@10,Psp@3,Psp@5,Psn@1,Psn@10,Psn@3,Psn@5,R@200,R@10,R@100
2,1.8858,1.822938,0.14,0.044,0.076667,0.064,0.14,0.138012,0.116614,0.124257,0.140339,0.171597,0.11926,0.141005,0.140339,0.153391,0.1274,0.139141,0.37594,0.167429,0.323345
4,1.8858,1.819432,0.14,0.046,0.08,0.064,0.14,0.141294,0.117651,0.122745,0.140339,0.177766,0.123546,0.141005,0.140339,0.155641,0.12775,0.137061,0.379607,0.180762,0.321643


  0%|          | 0/196 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


  0%|          | 0/196 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


  0%|          | 0/196 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
valid_dset = block.test.sample(n=100)

In [None]:
#| export
metric = PrecRecl(block.n_lbl, valid_dset.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
o = learn.predict(valid_dset.dset)

  0%|          | 0/196 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


In [None]:
#| export
if __name__ == '__main__':
    mp.freeze_support()
    learn.train()