In [1]:
#| default_exp 39_oak-for-msmarco-with-hard-negatives

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [4]:
#| export
import os,torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp
from transformers import DistilBertConfig

from xcai.main import *
from xcai.basics import *
from xcai.clustering.cluster import get_cluster_size

from xcai.models.oak import OAK015

In [5]:
os.environ['WANDB_MODE'] = 'disabled'

In [6]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT'] = 'mogicX_00-msmarco'

## Setup

In [7]:
from xcai.config import WIKISEEALSOTITLES

In [7]:
data_dir = '/Users/suchith720/Projects/data/'

config_key = "data_meta"
config_dir = "/Users/suchith720/Projects/mogicX/configs"

In [None]:
config = WIKISEEALSOTITLES(data_dir)[config_key]

lnk_meta = {
    'prefix': 'lnk',
    'data_meta': '/Users/suchith720/Projects/data//(mapped)LF-WikiSeeAlsoTitles-320K/category_renee_trn_X_Y.npz',
    'lbl_meta': '/Users/suchith720/Projects/data//(mapped)LF-WikiSeeAlsoTitles-320K/category_renee_lbl_X_Y.npz',
    'meta_info': '/Users/suchith720/Projects/data//(mapped)LF-WikiSeeAlsoTitles-320K/raw_data/category.raw.txt'
}
config["path"]["train"]["lnk_meta"] = lnk_meta

with open(f"{config_dir}/39_oak-for-msmarco-with-hard-negatives_test.json", 'w') as file:
    json.dump({'data_meta': config}, file, indent=4)

In [8]:
config_file = f"{config_dir}/39_oak-for-msmarco-with-hard-negatives_test.json"

In [9]:
pkl_dir = f"{data_dir}/processed/mogicX"
pkl_file = get_pkl_file(pkl_dir, 'wikiseealsotitles_data-oak-for-msmarco-with-hard-negatives-test_distilbert-base-uncased', 
                        True, False, False)

In [62]:
%%time
block = build_block(pkl_file, config_file, True, config_key=config_key, only_test=False, main_oversample=True, 
                    meta_oversample={"cat_meta":False, "lnk_meta":True}, n_slbl_samples=5, 
                    n_sdata_meta_samples={"cat_meta":2, "lnk_meta":4}, do_build=False, 
                    train_meta_topk={"lnk_meta":10}, test_meta_topk={"lnk_meta":10}, return_scores=True)


CPU times: user 1min 6s, sys: 11.1 s, total: 1min 17s
Wall time: 1min 25s


In [63]:
batch = block.train.dset.__getitems__([10, 20])

In [64]:
batch['data_input_text']

['Austroasiatic languages', 'Albania']

In [65]:
batch['lbl2data_data2ptr']

tensor([5, 5])

In [67]:
batch['lbl2data_input_text']

['Austroasiatic languages',
 'Austric languages',
 'Austric languages',
 'Austric languages',
 'Austric languages',
 'Index of Albania-related articles',
 'Albania',
 'Index of Albania-related articles',
 'Albania',
 'Index of Albania-related articles']

In [55]:
block.train.dset.meta['cat_meta'].data_meta.getnnz(axis=1).mean(), block.train.dset.meta['lnk_meta'].data_meta.getnnz(axis=1).mean()

(np.float64(4.8924975688302395), np.float64(2.0))

In [43]:
block.train.dset.meta['cat_meta'].meta_oversample, block.train.dset.meta['lnk_meta'].meta_oversample

(False, False)

In [44]:
block.train.dset.meta['cat_meta'].n_sdata_meta_samples, block.train.dset.meta['lnk_meta'].n_sdata_meta_samples

(10, 20)

In [45]:
batch['cat2data_data2ptr']

tensor([ 4, 10])

In [46]:
batch['lnk2data_data2ptr']

tensor([10, 10])

## Metadata clusters

In [12]:
metadata_idx2cluster, meta_repr, num_meta_cluster = get_cluster_mapping(cluster_sz=3, mname='distilbert-base-uncased', 
                                                                        meta_info=block.train.dset.meta['cat_meta'].meta_info, 
                                                                        collator=block.collator, normalize=True, use_layer_norm=True)


## Driver

In [None]:
#| export
if __name__ == '__main__':
    output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/39_oak-for-msmarco-with-hard-negatives'
    
    input_args = parse_args()

    if input_args.exact:
        config_file = '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC/configs/oak-for-msmarco-with-hard-negatives.json'
        config_key = 'data'
    else:
        raise NotImplementedError('Create a configuration file for using all the labels.')
    
    mname, meta_name = 'distilbert-base-uncased', 'lnk'
    meta_embed_init_file = None

    pkl_file = get_pkl_file(input_args.pickle_dir, 'msmarco_data-oak-for-msmarco-with-hard-negatives_distilbert-base-uncased', input_args.use_sxc_sampler, 
                            input_args.exact, input_args.only_test)

    do_inference = input_args.do_train_inference or input_args.do_test_inference or input_args.save_train_prediction or input_args.save_test_prediction or input_args.save_representation

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, 
                        only_test=input_args.only_test, main_oversample=False, meta_oversample={'lnk_meta':False, 'neg_meta':True}, 
                        n_slbl_samples=1, n_sdata_meta_samples={'lnk_meta':5, 'neg_meta':10}, 
                        train_meta_topk={"lnk_meta":5}, test_meta_topk={"lnk_meta":5})

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps=500,
        save_steps=500,
        save_total_limit=5,
        num_train_epochs=30,
        predict_with_representation=True,
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=6e-6,
        representation_search_type='BRUTEFORCE',
    
        output_representation_attribute='data_fused_repr',
        label_representation_attribute='data_repr',
        metadata_representation_attribute='data_repr',
        data_augmentation_attribute='data_repr',
        representation_attribute='data_fused_repr',
        clustering_representation_attribute='data_fused_repr',

        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        use_data_metadata_for_clustering=True,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
    
        use_distributional_representation=False,
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
        
        label_names=[f'{meta_name}2data_idx', f'{meta_name}2data_data2ptr'],
        
        prune_metadata=False,
        num_metadata_prune_warmup_epochs=10,
        num_metadata_prune_epochs=5,
        metadata_prune_batch_size=2048,
        prune_metadata_names=[f'{meta_name}_meta'],
        use_data_metadata_for_pruning=True,
    
        predict_with_augmentation=False,
        use_augmentation_index_representation=True,
    
        data_aug_meta_name=meta_name,
        augmentation_num_beams=None,
        data_aug_prefix=meta_name,
        use_label_metadata=False,

        data_meta_batch_size=2048,
        augment_metadata=False,
        num_metadata_augment_warmup_epochs=10,
        num_metadata_augment_epochs=5,
    
        use_cpu_for_searching=True,
        use_cpu_for_clustering=True,
    )

    def model_fn(mname):
        model = OAK015.from_pretrained(mname, margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                                       
                                       data_aug_meta_prefix=f'{meta_name}2data', lbl2data_aug_meta_prefix=None,
                                       neg2data_aug_meta_prefix=None,
                                       
                                       num_metadata=block.train.dset.meta[f'{meta_name}_meta'].n_meta, resize_length=5000,
                                       
                                       calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                                       calib_loss_weight=0.1, use_calib_loss=True,
        
                                       use_query_loss=True,
                                       
                                       use_encoder_parallel=True, normalize=True)
        return model
        
    def init_fn(model):
        model.init_retrieval_head()
        # model.init_cross_head()
        model.init_meta_embeddings()

        meta_embeddings = torch.tensor(np.load(meta_embed_init_file), dtype=torch.float32)
        model.encoder.set_pretrained_meta_embeddings(meta_embeddings)
        model.encoder.freeze_pretrained_meta_embeddings()

    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
    model = load_model(args.output_dir, model_fn, {"mname": mname, "bsz": bsz}, init_fn, do_inference=do_inference, use_pretrained=input_args.use_pretrained)
    
    metric = PrecReclMrr(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                         pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])
    
    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    main(learn, input_args, n_lbl=block.n_lbl)
    