In [2]:
#| default_exp 02_oak-for-wikiseealsotitles-trained-with-groundtruth-metadata

In [3]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [4]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np, math, transformers
from transformers import DistilBertConfig

from xcai.basics import *
from xcai.models.oakX import OAK001
from xcai.optimizers.oakX import MultipleOptimizer, MultipleScheduler


from xclib.utils.sparse import retain_topk

from fastcore.utils import *

In [5]:
os.environ['WANDB_MODE'] = 'disabled'

In [6]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='oakVn_00-wikiseealsotitles'

## Optimizer

In [7]:
#| export
@patch
def create_optimizer_and_scheduler(self:XCLearner, num_training_steps: int):
    NO_DECAY = ['bias', 'LayerNorm.weight']

    dense, sparse = [], []
    for k, p in model.named_parameters():
        if p.requires_grad:
            if "meta_embeddings" not in k: dense.append((k,p))
            else: sparse.append(p)

    params = [
        {'params': [p for n, p in dense if not any(nd in n for nd in NO_DECAY)], 'weight_decay': 0.01},
        {'params': [p for n, p in dense if any(nd in n for nd in NO_DECAY)], 'weight_decay': 0.0},
    ]

    optimizer_list = [torch.optim.AdamW(params, **{'lr': self.args.learning_rate, 'eps': 1e-6}),
                      torch.optim.SparseAdam(sparse, **{'lr': self.args.learning_rate * self.args.free_parameter_lr_coefficient, 'eps': 1e-6})]

    self.optimizer = MultipleOptimizer(optimizer_list)
    scheduler_list = [transformers.get_linear_schedule_with_warmup(self.optimizer.optimizers[0], num_warmup_steps=self.args.warmup_steps,
                                                                   num_training_steps=num_training_steps),
                        transformers.get_cosine_schedule_with_warmup(self.optimizer.optimizers[1],
                                                                     num_warmup_steps=self.args.free_parameter_warmup_steps,
                                                                     num_training_steps=num_training_steps)]

    self.lr_scheduler = MultipleScheduler(scheduler_list)
    

## Driver

In [None]:
#| export
if __name__ == '__main__':
    build_block = True
    pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
    data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'
    
    output_dir = '/home/scai/phd/aiz218323/scratch/outputs/mogic/02_oak-for-wikiseealsotitles-trained-with-groundtruth-metadata'
    meta_embed_file = '/home/aiscuser/scratch/OGB_Weights/LF-WikiSeeAlsoTitles-320K/emb_weights.npy'

    """ Load data """
    pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-lnk_distilbert-base-uncased_xcs.pkl'

    if build_block:
        block = XCBlock.from_cfg(data_dir, 'data_cat_lnk', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                                 sampling_features=[('lbl2data',4), ('lnk2data',3), ('cat2data',3)], oversample=False)
        with open(pkl_file, 'wb') as file: pickle.dump(block, file)
        exit()
    else:
        with open(pkl_file, 'rb') as file: block = pickle.load(file)
    
    """ Uses ground truth during training and linker prediction during inference """
    block.train.dset.meta['hyb'] = MetaXCDataset('hyb', block.train.dset.meta['cat_meta'].data_meta, 
                                                 block.train.dset.meta['cat_meta'].lbl_meta, block.train.dset.meta['cat_meta'].meta_info)
    
    data_meta = retain_topk(block.test.dset.meta['lnk_meta'].data_meta, k=3)
    lbl_meta = block.test.dset.meta['lnk_meta'].lbl_meta
    block.test.dset.meta['hyb'] = MetaXCDataset('hyb', data_meta, lbl_meta, block.test.dset.meta['lnk_meta'].meta_info)
    
    block.collator.tfms.tfms[0].sampling_features = [('lbl2data',4),('hyb2data',3)]
    block.collator.tfms.tfms[0].oversample = False

    del block.train.dset.meta['lnk_meta']
    del block.test.dset.meta['lnk_meta']

    del block.train.dset.meta['cat_meta']
    del block.test.dset.meta['cat_meta']

    block.train.dset.meta['hyb_meta'].meta_info = None
    block.test.dset.meta['hyb_meta'].meta_info = None

    """ Training arguements """
    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
        representation_search_type='BRUTEFORCE',
        
        output_representation_attribute='data_fused_repr',
        label_representation_attribute='data_repr',
        metadata_representation_attribute='data_repr',
        data_augmentation_attribute='data_repr',
        representation_attribute='data_fused_repr',
        clustering_representation_attribute='data_fused_repr',
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        use_data_metadata_for_clustering=True,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
        
        use_distributional_representation=False,
        use_encoder_parallel=True,
        max_grad_norm=None, 
        fp16=True,
        
        label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'hyb2data_idx'],
        
        prune_metadata=False,
        num_metadata_prune_warmup_epochs=10,
        num_metadata_prune_epochs=5,
        metadata_prune_batch_size=2048,
        prune_metadata_names=['hyb_meta'],
        use_data_metadata_for_pruning=True,
    
        predict_with_augmentation=False,
        use_augmentation_index_representation=True,
    
        data_aug_meta_name='hyb',
        augmentation_num_beams=None,
        data_aug_prefix='hyb',
        use_label_metadata=False,
        
        data_meta_batch_size=2048,
        augment_metadata=False,
        num_metadata_augment_warmup_epochs=10,
        num_metadata_augment_epochs=5,
    
        use_cpu_for_searching=False,
        use_cpu_for_clustering=True,
    )

    """ model """
    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
    model = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=bsz, num_batch_labels=5000, 
                                   margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                                   data_aug_meta_prefix='hyb2data', lbl2data_aug_meta_prefix=None, 
                                   data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                                   
                                   num_metadata=block.train.dset.meta['hyb_meta'].n_meta, resize_length=5000,
                                   
                                   calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                                   calib_loss_weight=0.1, use_calib_loss=False,
    
                                   use_query_loss=True,
    
                                   meta_loss_weight=0.0, 
                                   
                                   fusion_loss_weight=0.0, use_fusion_loss=False,
                                   
                                   use_encoder_parallel=True)
    
    model.init_retrieval_head()
    model.init_cross_head()
    model.init_meta_embeddings()
    
    meta_embeddings = np.load(meta_embed_file)
    model.encoder.set_pretrained_meta_embeddings(torch.tensor(meta_embeddings, dtype=torch.float32))
    model.encoder.freeze_pretrained_meta_embeddings()
    
    """ Training """
    metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                      pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

    learn = XCLearner(
        model=model, 
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    mp.freeze_support()
    learn.train()
    

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,P@1,P@10,P@3,P@5,N@1,N@10,N@3,N@5,Psp@1,Psp@10,Psp@3,Psp@5,Psn@1,Psn@10,Psn@3,Psn@5,R@200,R@10,R@100
10,0.0788,0.086276,0.175101,0.056812,0.115025,0.087132,0.175101,0.194849,0.173305,0.180583,0.163741,0.209324,0.168619,0.180018,0.163741,0.201765,0.174354,0.185467,0.429889,0.235065,0.384685


  0%|          | 0/15617 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


  0%|          | 0/15617 [00:00<?, ?it/s]