# RAGDA DR training

In [None]:
#| default_exp 75-radga-dr-ep-for-wikititles

In [None]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch, torch.multiprocessing as mp, pickle
from xcai.basics import *
from xcai.analysis import *
from xcai.models.radgaX import RAD002

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='xc-nlg_66-radga-dr-ep-for-wikiseealso'

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

In [None]:
block = XCBlock.from_cfg(data_dir, 'data_hlklnk', dset='wikititles', tfm='rm', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data|hlk2lbl2data|lnk2lbl2data',1, (1,3,3)), ('hlk2data',1,3), ('lnk2data',1,3)])



ValueError: `lbl_meta`(2148579) should have same number of columns as `data_meta`(3118594).

In [None]:
block = XCBlock.from_cfg(data_dir, 'data', tfm='xcnlg', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data',1, 1)])

In [None]:
#| export
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikititles_data-metas_distilbert-base-uncased_rm_radga-hlk-linker.pkl'

In [None]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikititles_data_distilbert-base-uncased_xcnlg_ngame.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [None]:
#| export
with open(pkl_file, 'rb') as file: block = pickle.load(file)

## Training

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/75-radga-dr-ep-for-wikititles-1-0',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    output_representation_attribute='data_fused_repr',
    representation_attribute="data_fused_repr",
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.0,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='INDEX',
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,
    output_concatenation_weight=1.0,
    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
    label_names=['hlk2data_idx', 'hlk2data_input_ids', 'hlk2data_attention_mask'],
    # label_names=['hlk2data_idx', 'hlk2data_input_ids', 'hlk2data_attention_mask',
    #              'hlk2lbl2data_idx', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask'],
)

In [None]:
#| export
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = RAD002.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', num_batch_labels=5000, batch_size=bsz,
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='hlk2data', lbl2data_aug_meta_prefix='hlk2lbl', 
                               resize_length=5000,
                               
                               meta_loss_weight=0.3, pred_meta_prefix=None, 
                               
                               fusion_loss_weight=0.05, use_fusion_loss=False, use_noise=False, use_encoder_parallel=True)
model.init_retrieval_head()
model.init_cross_head()

Some weights of RAD002 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
learn.train()

In [None]:
#| export
if __name__ == '__main__':
    mp.freeze_support()
    learn.train()

## Prediction

In [None]:
o_dir ='/home/scai/phd/aiz218323/scratch/outputs/66-radga-dr-ep-for-wikiseealso-5-2'
output_dir = f"/home/scai/phd/aiz218323/scratch/outputs/{os.path.basename(o_dir)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

In [None]:
bsz = 1600
model = RAD002.from_pretrained(mname, num_batch_labels=5000, batch_size=bsz,
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                                
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               resize_length=5000,
                               
                               meta_loss_weight=0.3, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False, use_noise=False, use_encoder_parallel=True)

### Metadata Prediction

In [None]:
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/66-radga-dr-ep-for-wikiseealso-5-2',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    output_representation_attribute="data_repr",
    representation_attribute="data_repr",
    
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    
    metric_for_best_model='P@1',
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    fp16=True,
    label_names=['cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask'],

    predict_with_augmentation=False,
    use_augmentation_index_representation=True,
    
    data_aug_meta_name='cat',
    augmentation_num_beams=3,
    data_aug_prefix='aug',
    use_label_metadata=False,
)

In [None]:
import numpy as np
from xcai.data import MainXCDataset, XCDataBlock, XCDataset, BaseXCDataBlock

test_meta_dset = MainXCDataset(block.test.dset.data.data_info)

dataset = MainXCDataset(block.train.dset.data.data_info, block.train.dset.meta['cat_meta'].data_meta, 
                        block.train.dset.meta['cat_meta'].meta_info)
idx = np.where(dataset.data_lbl.getnnz(axis=1) > 0)[0]
train_meta_dset = XCDataset(dataset._getitems(idx))

In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=train_meta_dset,
    eval_dataset=test_meta_dset,
    data_collator=block.collator,
    compute_metrics=None,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
o = learn.predict(test_meta_dset)

  0%|          | 0/411 [00:00<?, ?it/s]

In [None]:
from xcai.data import MetaXCDataset
from scipy import sparse
from xclib.utils.sparse import retain_topk

In [None]:
preds = retain_topk(get_pred_sparse(o, train_meta_dset.n_lbl), k=5)
aug_dset = MetaXCDataset('aug', preds, sparse.csr_matrix((block.n_lbl, train_meta_dset.n_lbl)), train_meta_dset.lbl_info)

block.test.dset.meta['aug_meta'] = aug_dset

### Label Prediction

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/66-radga-dr-ep-for-wikiseealso-5-2',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    output_representation_attribute="data_fused_repr",
    representation_attribute="data_fused_repr",
    
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    
    metric_for_best_model='P@1',
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    fp16=True,
    # label_names=['lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask'],
    label_names=['aug2data_idx', 'aug2data_input_ids', 'aug2data_attention_mask'],
    # label_names=['cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask'],
    # label_names=['hlk2data_idx', 'hlk2data_input_ids', 'hlk2data_attention_mask',
    #              'hlk2lbl2data_idx', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask'],

    predict_with_augmentation=False,
    use_augmentation_index_representation=True,
    
    data_aug_meta_name='cat',
    augmentation_num_beams=3,
    data_aug_prefix='aug',
    use_label_metadata=False,
)

In [None]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
learn.model.data_aug_meta_prefix = 'aug2data'

In [None]:
o = learn.predict(block.test.dset)

  0%|          | 0/196 [00:00<?, ?it/s]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


  self._set_arrayXarray(i, j, x)


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,26.5673,18.2133,14.0982,9.267,26.5673,26.6909,27.8684,29.9781,21.4784,24.0275,26.6204,31.7511,21.4784,24.0557,25.9068,28.3621,35.2704,52.1819,56.2612,0.0261,301.9623,587.871,0.368


__Three ground truth metadata__

In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,40.0969,26.336,19.9893,12.7582,40.0969,40.4669,42.2028,44.9619,29.1719,31.9999,35.0768,41.136,29.1719,32.9694,35.5467,38.7476,52.1834,71.4731,75.3854,0.0178,253.7373,699.602,0.437


__Three predicted metadata__

In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,25.9989,17.9363,13.9451,9.1716,25.9989,26.2267,27.4426,29.5332,21.0299,23.7623,26.4269,31.5032,21.0299,23.7023,25.5771,28.0035,34.8522,51.6796,55.6961,0.026,755.2368,235.045,0.147


__One predicted metadata__

In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,24.4751,17.0679,13.3589,8.8409,24.4751,24.8718,26.1337,28.2051,19.7119,22.6864,25.3921,30.4427,19.7119,22.4863,24.3645,26.7577,33.5469,50.1171,53.9882,0.026,911.3489,194.783,0.122


__No metadata__

In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,28.5841,19.2309,14.6871,9.498,28.5841,28.463,29.5066,31.4858,24.9028,26.5122,28.6209,33.1786,24.9028,26.9498,28.5921,30.846,36.362,52.0646,55.7828,0.0178,260.3383,681.863,0.426


__Linker__

In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,36.3057,24.4207,18.719,12.0918,36.3057,36.9515,38.7413,41.5432,26.5191,29.6405,32.7938,38.877,26.5191,30.1972,32.7319,35.8922,49.0683,69.3868,73.6306,0.0191,247.9957,715.799,0.448


### Metadata Analysis

In [None]:
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/66-radga-dr-ep-for-wikiseealso-5-2',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    output_representation_attribute="data_repr",
    representation_attribute="data_repr",
    
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    
    metric_for_best_model='P@1',
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    use_encoder_parallel=True,
    fp16=True,
    label_names=['cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask'],
    # label_names=['hlk2data_idx', 'hlk2data_input_ids', 'hlk2data_attention_mask',
    #              'hlk2lbl2data_idx', 'hlk2lbl2data_input_ids', 'hlk2lbl2data_attention_mask'],

    predict_with_augmentation=False,
    use_augmentation_index_representation=True,
    
    data_aug_meta_name='cat',
    augmentation_num_beams=3,
    data_aug_prefix='aug',
    use_label_metadata=False,
)

In [None]:
import numpy as np
from xcai.data import MainXCDataset, XCDataBlock, XCDataset, BaseXCDataBlock

dataset = MainXCDataset(block.test.dset.data.data_info, block.test.dset.meta['cat_meta'].data_meta, 
                        block.test.dset.meta['cat_meta'].meta_info)
idx = np.where(dataset.data_lbl.getnnz(axis=1) > 0)[0]
test_meta_dset = XCDataset(dataset._getitems(idx))

dataset = MainXCDataset(block.train.dset.data.data_info, block.train.dset.meta['cat_meta'].data_meta, 
                        block.train.dset.meta['cat_meta'].meta_info)
idx = np.where(dataset.data_lbl.getnnz(axis=1) > 0)[0]
train_meta_dset = XCDataset(dataset._getitems(idx))

In [None]:
metric = PrecRecl(test_meta_dset.n_lbl, test_meta_dset.data.data_lbl_filterer, prop=train_meta_dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=train_meta_dset,
    eval_dataset=test_meta_dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
o = learn.predict(test_meta_dset)

  0%|          | 0/411 [00:00<?, ?it/s]

In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,32.7731,17.3842,12.5425,7.8116,32.7731,24.2618,23.5385,24.4533,28.1833,20.58,19.7698,20.9801,28.1833,23.7536,24.3566,26.2213,25.119,39.0703,42.8998,0.0288,103.8435,1671.304,1.05


In [None]:
meta_pred = get_pred_sparse(o, test_meta_dset.data.n_lbl)

In [None]:
import xclib.evaluation.xc_metrics as xc_metrics
xc_metrics.precision(meta_pred, test_meta_dset.data.data_lbl, k=5)

array([0.34504535, 0.23478572, 0.18452662, 0.15389447, 0.1328912 ])

In [None]:
evals = pointwise_eval(meta_pred, test_meta_dset.data.data_lbl, topk=3, metric='P', return_type='D')

In [None]:
topk_meta_pred = xc_sparse.retain_topk(meta_pred, k=3)

In [None]:
pattern = r'^(data|lbl2data)_input_text$'

pred_dset = TextColumns(get_pred_dset(topk_meta_pred, test_meta_dset), pat=pattern)
test_dset = TextColumns(test_meta_dset, pat=pattern)

In [None]:
from IPython.display import HTML

In [None]:
idx = np.argsort(evals)[:10]

In [None]:
HTML(display_text(pred_dset, test_dset, idx))

In [None]:
idx = np.argsort(evals)[-10:]

In [None]:
HTML(display_text(pred_dset, test_dset, idx))