In [1]:
#| default_exp 07_sandwich-for-wikiseealsotitles

In [2]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
%reload_ext autoreload
%autoreload 2

In [5]:
#| export
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

import torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp

from xcai.basics import *
from xcai.models.sandwich import SAW001

In [6]:
os.environ['WANDB_MODE'] = 'disabled'

In [7]:
#| export
os.environ['WANDB_PROJECT'] = 'sandwich_00-wikiseealsotitles'

## Driver

In [None]:
#| export
if __name__ == '__main__':
    output_dir = '/scratch/scai/phd/aiz218323/outputs/cachew/06_ngame-for-wikiseealsotitles-with-augmented-training'

    data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/'
    config_file = 'wikiseealsotitles'
    config_key = 'data_meta'
    
    mname = 'sentence-transformers/msmarco-distilbert-base-v4'

    input_args = parse_args()

    pkl_file = f'{input_args.pickle_dir}/sandwich/wikiseealsotitles_data-meta_distilbert-base-uncased'
    pkl_file = f'{pkl_file}_sxc' if input_args.use_sxc_sampler else f'{pkl_file}_xcs'
    if input_args.only_test: pkl_file = f'{pkl_file}_only-test'
    pkl_file = f'{pkl_file}.joblib'

    do_inference = input_args.do_train_inference or input_args.do_test_inference or input_args.save_train_prediction or input_args.save_test_prediction or input_args.save_representation

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, only_test=input_args.only_test, data_dir=data_dir)

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=1024,
        per_device_eval_batch_size=1024,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,
    
        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
    
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
    )

    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
    
    config = SandwichConfig(
        data_aug_meta_prefix='cat2data', 
        lbl2data_aug_meta_prefix='cat2lbl',
    
        data_enrich=True,
        lbl2data_enrich=True,
    
        batch_size=bsz, 
        num_batch_labels=5000, 
        margin=0.3,
        num_negatives=5,
        tau=0.1,
        apply_softmax=True,
    
        use_calib_loss=True,
        calib_loss_weight=0.1,
        calib_margin=0.05,
        calib_num_negatives=10,
        calib_tau=0.1,
        calib_apply_softmax=False,
    
        use_query_loss=True,
    
        use_meta_loss=False,
        meta_loss_weight=0.1,
        
        use_encoder_parallel=False
    )
    
    def model_fn(mname):
        model = SAW001.from_pretrained(mname, config=config)
        return model
    
    def init_fn(model):
        model.init_heads_to_identity()
        model.init_meta_distilbert()
        model.init_combiner_to_last_layer()

    metric = PrecReclMrr(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                         pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])
    
    model = load_model(args.output_dir, model_fn, {"mname": mname}, init_fn, do_inference=do_inference, 
                       use_pretrained=input_args.use_pretrained)
    
    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    main(learn, input_args, n_lbl=block.n_lbl)
    