In [1]:
#| default_exp 37_training-msmarco-distilbert-from-scratch

In [2]:
%reload_ext autoreload
%autoreload 2

In [2]:
#| export
import os,torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp

from transformers import DistilBertConfig

from xcai.basics import *
from xcai.models.PPP0XX import DBT023

In [3]:
os.environ['WANDB_MODE'] = 'disabled'

In [4]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT'] = 'mogicX_00-msmarco'

In [45]:
import pickle, scipy.sparse as sp, numpy as np
from tqdm.auto import tqdm
from typing import Optional, List

from sugar.core import *

In [19]:
data_dir = "/home/scai/phd/aiz218323/scratch/datasets/msmarco/negatives"
fname = f"{data_dir}/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl"

In [20]:
with open(fname, 'rb') as file:
    hard_negatives = pickle.load(file)

In [22]:
trn_ids = [int(i) for i in block.train.dset.data.data_info['identifier']]
tst_ids = [int(i) for i in block.test.dset.data.data_info['identifier']]

In [42]:
def load_msmarco_hard_negatives(fname:str, data_ids:Optional[List]=None):
    with open(fname, 'rb') as file:
        negatives = pickle.load(file)

    data_ids = list(negatives) if data_ids is None else data_ids

    lbl_id2idx = dict()
    data, indices, indptr = [], [], [0]
    for idx in tqdm(data_ids):
        if idx in negatives:
            data.extend(list(negatives[idx].values()))
            for i in negatives[idx]:
                index = lbl_id2idx.setdefault(i, len(lbl_id2idx))
                indices.append(index)
        indptr.append(len(data))

    lbl_ids = sorted(lbl_id2idx, key=lambda x: lbl_id2idx[x])
    return data_ids, lbl_ids, sp.csr_matrix((data, indices, indptr), dtype=np.float32)
    

In [44]:
data_ids, neg_ids, data_neg = load_msmarco_hard_negatives(fname, trn_ids)

  0%|          | 0/502939 [00:00<?, ?it/s]

In [47]:
lbl_neg = sp.csr_matrix((block.n_lbl, data_neg.shape[1]), dtype=np.float32)

In [54]:
sp.save_npz(f'{data_dir}/negatives_trn_X_Y.npz', data_neg)
sp.save_npz(f'{data_dir}/negatives_lbl_X_Y_exact.npz', lbl_neg)

In [55]:
fname = '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC/raw_data/label.raw.txt'
lbl_ids, lbl_txt = load_raw_file(fname)
lbl_map = {k:v for k,v in zip(lbl_ids, lbl_txt)}

In [57]:
neg_txt = [lbl_map[str(i)] for i in neg_ids]

In [64]:
save_raw_file(f'{data_dir}/raw_data/negatives.raw.txt', neg_ids, neg_txt)

## Driver

In [None]:
#| export
if __name__ == '__main__':
    output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/37_training-msmarco-distilbert-from-scratch'

    input_args = parse_args()

    if input_args.exact:
        config_file = '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC/configs/data_ce_exact.json'
        config_key = 'data'
    else:
        raise NotImplementedError('Create a configuration file for using all the labels.')
    
    mname = 'distilbert-base-uncased'

    pkl_file = get_pkl_file(input_args.pickle_dir, 'msmarco_data-ce_distilbert-base-uncased', input_args.use_sxc_sampler, 
                            input_args.exact, input_args.only_test)

    do_inference = input_args.do_train_inference or input_args.do_test_inference or input_args.save_train_prediction or input_args.save_test_prediction or input_args.save_representation

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, 
                        only_test=input_args.only_test, main_oversample=True, meta_oversample=True, return_scores=True, 
                        n_slbl_samples=1, n_sdata_meta_samples=10)

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-5,
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,
    
        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
    
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
    )

    def model_fn(mname):
        model = DBT023.from_pretrained(mname, normalize=True, use_encoder_parallel=True)
        return model
    
    def init_fn(model): 
        model.init_dr_head()

    metric = PrecReclMrr(block.test.dset.n_lbl, block.test.data_lbl_filterer, pk=10, rk=200, rep_pk=[1, 3, 5, 10], 
                         rep_rk=[10, 100, 200], mk=[5, 10, 20])

    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

    model = load_model(args.output_dir, model_fn, {"mname": mname}, init_fn, do_inference=do_inference, 
                       use_pretrained=input_args.use_pretrained)
    
    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    main(learn, input_args, n_lbl=block.test.dset.n_lbl)
    