In [1]:
#| default_exp 29_distillation-for-wikiseealsotitles-with-dexa

In [2]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [21]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np
from tqdm.auto import tqdm
from transformers import DistilBertConfig

from xcai.basics import *
from xcai.data import XCDataBlock
from xcai.models.PPP0XX import DBT009
from xcai.models.dexa import DEX001, DEX002
from xcai.clustering.cluster import BalancedClusters

from xclib.utils.sparse import retain_topk

from fastcore.utils import *

In [4]:
os.environ['WANDB_MODE'] = 'disabled'

In [5]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='medic_00-wikiseealsotitles'

## Load data

In [6]:
build_block = False
output_dir = '/home/scai/phd/aiz218323/scratch/outputs/medic/28_dexa-for-wikiseealsotitles'

""" Load data """
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_distilbert-base-uncased_xcs.pkl'

if build_block:
    data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'
    block = XCBlock.from_cfg(data_dir, 'data', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                             sampling_features=[('lbl2data',4)], oversample=False)
    with open(pkl_file, 'wb') as file: pickle.dump(block, file)
else:
    with open(pkl_file, 'rb') as file: block = pickle.load(file)

block.collator.tfms.tfms[0].sampling_features = [('lbl2data',4)]
block.collator.tfms.tfms[0].oversample = True

## Model

In [7]:
args = XCLearningArguments(
    output_dir=output_dir,
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
)

In [8]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
model = DBT009.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, 
                               n_negatives=10, apply_softmax=True, use_encoder_parallel=True)
model.init_dr_head()

Some weights of DBT009 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [10]:
lbl_repr = learn._get_lbl_representation(block.train.dset, to_cpu=True)

  0%|          | 0/196 [00:00<?, ?it/s]

In [11]:
clusters = BalancedClusters.proc(lbl_repr.half(), min_cluster_sz=3)

Updating clusters with size 3
Tree depth = 17
doing random split
lengths: [156165, 156165]
remaining levels for GPU split=16
==> gpu splitting random clusters 0 to 2
 rank=0 => Total clusters 2	Avg. Cluster size                 78082.50	Time to split nodes on this level 0.21 sec
 rank=0 => Total clusters 4	Avg. Cluster size                 39041.25	Time to split nodes on this level 0.04 sec
 rank=0 => Total clusters 8	Avg. Cluster size                 19520.62	Time to split nodes on this level 0.03 sec
 rank=0 => Total clusters 16	Avg. Cluster size                 9760.31	Time to split nodes on this level 0.04 sec
 rank=0 => Total clusters 32	Avg. Cluster size                 4880.16	Time to split nodes on this level 0.07 sec
 rank=0 => Total clusters 64	Avg. Cluster size                 2440.08	Time to split nodes on this level 0.16 sec
 rank=1 => Total clusters 2	Avg. Cluster size                 78082.50	Time to split nodes on this level 0.17 sec
 rank=1 => Total clusters 4	Avg. Clu

In [12]:
lbl_remap = torch.zeros(block.n_lbl, dtype=torch.int64)
for i,o in enumerate(clusters): lbl_remap[o] = i

In [13]:
model = DEX001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=bsz, num_batch_labels=5000,
                               margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True, use_encoder_parallel=False,
                               n_labels=block.n_lbl, n_clusters=len(clusters))
model.init_retrieval_head()
model.init_label_embeddings()

Some weights of DEX001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'label_embeddings.weight', 'label_remap']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
model.set_label_remap(lbl_remap)

## Driver

In [31]:
#| export
def get_label_remap(output_dir:str, block:XCDataBlock, cluster_sz:int=3):
    
    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
        
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
    )

    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
    model = DBT009.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, 
                                   n_negatives=10, apply_softmax=True, use_encoder_parallel=True)
    model.init_dr_head()

    learn = XCLearner(
        model=model, 
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
    )

    lbl_repr = learn._get_lbl_representation(block.train.dset, to_cpu=True)
    clusters = BalancedClusters.proc(lbl_repr.half(), min_cluster_sz=cluster_sz)

    lbl_remap = torch.zeros(block.n_lbl, dtype=torch.int64)
    for i,o in enumerate(clusters): lbl_remap[o] = i

    return lbl_remap, len(clusters)
    

In [None]:
#| export
if __name__ == '__main__':
    build_block = False
    output_dir = '/home/scai/phd/aiz218323/scratch/outputs/medic/29_distillation-for-wikiseealsotitles-with-dexa'
    model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
    
    """ Load data """
    pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
    pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_distilbert-base-uncased_xcs.pkl'

    if build_block:
        data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'
        block = XCBlock.from_cfg(data_dir, 'data', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                                 sampling_features=[('lbl2data',4)], oversample=False)
        with open(pkl_file, 'wb') as file: pickle.dump(block, file)
    else:
        with open(pkl_file, 'rb') as file: block = pickle.load(file)

    block.collator.tfms.tfms[0].sampling_features = [('lbl2data',4)]
    block.collator.tfms.tfms[0].oversample = False

    lbl_remap, n_clusters = get_label_remap(output_dir, block, cluster_sz=3)

    """ Training arguements """
    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
        representation_search_type='BRUTEFORCE',
        
        output_representation_attribute='data_repr',
        label_representation_attribute='data_repr',
        metadata_representation_attribute='data_repr',
        data_augmentation_attribute='data_repr',
        representation_attribute='data_repr',
        clustering_representation_attribute='data_repr',
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        use_data_metadata_for_clustering=True,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
        
        use_distributional_representation=False,
        use_encoder_parallel=True,
        max_grad_norm=None, 
        fp16=True,
    
        use_cpu_for_searching=True,
        use_cpu_for_clustering=True,
    )

    """ Teacher model """
    m_teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)
    m_teacher.freeze_embeddings()
    
    """ Student model """
    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
    m_student = DEX001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=bsz, num_batch_labels=5000,
                                       margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True, use_encoder_parallel=True,
                                       n_labels=block.n_lbl, n_clusters=n_clusters)
    m_student.init_retrieval_head()
    m_student.init_label_embeddings()
    m_student.set_label_remap(lbl_remap)

    """ Distillation model """
    model = DTL004(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, 
                   n_negatives=10, apply_softmax=True, teacher_data_student_label_loss_weight=1.0, 
                   student_data_teacher_label_loss_weight=1.0, data_mse_loss_weight=0.1, label_mse_loss_weight=0.1)

    """ Training """
    metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                      pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

    learn = XCLearner(
        model=model, 
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    mp.freeze_support()
    learn.train()
    

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,P@1,P@10,P@3,P@5,N@1,N@10,N@3,N@5,Psp@1,Psp@10,Psp@3,Psp@5,Psn@1,Psn@10,Psn@3,Psn@5,R@200,R@10,R@100
10,0.0788,0.086276,0.175101,0.056812,0.115025,0.087132,0.175101,0.194849,0.173305,0.180583,0.163741,0.209324,0.168619,0.180018,0.163741,0.201765,0.174354,0.185467,0.429889,0.235065,0.384685


  0%|          | 0/15617 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


  0%|          | 0/15617 [00:00<?, ?it/s]