In [None]:
#| default_exp 13_ngame-linker-for-wikititles-noise

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp

from xcai.basics import *
from xcai.models.PPP0XX import DBT009,DBT011

In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT'] = 'mogicX_02-wikititles-linker'

## Setup

In [None]:
output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/13_ngame-linker-for-wikititles-noise'

config_file = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/(mapped)LF-WikiTitles-500K/configs/data_hyper_link_noise-050.json'
config_key = 'data_hyper_link'

mname = 'sentence-transformers/msmarco-distilbert-base-v4'

In [None]:
do_train_inference = False
do_test_inference = True

save_train_inference = False
save_test_inference = False

save_representation = False

use_sxc_sampler, only_test = True, False

pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'

In [None]:
pkl_file = f'{pkl_dir}/mogicX/wikititles-noise_data-hyper-link_distilbert-base-uncased'
pkl_file = f'{pkl_file}_sxc' if use_sxc_sampler else f'{pkl_file}_xcs'
if only_test: pkl_file = f'{pkl_file}_only-test'
pkl_file = f'{pkl_file}.joblib'

In [None]:
do_inference = do_train_inference or do_test_inference or save_train_inference or save_test_inference or save_representation

In [None]:
pkl_file

'/scratch/scai/phd/aiz218323/datasets/processed//mogicX/wikititles-noise_data-hyper-link_distilbert-base-uncased_sxc.joblib'

In [None]:
%%time
os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
block = build_block(pkl_file, config_file, use_sxc_sampler, config_key)

CPU times: user 23min 45s, sys: 4min 15s, total: 28min
Wall time: 15min 57s


In [None]:
linker_block = block.linker_dset('hlk_meta', remove_empty=True)

## Driver

In [None]:
#| export
if __name__ == '__main__':
    output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/13_ngame-linker-for-wikititles-noise'

    config_file = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/(mapped)LF-WikiTitles-500K/configs/data_hyper_link_noise-050.json'
    config_key = 'data_hyper_link'

    mname = 'sentence-transformers/msmarco-distilbert-base-v4'

    input_args = parse_args()

    pkl_file = f'{input_args.pickle_dir}/mogicX/wikititles-noise_data-hyper-link_distilbert-base-uncased'
    pkl_file = f'{pkl_file}_sxc' if input_args.use_sxc_sampler else f'{pkl_file}_xcs'
    if input_args.only_test: pkl_file = f'{pkl_file}_only-test'
    pkl_file = f'{pkl_file}.joblib'

    do_inference = input_args.do_train_inference or input_args.do_test_inference or input_args.save_train_prediction or input_args.save_test_prediction or input_args.save_representation

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, only_test=input_args.only_test)

    linker_block = block.linker_dset('hlk_meta', remove_empty=True)

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,
    
        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
    
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
    )

    def model_fn(mname, bsz):
        model = DBT009.from_pretrained(mname, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, n_negatives=10, 
                                       apply_softmax=True, use_encoder_parallel=True)
        return model
    
    def init_fn(model): 
        model.init_dr_head()

    metric = PrecReclMrr(linker_block.n_lbl, linker_block.test.data_lbl_filterer, prop=linker_block.train.dset.data.data_lbl,
                         pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])
    
    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

    model = load_model(args.output_dir, model_fn, {"mname": mname, "bsz": bsz}, init_fn, do_inference=do_inference, use_pretrained=input_args.use_pretrained)
    
    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=linker_block.train.dset,
        eval_dataset=linker_block.test.dset,
        data_collator=linker_block.collator,
        compute_metrics=metric,
    )
    
    main(learn, input_args, n_lbl=linker_block.n_lbl)
    

## Metrics

In [None]:
import scipy.sparse as sp
import xclib.data.data_utils as du

from xcai.core import get_output

In [None]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/(mapped)LF-WikiSeeAlsoTitles-320K'

In [None]:
pred_meta = sp.load_npz(f'{data_dir}/category_renee_tst_X_Y.npz')
data_meta = du.read_sparse_file(f'{data_dir}/category_tst_X_Y.txt')



In [None]:
trn_meta = du.read_sparse_file(f'{data_dir}/category_trn_X_Y.txt')

In [None]:
metric = PrecReclMrr(pred_meta.shape[1], prop=trn_meta, pk=10, rk=200, rep_pk=[1, 3, 5, 10], 
                     rep_rk=[10, 100, 200], mk=[5, 10, 20])

In [None]:
m = metric(**get_output(data_meta, pred_meta))

In [None]:
print(m)

{'P@1': 0.46157789482578937, 'P@10': 0.129043742782284, 'P@3': 0.2856228112179042, 'P@5': 0.2097163619976852, 'N@1': 0.4615779, 'N@10': 0.37416404, 'N@3': 0.37562373, 'N@5': 0.3647661, 'PSP@1': 0.28747141134753296, 'PSP@10': 0.26990502607123174, 'PSP@3': 0.2562446826222358, 'PSP@5': 0.2536693159141603, 'PSN@1': 0.28747144, 'PSN@10': 0.31653425, 'PSN@3': 0.27945155, 'PSN@5': 0.29261345, 'R@200': 0.43547917964633065, 'R@10': 0.3834404785658011, 'R@100': 0.43547917964633065, 'MRR@20': 0.5275643944810756, 'MRR@10': 0.5245715854998169, 'MRR@5': 0.5181164033086407}
