In [1]:
#| default_exp 20_msmarco-hard-negatives

In [2]:
%reload_ext autoreload
%autoreload 2

In [5]:
#| export
import pickle, scipy.sparse as sp
from tqdm.auto import tqdm

from sugar.core import *

## Setup

In [4]:
#| export
def load_msmarco_hard_negatives(fname, query_ids):
    with open(fname, 'rb') as file:
        o = pickle.load(file)
    
    data, indices, indptr = [], [], [0]
    for i in tqdm(query_ids):
        if i in o:
            data.extend(list(o[i].values()))
            indices.extend(list(o[i].keys()))
        indptr.append(len(data))
    
    return sp.csr_matrix((data, indices, indptr))
    

In [6]:
data_dir = "/home/scai/phd/aiz218323/scratch/datasets/msmarco/negatives"
fname = f"{data_dir}/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl"

In [None]:
query_ids = [int(i) for i in block.train.dset.data.data_info['identifier']]
neg = load_msmarco_hard_negatives(fname, query_ids)

In [None]:
lbl_ids = [int(i) for i in block.train.dset.data.lbl_info['identifier']]
ids = set(lbl_ids)
meta_ids = [i for i in range(neg.shape[1]) if i not in ids]

n1 = neg[:, lbl_ids]
n2 = neg[:, meta_ids]

neg_ids = lbl_ids + meta_ids
neg = sp.hstack([n1, n2])

sp.save_npz(f'{data_dir}/negatives_trn_X_Y.npz', neg)
sp.save_npz(f'{data_dir}/negatives_lbl_X_Y.npz', sp.csr_matrix((523598, 8841823), dtype=np.float64))

In [None]:
fname = '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC/raw_data/label.raw.txt'
lbl_ids, lbl_txt = load_raw_file(fname)
lbl_info = {k:v for k,v in zip(lbl_ids, lbl_txt)}

neg_txt = [lbl_info[str(i)] for i in neg_ids]
save_raw_file(f'{data_dir}/negatives.raw.txt', neg_ids, neg_txt)

## Driver

In [None]:
#| export
if __name__ == '__main__':
    data_dir = "/home/scai/phd/aiz218323/scratch/datasets/msmarco/negatives"
    fname = f"{data_dir}/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl"

    query_ids = [int(i) for i in block.train.dset.data.data_info['identifier']]
    neg = load_msmarco_hard_negatives(fname, query_ids)

    lbl_ids = [int(i) for i in block.train.dset.data.lbl_info['identifier']]
    ids = set(lbl_ids)
    meta_ids = [i for i in range(neg.shape[1]) if i not in ids]
    
    n1 = neg[:, lbl_ids]
    n2 = neg[:, meta_ids]
    
    neg_ids = lbl_ids + meta_ids
    neg = sp.hstack([n1, n2])
    
    sp.save_npz(f'{data_dir}/negatives_trn_X_Y.npz', neg)
    sp.save_npz(f'{data_dir}/negatives_lbl_X_Y.npz', sp.csr_matrix((523598, 8841823), dtype=np.float64))

    fname = '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC/raw_data/label.raw.txt'
    lbl_ids, lbl_txt = load_raw_file(fname)
    lbl_info = {k:v for k,v in zip(lbl_ids, lbl_txt)}
    
    neg_txt = [lbl_info[str(i)] for i in neg_ids]
    save_raw_file(f'{data_dir}/negatives.raw.txt', neg_ids, neg_txt)
    