In [1]:
#| default_exp 30_ngame-for-msmarco-with-hard-negatives

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [4]:
#| export
import os,torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp

from xcai.basics import *
from xcai.models.PPP0XX import DBT009,DBT011

In [5]:
os.environ['WANDB_MODE'] = 'disabled'

In [6]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT'] = 'mogicX_00-msmarco'

## Setup

In [10]:
output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/30_ngame-for-msmarco-with-hard-negatives'
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'

config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_gpt_exact.json'
config_key = 'data_entity-gpt_exact'

mname = 'sentence-transformers/msmarco-distilbert-dot-v5'

In [11]:
do_train_inference = False
do_test_inference = False

save_train_inference = False
save_test_inference = False

save_representation = False

use_sxc_sampler, only_test = True, False

In [12]:
pkl_file = f'{pkl_dir}/mogicX/msmarco_data-meta_distilbert-base-uncased'
pkl_file = f'{pkl_file}_sxc' if use_sxc_sampler else f'{pkl_file}_xcs'
if only_test: pkl_file = f'{pkl_file}_only-test'
pkl_file = f'{pkl_file}.joblib'

In [178]:
do_inference = do_train_inference or do_test_inference or save_train_inference or save_test_inference or save_representation

In [12]:
%%time
os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
block = build_block(pkl_file, config_file, use_sxc_sampler, config_key, do_build=True, only_test=False)

CPU times: user 11min 49s, sys: 1min 15s, total: 13min 5s
Wall time: 5min 21s


### `Negatives`

In [7]:
import pickle, scipy.sparse as sp
from tqdm.auto import tqdm

In [8]:
data_dir = "/home/scai/phd/aiz218323/scratch/datasets/msmarco/negatives"
fname = f"{data_dir}/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl"

In [9]:
def load_msmarco_hard_negatives(fname, query_ids):
    with open(fname, 'rb') as file:
        o = pickle.load(file)
    
    data, indices, indptr = [], [], [0]
    for i in tqdm(query_ids):
        if i in o:
            data.extend(list(o[i].values()))
            indices.extend(list(o[i].keys()))
        indptr.append(len(data))
    
    return sp.csr_matrix((data, indices, indptr))
    

In [66]:
query_ids = [int(i) for i in block.train.dset.data.data_info['identifier']]

In [69]:
neg = load_msmarco_hard_negatives(fname, query_ids)

  0%|          | 0/502939 [00:00<?, ?it/s]

In [89]:
lbl_ids = [int(i) for i in block.train.dset.data.lbl_info['identifier']]
ids = set(lbl_ids)
meta_ids = [i for i in range(neg.shape[1]) if i not in ids]

n1 = neg[:, lbl_ids]
n2 = neg[:, meta_ids]

neg_ids = lbl_ids + meta_ids
neg = sp.hstack([n1, n2])

In [93]:
sp.save_npz(f'{data_dir}/negatives_trn_X_Y.npz', neg)

In [101]:
from sugar.core import *

In [102]:
fname = '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC/raw_data/label.raw.txt'
lbl_ids, lbl_txt = load_raw_file(fname)
lbl_info = {k:v for k,v in zip(lbl_ids, lbl_txt)}

In [108]:
neg_txt = [lbl_info[str(i)] for i in neg_ids]

In [112]:
save_raw_file(f'{data_dir}/negatives.raw.txt', neg_ids, neg_txt)

In [118]:
sp.save_npz(f'{data_dir}/negatives_lbl_X_Y.npz', sp.csr_matrix((523598, 8841823), dtype=np.float64))

### `Block`

In [13]:
config_file = "/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC/configs/negatives_exact.json"
config_key = 'data_negatives_exact'

In [14]:
pkl_file = f'{pkl_dir}/mogicX/msmarco_data-neg_distilbert-base-uncased'
pkl_file = f'{pkl_file}_sxc' if use_sxc_sampler else f'{pkl_file}_xcs'
if only_test: pkl_file = f'{pkl_file}_only-test'
pkl_file = f'{pkl_file}.joblib'

In [16]:
pkl_file

'/scratch/scai/phd/aiz218323/datasets/processed//mogicX/msmarco_data-neg_distilbert-base-uncased_sxc.joblib'

In [None]:
os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
block = build_block(pkl_file, config_file, use_sxc_sampler, config_key, do_build=False, only_test=False, use_meta_distribution=True,
                   meta_oversample=True, n_sdata_meta_samples=10)

In [None]:
print('HI')

### `Model`

In [167]:
from typing import Optional
from xcai.losses import PKMMultiTripletFromScores
from xcai.models.PPP0XX import XCModelOutput

In [168]:
class DBT021(DBT009):

    def __init__(
        self,
        config,
        margin:Optional[float]=0.3,
        tau:Optional[float]=0.1,
        apply_softmax:Optional[bool]=True,
        n_negatives:Optional[int]=10,
        **kwargs
    ):
        super().__init__(config, margin=margin, tau=tau, apply_softmax=apply_softmax, n_negatives=n_negatives, **kwargs)
        self.loss_fn = PKMMultiTripletFromScores(margin=margin, n_negatives=n_negatives, tau=tau, apply_softmax=apply_softmax, 
                                                 reduce='mean')

    def _get_scores(self, data_repr:torch.Tensor, lbl2data_repr:torch.Tensor, neg2data_repr:Optional[torch.Tensor]=None):
        bsz = data_repr.shape[0]
        n_meta = neg2data_repr.shape[0] // bsz
    
        lbl_scores = data_repr @ lbl2data_repr.T

        neg_scores = None
        if neg2data_repr is not None:
            neg_scores = data_repr.unsqueeze(1) @ neg2data_repr.view(bsz, n_meta, -1).transpose(1, 2)
            neg_scores = neg_scores.squeeze(1)
        
        return lbl_scores if neg_scores is None else torch.hstack([lbl_scores, neg_scores])

    def _get_indices(self, lbl2data_idx:torch.Tensor, neg2data_idx:Optional[torch.Tensor]=None):
        bsz = len(lbl2data_idx)
        n_meta = len(neg2data_idx) // bsz
        
        lbl_idx = torch.repeat_interleave(lbl2data_idx.unsqueeze(0), bsz, 0)
        neg_idx = None if neg2data_idx is None else neg2data_idx.view(bsz, n_meta)
        
        return lbl_idx if neg_idx is None else torch.hstack([lbl_idx, neg_idx])
    
    def forward(
        self,
        data_input_ids:Optional[torch.Tensor]=None,
        data_attention_mask:Optional[torch.Tensor]=None,
        lbl2data_data2ptr:Optional[torch.Tensor]=None,
        lbl2data_idx:Optional[torch.Tensor]=None,
        lbl2data_input_ids:Optional[torch.Tensor]=None,
        lbl2data_attention_mask:Optional[torch.Tensor]=None,
        plbl2data_data2ptr:Optional[torch.Tensor]=None,
        plbl2data_idx:Optional[torch.Tensor]=None,

        neg2data_data2ptr:Optional[torch.Tensor]=None,
        neg2data_idx:Optional[torch.Tensor]=None,
        neg2data_input_ids:Optional[torch.Tensor]=None,
        neg2data_attention_mask:Optional[torch.Tensor]=None,
        
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if self.use_encoder_parallel: 
            encoder = nn.DataParallel(module=self.encoder)
        else: encoder = self.encoder
        
        data_o, data_repr = encoder(data_input_ids, data_attention_mask, 
                                    output_attentions=output_attentions, 
                                    output_hidden_states=output_hidden_states,
                                    return_dict=return_dict)
        
        loss, lbl2data_repr = None, None
        if lbl2data_input_ids is not None:
            lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask,  
                                                output_attentions=output_attentions, 
                                                output_hidden_states=output_hidden_states,
                                                return_dict=return_dict)
            neg2data_repr = None
            if neg2data_input_ids is not None:
                neg2data_o, neg2data_repr = encoder(neg2data_input_ids, neg2data_attention_mask,
                                                    output_attentions=output_attentions, 
                                                    output_hidden_states=output_hidden_states,
                                                    return_dict=return_dict)

                assert torch.all(neg2data_data2ptr == neg2data_data2ptr.max()), f'All datapoints should have equal negatives'
                
            scores, idx = self._get_scores(data_repr, lbl2data_repr, neg2data_repr), self._get_indices(lbl2data_idx, neg2data_idx)
            loss = self.loss_fn(scores, idx, plbl2data_data2ptr, plbl2data_idx)

        if not return_dict:
            o = (data_repr, lbl2data_repr)
            return ((loss,) + o) if loss is not None else o

        return XCModelOutput(
            loss=loss,
            data_repr=data_repr,
            lbl2data_repr=lbl2data_repr,
        )
        

In [169]:
model = DBT021.from_pretrained(mname, bsz=1000, tn_targ=5000, margin=0.3, tau=0.1, n_negatives=10, 
                               apply_softmax=True, use_encoder_parallel=False)

Some weights of DBT021 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-dot-v5 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Training

In [105]:
args = XCLearningArguments(
    output_dir=output_dir,
    logging_first_step=True,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    eval_strategy="steps",
    eval_steps=10,
    save_steps=10,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    adam_epsilon=1e-6,                                                                                                                                          warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-5,

    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,

    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',

    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
)

In [106]:
metric = PrecReclMrr(block.test.dset.n_lbl, block.test.data_lbl_filterer,
                     pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

In [170]:
learn = XCLearner(
    model=model,
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [171]:
dataloader = learn.get_train_dataloader()
batch = next(iter(dataloader))

In [172]:
batch.keys()

dict_keys(['data_input_ids', 'data_attention_mask', 'plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'neg2data_idx', 'neg2data_data2ptr', 'neg2data_input_ids', 'neg2data_attention_mask'])

In [79]:
def func():
    import pdb; pdb.set_trace()
    o = model(**batch)
    

In [173]:
batch = batch.to(model.device)
o = model(**batch)

In [176]:
o

XCModelOutput(loss=tensor(0.0275, device='cuda:0', grad_fn=<MeanBackward0>), logits=None, lm_loss=None, dr_loss=None, data_repr=tensor([[ 0.0210, -0.0089, -0.0497,  ...,  0.0260,  0.0203,  0.0314],
        [ 0.0080, -0.0119, -0.0603,  ...,  0.0212,  0.0167,  0.0730],
        [ 0.0382,  0.0116, -0.0456,  ...,  0.0267, -0.0068,  0.0473],
        ...,
        [ 0.0305, -0.0044, -0.0332,  ...,  0.0222,  0.0502,  0.0704],
        [ 0.0059,  0.0152, -0.0177,  ...,  0.0201, -0.0265,  0.0305],
        [ 0.0344, -0.0250, -0.0396,  ...,  0.0449,  0.0415,  0.0571]],
       device='cuda:0', grad_fn=<DivBackward0>), lbl2data_repr=tensor([[ 0.0232, -0.0060, -0.0577,  ...,  0.0383,  0.0307,  0.0272],
        [ 0.0512, -0.0025, -0.0409,  ...,  0.0580,  0.0047,  0.0745],
        [ 0.0427,  0.0073, -0.0446,  ...,  0.0403,  0.0536,  0.0543],
        ...,
        [ 0.0419, -0.0041, -0.0188,  ...,  0.0378,  0.0434,  0.0771],
        [-0.0034,  0.0121,  0.0002,  ...,  0.0340,  0.0228,  0.0703],
        [ 0.

## Driver

In [None]:
#| export
if __name__ == '__main__':
    output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/30_ngame-for-msmarco-with-hard-negatives'

    config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/negatives_exact.json.json'
    config_key = 'data_negatives_exact'
    
    mname = 'sentence-transformers/msmarco-distilbert-cos-v5'

    input_args = parse_args()

    pkl_file = f'{input_args.pickle_dir}/mogicX/msmarco_data-neg_distilbert-base-uncased'
    pkl_file = f'{pkl_file}_sxc' if input_args.use_sxc_sampler else f'{pkl_file}_xcs'
    if input_args.only_test: pkl_file = f'{pkl_file}_only-test'
    pkl_file = f'{pkl_file}.joblib'

    do_inference = input_args.do_train_inference or input_args.do_test_inference or input_args.save_train_prediction or input_args.save_test_prediction or input_args.save_representation

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, 
                        only_test=input_args.only_test, use_meta_distribution=True, meta_oversample=True, n_sdata_meta_samples=10)

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        adam_epsilon=1e-6,                                                                                                                                          warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,
    
        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
    
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
    )

    def model_fn(mname, bsz):
        model = DBT009.from_pretrained(mname, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, n_negatives=10, 
                                       apply_softmax=True, use_encoder_parallel=True)
        return model
    
    def init_fn(model): 
        model.init_dr_head()

    metric = PrecReclMrr(block.test.dset.n_lbl, block.test.data_lbl_filterer,
                     pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

    model = load_model(args.output_dir, model_fn, {"mname": mname, "bsz": bsz}, init_fn, do_inference=do_inference, use_pretrained=input_args.use_pretrained)
    
    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    main(learn, input_args, n_lbl=block.test.dset.n_lbl)
    