In [None]:
#| default_exp 77-distillation-for-wikiseealso-with-ramen-1-0

In [None]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np
from transformers import DistilBertConfig

from xcai.basics import *
from xcai.models.PPP0XX import DBT021
from xcai.models.distillation import DTL002,TCH001

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='xc-nlg_69-distillation-for-wikiseealso'

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

block = XCBlock.from_cfg(data_dir, 'data_meta', tfm='rm', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data|cat2lbl2data', 1, (2,1)), ('cat2data',1,1)])

In [None]:
#| export
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-meta_distilbert-base-uncased_rm_distil-ramen-cat-2.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [None]:
#| export
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
#| export
with open(f'{pkl_dir}/processed/corelations.pkl', 'rb') as file: data_corel, lbl_corel = pickle.load(file)

In [None]:
#| export
from xcai.data import MetaXCDataset, XCDataset
from scipy import sparse

lco_meta = MetaXCDataset('lco', sparse.csr_matrix((block.train.dset.n_data, block.n_lbl)), lbl_corel, 
                         block.train.dset.data.lbl_info)

block.train.dset.meta['lco_meta'] = lco_meta

In [None]:
#| export
smp_features=[('lbl2data|cat2lbl2data|lco2lbl2data',1,(2,1,1)), ('cat2data',1,1)]
block.collator.tfms.tfms[0].smp_features = smp_features

## Training

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/77-distillation-for-wikiseealso-with-ramen-1-0',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=3000,
    save_steps=3000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
    label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 
                 'lco2lbl2data_idx', 'lco2lbl2data_input_ids', 'lco2lbl2data_attention_mask', 
                 'cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask'],
)

In [None]:
#| export
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
m_teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)

In [None]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

m_student = DBT021.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=bsz, tn_targ=1000, margin=0.3, tau=0.1, 
                                   apply_softmax=True, n_negatives=10, m_lw=0.2, data_meta_prefix='cat2data', 
                                   lbl2data_meta_prefix='lco2lbl', use_encoder_parallel=False, task_repr_type='pool', meta_repr_type='pool')

m_student.init_dr_head()

Some weights of DBT021 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight', 'encoder.meta_layer_norm.bias', 'encoder.meta_layer_norm.weight', 'encoder.meta_projector.bias', 'encoder.meta_projector.weight', 'encoder.meta_transform.bias', 'encoder.meta_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
#| export
model = DTL002(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, 
               n_negatives=10, apply_softmax=True, distil_loss_weight=1.0, mse_loss_weight=0.1)

In [None]:
#| export
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
#| export
if __name__ == '__main__':
    mp.freeze_support()
    learn.train()

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,P@1,P@10,P@3,P@5,N@1,N@10,N@3,N@5,Psp@1,Psp@10,Psp@3,Psp@5,Psn@1,Psn@10,Psn@3,Psn@5,R@200,R@10,R@100
10,0.0788,0.086276,0.175101,0.056812,0.115025,0.087132,0.175101,0.194849,0.173305,0.180583,0.163741,0.209324,0.168619,0.180018,0.163741,0.201765,0.174354,0.185467,0.429889,0.235065,0.384685


  0%|          | 0/15617 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


  0%|          | 0/15617 [00:00<?, ?it/s]

In [None]:
o = learn.predict(block.test.dset)

  0%|          | 0/196 [00:00<?, ?it/s]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


  self._set_arrayXarray(i, j, x)


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,17.296,11.305,8.5469,5.5415,17.296,17.0792,17.7819,19.1372,16.3694,16.7477,17.8326,20.6162,16.3694,17.404,18.4949,20.0678,22.993,37.2897,41.666,0.9435,120.816,1469.3,0.919
