# RAGDA DR training

In [None]:
#| default_exp 66-radga-dr-ep-for-wikiseealso

In [None]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch, torch.multiprocessing as mp, pickle
from xcai.basics import *
from xcai.models.radga import RAD002

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='xc-nlg_66-radga-dr-ep-for-wikiseealso'

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

In [None]:
block = XCBlock.from_cfg(data_dir, 'data_metas', tfm='rm', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data|cat2lbl2data|hlk2lbl2data',1, (1,1,3)), ('cat2data',1,1), ('hlk2data',1,3)], 
                         n_data_meta_samples=50, n_lbl_meta_samples=50)



In [None]:
#| export
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_dir = f'{data_dir}/processed/'
pkl_file = f'{pkl_dir}/wikiseealso_data-metas_distilbert-base-uncased_rm_radga-final.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [None]:
#| export
with open(pkl_file, 'rb') as file: 
    block = pickle.load(file)

## Training

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/66-radga-dr-ep-for-wikiseealso-1-0',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    representation_attribute="data_fused_repr",
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.0,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='INDEX',
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,
    output_concatenation_weight=1.0,
    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
    label_names=['hlk2data_idx', 'hlk2data_input_ids', 'hlk2data_attention_mask'],
    # label_names=['hlk2data_idx', 'hlk2data_input_ids', 'hlk2data_attention_mask',
    #              'hlk2lbl2data_idx', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask'],
)

In [None]:
#| export
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = RAD002.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', num_batch_labels=5000, batch_size=bsz,
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='hlk2data', lbl2data_aug_meta_prefix='hlk2lbl', 
                               resize_length=5000,
                               
                               meta_loss_weight=0.3, pred_meta_prefix='cat', 
                               
                               fusion_loss_weight=0.05, use_fusion_loss=False, use_noise=False, use_encoder_parallel=True)
model.init_retrieval_head()
model.init_cross_head()

Some weights of RAD002 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
learn.train()

In [None]:
#| export
if __name__ == '__main__':
    mp.freeze_support()
    learn.train()