In [7]:
#| default_exp 30_bert-ngame-for-wikiseealsotitles-with-input-concatenation

In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np
from xcai.basics import *
from xcai.models.BBB0XX import BRT009

In [10]:
os.environ['WANDB_MODE'] = 'disabled'

In [11]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='medic_00-wikiseealsotitles'

## Driver

In [None]:
#| export
if __name__ == '__main__':
    build_block = False

    """ Load data """
    pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
    pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-meta_bert-base-uncased_xcs_cat-128.pkl'

    if build_block:
        data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'
        block = XCBlock.from_cfg(data_dir, 'data_meta', transform_type='xcs', tokenizer='bert-base-uncased', 
                                 sampling_features=[('lbl2data',1)], oversample=True)

        block = AugmentMetaInputIdsTfm.apply(block, 'cat_meta', 'data', 128, True)
        block = AugmentMetaInputIdsTfm.apply(block, 'cat_meta', 'lbl', 128, True)

        with open(pkl_file, 'wb') as file: pickle.dump(block, file)
    else:
        with open(pkl_file, 'rb') as file: block = pickle.load(file)

    block.train.dset.data.data_info['input_ids'] = block.train.dset.data.data_info['input_ids_aug_cat']
    block.train.dset.data.data_info['attention_mask'] = block.train.dset.data.data_info['attention_mask_aug_cat']
    block.test.dset.data.data_info['input_ids'] = block.test.dset.data.data_info['input_ids_aug_cat']
    block.test.dset.data.data_info['attention_mask'] = block.test.dset.data.data_info['attention_mask_aug_cat']

    block.train.dset.data.lbl_info['input_ids'] = block.train.dset.data.lbl_info['input_ids_aug_cat']
    block.train.dset.data.lbl_info['attention_mask'] = block.train.dset.data.lbl_info['attention_mask_aug_cat']
    block.test.dset.data.lbl_info['input_ids'] = block.test.dset.data.lbl_info['input_ids_aug_cat']
    block.test.dset.data.lbl_info['attention_mask'] = block.test.dset.data.lbl_info['attention_mask_aug_cat']

    block.train.dset.meta = {}
    block.test.dset.meta = {}

    """ Training Arguements """
    args = XCLearningArguments(
        output_dir='/home/scai/phd/aiz218323/scratch/outputs/medic/30_bert-ngame-for-wikiseealsotitles-with-input-concatenation',
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
        
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,
        
        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
        
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
    )

    metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                      pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

    """ Model """
    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
    model = BRT009.from_pretrained('sentence-transformers/msmarco-bert-base-dot-v5', bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, 
                                   n_negatives=10, apply_softmax=True, use_encoder_parallel=True)
    model.init_dr_head()

    
    learn = XCLearner(
        model=model, 
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    mp.freeze_support()
    learn.train()
    