# RAGDA DR training

In [None]:
#| default_exp 66-radga-dr-ep-for-wikiseealso

In [None]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch, torch.multiprocessing as mp, pickle
from xcai.basics import *
from xcai.models.radga import RAD002

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='xc-nlg_66-radga-dr-ep-for-wikiseealso'

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

In [None]:
block = XCBlock.from_cfg(data_dir, 'data_metas', tfm='rm', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data|cat2lbl2data',1, (1,3)), ('cat2data',1,3)])

In [None]:
block = XCBlock.from_cfg(data_dir, 'data', tfm='xcnlg', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data',1, 1)])

In [None]:
#| export
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-metas_distilbert-base-uncased_rm_radga-cat-final.pkl'

In [None]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data_distilbert-base-uncased_xcnlg_ngame.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [None]:
#| export
with open(pkl_file, 'rb') as file: block = pickle.load(file)

## Training

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/66-radga-dr-ep-for-wikiseealso-1-0',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    representation_attribute="data_fused_repr",
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.0,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='INDEX',
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,
    output_concatenation_weight=1.0,
    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
    label_names=['cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask'],
    # label_names=['hlk2data_idx', 'hlk2data_input_ids', 'hlk2data_attention_mask',
    #              'hlk2lbl2data_idx', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask'],
)

In [None]:
#| export
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = RAD002.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', num_batch_labels=5000, batch_size=bsz,
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix='cat2lbl', 
                               resize_length=5000,
                               
                               meta_loss_weight=0.3, pred_meta_prefix='hlk', 
                               
                               fusion_loss_weight=0.05, use_fusion_loss=False, use_noise=False, use_encoder_parallel=True)
model.init_retrieval_head()
model.init_cross_head()

Some weights of RAD002 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
learn.train()

In [None]:
#| export
if __name__ == '__main__':
    mp.freeze_support()
    learn.train()

## Prediction

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/66-radga-dr-ep-for-wikiseealso-1-6',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    output_representation_attribute="data_fused_repr",
    representation_attribute="data_fused_repr",
    
    predict_with_representation=True,
    representation_search_type='INDEX',
    
    metric_for_best_model='P@1',
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    fp16=True,
    label_names=['cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask'],
    # label_names=['hlk2data_idx', 'hlk2data_input_ids', 'hlk2data_attention_mask',
    #              'hlk2lbl2data_idx', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask'],

    predict_with_augmentation=False,
    use_augmentation_index_representation=True,
    
    data_aug_meta_name='cat',
    augmentation_num_beams=3,
    data_aug_prefix='aug',
    use_label_metadata=False,
)

In [None]:
output_dir = f"/home/scai/phd/aiz218323/scratch/outputs/{os.path.basename(args.output_dir)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

In [None]:
mname = f'{output_dir}/checkpoint-130200'

In [None]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = RAD002.from_pretrained(mname, num_batch_labels=5000, batch_size=bsz,
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                                
                               data_aug_meta_prefix='aug2data', lbl2data_aug_meta_prefix='aug2lbl', 
                               resize_length=5000,
                               
                               meta_loss_weight=0.3, 
                               pred_meta_prefix=None, #pred_meta_prefix='cat', 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False, use_noise=False, use_encoder_parallel=True)

In [None]:
import numpy as np
from xcai.data import MainXCDataset, XCDataBlock, XCDataset, BaseXCDataBlock

dataset = MainXCDataset(block.test.dset.data.data_info, block.test.dset.meta['cat_meta'].data_meta, 
                        block.test.dset.meta['cat_meta'].meta_info)
idx = np.where(dataset.data_lbl.getnnz(axis=1) > 0)[0]
test_meta_dset = XCDataset(dataset._getitems(idx))

dataset = MainXCDataset(block.train.dset.data.data_info, block.train.dset.meta['cat_meta'].data_meta, 
                        block.train.dset.meta['cat_meta'].meta_info)
idx = np.where(dataset.data_lbl.getnnz(axis=1) > 0)[0]
train_meta_dset = XCDataset(dataset._getitems(idx))

In [None]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

In [None]:
metric = PrecRecl(test_meta_dset.n_lbl, test_meta_dset.data.data_lbl_filterer, prop=train_meta_dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=train_meta_dset,
    eval_dataset=test_meta_dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
o = learn.predict(test_meta_dset)

  0%|          | 0/411 [00:00<?, ?it/s]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,34.5028,18.4519,13.2887,8.2554,34.5028,26.0972,25.4834,26.5477,28.9366,21.3815,20.5157,21.7502,28.9366,24.7859,25.5341,27.5692,27.6437,41.7672,45.6261,0.0283,663.5434,261.556,0.164


In [None]:
o = learn.predict(block.test.dset)

  0%|          | 0/196 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,37.145,24.8182,18.7853,11.9354,37.145,37.347,38.7788,41.2028,28.5487,31.1488,33.7948,39.0724,28.5487,31.8077,33.9975,36.7601,47.5127,65.2832,69.2098,0.0217,641.0539,276.911,0.173


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,28.5024,19.3702,14.9239,9.6855,28.5024,28.7246,29.9571,32.0412,22.5878,25.1417,27.7646,32.6998,22.5878,25.2912,27.1844,29.5799,37.3778,53.7569,57.6496,0.0347,1248.3147,142.204,0.089


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,30.8261,20.5389,15.5689,9.8958,30.8261,30.6954,31.7409,33.6136,25.9538,27.6247,29.6802,33.9528,25.9538,28.1304,29.7845,31.94,38.3747,53.0829,56.5055,0.0258,515.3512,344.454,0.215
