In [1]:
#| default_exp 25_sbert-for-msmarco-inference

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [80]:
#| export
import os,torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm

from xcai.main import *
from xcai.metrics import mrr

In [39]:
os.environ['WANDB_MODE'] = 'disabled'

In [6]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT'] = 'mogicX_00-msmarco'

## Setup

In [7]:
output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/00_ngame-for-msmarco'
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'

config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_gpt_exact.json'
config_key = 'data_entity-gpt_exact'

mname = 'sentence-transformers/msmarco-distilbert-dot-v5'

In [9]:
use_sxc_sampler, only_test = True, False

In [10]:
pkl_file = f'{pkl_dir}/mogicX/msmarco_data_distilbert-base-uncased'
pkl_file = f'{pkl_file}_sxc' if use_sxc_sampler else f'{pkl_file}_xcs'
if only_test: pkl_file = f'{pkl_file}_only-test'
pkl_file = f'{pkl_file}.joblib'

In [73]:
topk = 200
batch_size = 100
device = 'cpu'

In [16]:
%%time
os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
block = build_block(pkl_file, config_file, use_sxc_sampler, config_key, do_build=False, only_test=False)

CPU times: user 12min 38s, sys: 1min 47s, total: 14min 26s
Wall time: 4min 2s


In [26]:
model = SentenceTransformer(mname, device=device)

In [36]:
queries = block.test.dset.data.data_info['input_text'][:1000]
labels = block.test.dset.data.lbl_info['input_text'][:1000]

In [42]:
lbl_embed = [model.encode(labels[idx:idx+batch_size], convert_to_tensor=True, device=device) for idx in tqdm(range(0, len(labels), batch_size))]
lbl_embed = torch.cat(lbl_embed, dim=0)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:52<00:00, 17.24s/it]


In [77]:
scores, idxs = [], []

for idx in tqdm(range(0, len(queries), batch_size)):
    query_embed = model.encode(queries[idx:idx+batch_size], convert_to_tensor=True, device=device)
    sc = util.cos_sim(query_embed, lbl_embed)
    sc, idx = torch.topk(sc, k=topk, largest=True)

    scores.append(sc.to('cpu'))
    idxs.append(idx.to('cpu'))
    break

scores = torch.cat(scores, dim=0)
idxs = torch.cat(idxs, dim=0)
indptr = torch.arange(0, (scores.shape[0]+1) * topk, topk)

  0%|                                                                                                                                                      | 0/10 [00:02<?, ?it/s]


In [79]:
pred_mat = sp.csr_matrix((scores.flatten(), idxs.flatten(), indptr.flatten()))

In [82]:
print(mrr(pred_mat, block.test.dset.data.data_lbl, k=[10]))

## Driver

In [83]:
#| export
if __name__ == '__main__':
    config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_gpt_exact.json'
    config_key = 'data_entity-gpt_exact'

    mname = 'sentence-transformers/msmarco-distilbert-dot-v5'

    topk = 200
    batch_size = 100
    device = 'cpu'

    input_args = parse_args()

    pkl_file = f'{input_args.pickle_dir}/mogicX/msmarco_data-entity-gpt_distilbert-base-uncased'
    pkl_file = f'{pkl_file}_sxc' if input_args.use_sxc_sampler else f'{pkl_file}_xcs'
    if input_args.only_test: pkl_file = f'{pkl_file}_only-test'
    pkl_file = f'{pkl_file}_exact'
    pkl_file = f'{pkl_file}.joblib'

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, 
                        only_test=input_args.only_test)

    model = SentenceTransformer(mname, device=device)

    queries = block.test.dset.data.data_info['input_text']
    labels = block.test.dset.data.lbl_info['input_text']

    lbl_embed = [model.encode(labels[idx:idx+batch_size], convert_to_tensor=True, device=device) for idx in tqdm(range(0, len(labels), batch_size))]
    lbl_embed = torch.cat(lbl_embed, dim=0)

    scores, idxs = [], []
    for idx in tqdm(range(0, len(queries), batch_size)):
        query_embed = model.encode(queries[idx:idx+batch_size], convert_to_tensor=True, device=device)
        sc = util.cos_sim(query_embed, lbl_embed)
        sc, idx = torch.topk(sc, k=topk, largest=True)
    
        scores.append(sc.to('cpu'))
        idxs.append(idx.to('cpu'))
    
    scores = torch.cat(scores, dim=0)
    idxs = torch.cat(idxs, dim=0)
    indptr = torch.arange(0, (scores.shape[0]+1) * topk, topk)

    pred_mat = sp.csr_matrix((scores.flatten(), idxs.flatten(), indptr.flatten()))
    print(mrr(pred_mat, block.test.dset.data.data_lbl, k=[10]))
    