In [1]:
#| default_exp 24_distillation-for-wikiseealsotitles-with-oak-curriculum-learning

In [2]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np
from scipy import sparse
from transformers import DistilBertConfig

from xcai.basics import *
from xcai.data import MetaXCDataset
from xcai.models.oak import OAK003
from xcai.models.distillation import DTL004,TCH001,TCH002
from xcai.models.classifiers import CLS001

from xclib.utils.sparse import retain_topk

from fastcore.utils import *

comet_ml is installed but `COMET_API_KEY` is not set.


In [4]:
os.environ['WANDB_MODE'] = 'disabled'

In [5]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='medic_00-wikiseealsotitles'

## Load data

In [6]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-cat-lnk_distilbert-base-uncased_xcs.pkl'

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'
block = XCBlock.from_cfg(data_dir, 'data_cat_lnk', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                         sampling_features=[('lbl2data',4), ('lnk2data',3)], oversample=False)

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [7]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [8]:
data_meta = retain_topk(block.train.dset.meta.lnk_meta.data_meta, k=5)
lbl_meta = block.train.dset.meta.lnk_meta.lbl_meta
block.train.dset.meta.lnk_meta.update_meta_matrix(data_meta, lbl_meta)

data_meta = retain_topk(block.test.dset.meta.lnk_meta.data_meta, k=3)
lbl_meta = block.test.dset.meta.lnk_meta.lbl_meta
block.test.dset.meta.lnk_meta.update_meta_matrix(data_meta, lbl_meta)

block.collator.tfms.tfms[0].sampling_features = [('lbl2data',4),('lnk2data',3)]
block.collator.tfms.tfms[0].oversample = False

In [9]:
block.train.dset.meta['lnk_meta'].meta_info = None
block.test.dset.meta['lnk_meta'].meta_info = None

block.train.dset.meta['cat_meta'].meta_info = None
block.test.dset.meta['cat_meta'].meta_info = None

In [10]:
block.train.dset.meta['hyb_meta'] = MetaXCDataset('hyb', block.train.dset.meta['cat_meta'].data_meta.copy(), 
                                                  block.train.dset.meta['cat_meta'].lbl_meta.copy())
block.test.dset.meta['hyb_meta'] = MetaXCDataset('hyb', block.test.dset.meta['lnk_meta'].data_meta.copy(), 
                                                 block.test.dset.meta['lnk_meta'].lbl_meta.copy())

In [11]:
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/medic/14_distillation-for-wikititles-with-oak',
    mix_metadata=True,
    num_mix_metadata_epochs=5,
    num_mix_metadata_warmup_epochs=0,
    maximum_mix_metadata_epochs=50,
    mix_metadata_name_1='cat',
    mix_metadata_name_2='lnk',
    mix_metadata_k=3,
    num_train_epochs=10,
    label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'hyb2data_idx'],
)

model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
model = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)
model.freeze_embeddings()

learn = XCLearner(
    model=model,
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [12]:
block.train.dset.meta['hyb_meta'].data_meta, block.train.dset.meta['cat_meta'].data_meta, block.train.dset.meta['lnk_meta'].data_meta

(<693082x656086 sparse matrix of type '<class 'numpy.float32'>'
 	with 3390902 stored elements in Compressed Sparse Row format>,
 <693082x656086 sparse matrix of type '<class 'numpy.float32'>'
 	with 3390902 stored elements in Compressed Sparse Row format>,
 <693082x656086 sparse matrix of type '<class 'numpy.float32'>'
 	with 3465410 stored elements in Compressed Sparse Row format>)

In [13]:
block.train.dset.mix_meta_dataset('cat', 'lnk', pct=1, k=3)

In [14]:
batch = next(iter(learn.get_train_dataloader()))

In [16]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_idx', 'hyb2data_idx', 'hyb2data_data2ptr'])

## Driver

In [None]:
#| export
if __name__ == '__main__':
    build_block = False

    pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
    data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'
    
    output_dir = '/home/scai/phd/aiz218323/scratch/outputs/medic/14_distillation-for-wikititles-with-oak'
    model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'
    meta_embed_file = '/home/aiscuser/scratch/OGB_Weights/LF-WikiSeeAlsoTitles-320K/emb_weights.npy'

    """ Load data """
    pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-cat-lnk_distilbert-base-uncased_xcs.pkl'
    if build_block:
        block = XCBlock.from_cfg(data_dir, 'data_cat_lnk', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                                 sampling_features=[('lbl2data',4), ('lnk2data',3)], oversample=True)
        with open(pkl_file, 'wb') as file: pickle.dump(block, file)
    else:
        with open(pkl_file, 'rb') as file: block = pickle.load(file)

    """ Prune metadata """
    data_meta = retain_topk(block.train.dset.meta.lnk_meta.data_meta, k=5)
    lbl_meta = block.train.dset.meta.lnk_meta.lbl_meta
    block.train.dset.meta.lnk_meta.update_meta_matrix(data_meta, lbl_meta)
    
    data_meta = retain_topk(block.test.dset.meta.lnk_meta.data_meta, k=3)
    lbl_meta = block.test.dset.meta.lnk_meta.lbl_meta
    block.test.dset.meta.lnk_meta.update_meta_matrix(data_meta, lbl_meta)

    block.collator.tfms.tfms[0].sampling_features = [('lbl2data',4),('hyb2data',3)]
    block.collator.tfms.tfms[0].oversample = True
    
    block.train.dset.meta['lnk_meta'].meta_info = None
    block.test.dset.meta['lnk_meta'].meta_info = None
    
    block.train.dset.meta['cat_meta'].meta_info = None
    block.test.dset.meta['cat_meta'].meta_info = None

    block.train.dset.meta['hyb_meta'] = MetaXCDataset('hyb', block.train.dset.meta['cat_meta'].data_meta.copy(), 
                                                      block.train.dset.meta['cat_meta'].lbl_meta.copy())
    block.test.dset.meta['hyb_meta'] = MetaXCDataset('hyb', block.test.dset.meta['lnk_meta'].data_meta.copy(), 
                                                     block.test.dset.meta['lnk_meta'].lbl_meta.copy())

    """ Training arguements """
    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
        representation_search_type='BRUTEFORCE',
        
        output_representation_attribute='data_fused_repr',
        label_representation_attribute='data_repr',
        metadata_representation_attribute='data_repr',
        data_augmentation_attribute='data_repr',
        representation_attribute='data_fused_repr',
        clustering_representation_attribute='data_fused_repr',
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        use_data_metadata_for_clustering=True,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
        
        use_distributional_representation=False,
        use_encoder_parallel=True,
        max_grad_norm=None, 
        fp16=True,
        
        label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'hyb2data_idx'],
        
        prune_metadata=False,
        num_metadata_prune_warmup_epochs=10,
        num_metadata_prune_epochs=5,
        metadata_prune_batch_size=2048,
        prune_metadata_names=['hyb_meta'],
        use_data_metadata_for_pruning=True,
        prune_metadata_threshold=0.0,
        prune_metadata_topk=3,
    
        predict_with_augmentation=False,
        use_augmentation_index_representation=True,
    
        data_aug_meta_name='hyb',
        augmentation_num_beams=None,
        data_aug_prefix='hyb',
        use_label_metadata=False,

        augment_metadata=False,
        data_meta_batch_size=2048,
        num_metadata_augment_warmup_epochs=10,
        num_metadata_augment_epochs=5,
    
        use_cpu_for_searching=True,
        use_cpu_for_clustering=True,

        mix_metadata=True,
        num_mix_metadata_epochs=5,
        num_mix_metadata_warmup_epochs=10,
        maximum_mix_metadata_epochs=50,
        mix_metadata_name_1='cat',
        mix_metadata_name_2='lnk',
        mix_metadata_k=3,
    )

    """ Teacher model """
    m_teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)
    m_teacher.freeze_embeddings()

    """ Student model """
    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

    m_student = OAK003.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=bsz, num_batch_labels=5000,
                                       margin=0.3, num_negatives=5, tau=0.1, apply_softmax=True,
                                       
                                       data_aug_meta_prefix='hyb2data', lbl2data_aug_meta_prefix=None,
                                       data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                                       
                                       num_metadata=block.train.dset.meta['hyb_meta'].n_meta, resize_length=5000,
                                       
                                       calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False,
                                       calib_loss_weight=0.1, use_calib_loss=True,
                                       
                                       use_query_loss=True,
                                       
                                       meta_loss_weight=0.0,
                                       
                                       fusion_loss_weight=0.1, use_fusion_loss=False,
                                       
                                       use_encoder_parallel=True)
    m_student.init_retrieval_head()
    m_student.init_cross_head()
    m_student.init_meta_embeddings()
    
    # meta_embeddings = np.load(meta_embed_file)
    # m_student.encoder.set_pretrained_meta_embeddings(torch.tensor(meta_embeddings, dtype=torch.float32))
    m_student.encoder.set_pretrained_meta_embeddings(torch.zeros(block.train.dset.meta['hyb_meta'].n_meta, m_student.config.dim))
    m_student.encoder.freeze_pretrained_meta_embeddings()

    """ Distillation model """
    model = DTL004(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, 
                   n_negatives=10, apply_softmax=True, teacher_data_student_label_loss_weight=1.0, 
                   student_data_teacher_label_loss_weight=1.0, data_mse_loss_weight=0.1, label_mse_loss_weight=0.1)


    """ Training """
    metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                      pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

    learn = XCLearner(
        model=model, 
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    mp.freeze_support()
    learn.train()
    

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,P@1,P@10,P@3,P@5,N@1,N@10,N@3,N@5,Psp@1,Psp@10,Psp@3,Psp@5,Psn@1,Psn@10,Psn@3,Psn@5,R@200,R@10,R@100
10,0.0788,0.086276,0.175101,0.056812,0.115025,0.087132,0.175101,0.194849,0.173305,0.180583,0.163741,0.209324,0.168619,0.180018,0.163741,0.201765,0.174354,0.185467,0.429889,0.235065,0.384685


  0%|          | 0/15617 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


  0%|          | 0/15617 [00:00<?, ?it/s]

## Classifiers

In [None]:
learn.args.use_data_metadata_for_representation=True

In [None]:
output_dir = f"/home/aiscuser/scratch/Projects/xc_nlg/outputs/{os.path.basename(args.output_dir)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

In [None]:
model = DTL004.from_pretrained(mname, m_student=m_student, m_teacher=m_teacher, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, 
                               n_negatives=10, apply_softmax=True, teacher_data_student_label_loss_weight=1.0, data_mse_loss_weight=0.1)

In [None]:
train_rep, lbl_rep = learn.get_data_and_lbl_representation(learn.train_dataset)

In [None]:
test_rep = learn._get_data_representation(learn.eval_dataset)

In [None]:
model = CLS001(DistilBertConfig(), n_train=block.train.dset.n_data, n_test=block.test.dset.n_data, n_lbl=block.n_lbl, 
               batch_size=100, num_batch_labels=5000, margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True)
model.init_representation(train_rep, test_rep, lbl_rep)

In [None]:
fname = f'{os.path.dirname(mname)}/representation'
model.save_pretrained(fname)

In [None]:
o = learn.predict(block.test.dset)

In [None]:
display_metric(o.metrics)