In [1]:
#| default_exp 09_msmarco-dataset

In [2]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
#| export
import os, json, pandas as pd, scipy.sparse as sp, numpy as np, argparse, gzip, shutil

from typing import Dict, Optional, Tuple

from tqdm.auto import tqdm
from dataclasses import dataclass
from huggingface_hub import snapshot_download

from sugar.core import *

## Download data

In [4]:
#| export
def unzip_file(input_file, output_file):
    with gzip.open(input_file, "rb") as f_in:
        with open(output_file, "wb") as f_out:
            shutil.copyfileobj(f_in, f_out)
            

In [5]:
#| export 
def download_msmarco(data_dir=None):
    if not os.path.exists(data_dir): os.makedirs(data_dir, exist_ok=True)
    
    snapshot_download(repo_id="BeIR/msmarco", repo_type="dataset", local_dir=data_dir)
    unzip_file(f'{data_dir}/corpus.jsonl.gz', f'{data_dir}/corpus.jsonl')
    unzip_file(f'{data_dir}/queries.jsonl.gz', f'{data_dir}/queries.jsonl')

    snapshot_download(repo_id="BeIR/msmarco-qrels", repo_type="dataset", local_dir=f'{data_dir}/qrels')
    

In [30]:
download_msmarco('/Users/suchith720/Projects/data/beir/msmarco/')

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

## Load data

In [6]:
#| export
def load_queries(fname):
    queries = dict()
    with open(fname, 'r') as file:
        for line in file:
            data = json.loads(line)
            queries[data['_id']] = data['text']
    return queries
        

In [7]:
#| export
def load_labels(fname:str):
    lbl_txt, lbl_id2idx = [], dict()
    with open(fname, 'r') as file:
        for idx,line in enumerate(file):
            data = json.loads(line)
            lbl_id2idx[data['_id']] = idx
            lbl_txt.append(data['text'])
    return lbl_txt, lbl_id2idx
        

In [8]:
#| export
def load_qrels(fname:str, lbl_id2idx:Dict):
    qrels = pd.read_table(fname)
    
    qry2lbls = {"scores": {}, "labels": {}}
    for qry, lbl, sc in tqdm(zip(qrels['query-id'], qrels['corpus-id'], qrels['score']), total=qrels.shape[0]):
        if sc > 0 and str(lbl) in lbl_id2idx:
            qry2lbls["labels"].setdefault(qry, []).append(lbl_id2idx[str(lbl)])
            qry2lbls["scores"].setdefault(qry, []).append(sc)
            
    return qry2lbls
    

In [13]:
data_dir = '/Users/suchith720/Projects/data/beir/msmarco/'

queries = load_queries(f'{data_dir}/queries.jsonl')
lbl_txt, lbl_id2idx = load_labels(f'{data_dir}/corpus.jsonl')

In [14]:
trn_qrels = load_qrels(f'{data_dir}/qrels/train.tsv', lbl_id2idx)
tst_qrels = load_qrels(f'{data_dir}/qrels/dev.tsv', lbl_id2idx)

  0%|          | 0/532751 [00:00<?, ?it/s]

  0%|          | 0/7437 [00:00<?, ?it/s]

In [15]:
len(trn_qrels['labels']), len(tst_qrels['labels'])

(502939, 6980)

## Construct matrix

In [9]:
#| export
def get_matrix_from_qry2lbl(qry2lbl:Dict, num_labels:int):
    qry_ids = []
    
    data, indices, indptr = [], [], [0]
    for qry, lbls in tqdm(qry2lbl["labels"].items(), total=len(qry2lbl["labels"])):
        data.extend(qry2lbl["scores"][qry])
        indices.extend(lbls)
        indptr.append(len(data))
        qry_ids.append(qry)

    mat = sp.csr_matrix((data, indices, indptr), shape=(len(qry_ids), num_labels), dtype=np.float32)
    return mat, qry_ids

def get_matrix(fname:str, lbl_id2idx:Dict):
    qry2lbls = load_qrels(fname, lbl_id2idx)
    return get_matrix_from_qry2lbl(qry2lbls, len(lbl_id2idx))
    

In [16]:
trn_mat, trn_ids = get_matrix_from_qry2lbl(trn_qrels, len(lbl_id2idx))

  0%|          | 0/502939 [00:00<?, ?it/s]

In [17]:
tst_mat, tst_ids = get_matrix_from_qry2lbl(tst_qrels, len(lbl_id2idx))

  0%|          | 0/6980 [00:00<?, ?it/s]

In [10]:
#| export
def load_query_info(fname:str, queries:Dict, lbl_id2idx:Dict):
    mat = ids = txt = None
    if fname is not None and os.path.exists(fname):
        mat, ids = get_matrix(fname, lbl_id2idx)
        txt = [queries[str(o)] for o in ids]
    return mat, ids, txt

def get_msmarco(qry_file:str, lbl_file:str, tst_file:str, trn_file:Optional[str]=None):
    queries = load_queries(qry_file)
    
    lbl_txt, lbl_id2idx = load_labels(lbl_file)
    lbl_ids = sorted(lbl_id2idx, key=lambda x: lbl_id2idx[x])

    tst_mat, tst_ids, tst_txt = load_query_info(tst_file, queries, lbl_id2idx)
    trn_mat, trn_ids, trn_txt = load_query_info(trn_file, queries, lbl_id2idx)
    
    return (lbl_ids, lbl_txt), (tst_mat, tst_ids, tst_txt), (trn_mat, trn_ids, trn_txt)


In [None]:
lbl_info, tst_info, trn_info = get_msmarco(f'{data_dir}/queries.jsonl', f'{data_dir}/corpus.jsonl', 
                                           trn_file=f'{data_dir}/qrels/train.tsv', tst_file=f'{data_dir}/qrels/test.tsv')

  0%|          | 0/9260 [00:00<?, ?it/s]

  0%|          | 0/43 [00:00<?, ?it/s]

  0%|          | 0/532751 [00:00<?, ?it/s]

In [54]:
#| export
def _get_valid_lbl_idx(tst_mat:sp.csr_matrix, trn_mat:sp.csr_matrix, type:Optional[str]=None):
    if type == 'exact':
        trn_valid_idx = np.array([], dtype=tst_mat.dtype) if trn_mat is None else np.where(trn_mat.getnnz(axis=0) > 0)[0]
        tst_valid_idx = np.where(tst_mat.getnnz(axis=0) > 0)[0]
        return np.union1d(trn_valid_idx, tst_valid_idx)
    elif type == "xc":
        return np.where(trn_mat.getnnz(axis=0) > 0)[0]
    else:
        raise ValueError(f"Invalid sampling type: {type}.")

def sample_msmarco(lbl_info:Dict, tst_info:Dict, trn_info:Dict, sampling_type:Optional[str]=None):
    idx = _get_valid_lbl_idx(tst_info[0], trn_info[0], type=sampling_type)
    
    lbl_ids = [lbl_info[0] for i in idx]
    lbl_txt = [lbl_info[1] for i in idx]

    tst_mat = tst_info[0][:, idx]
    trn_mat = trn_info[0][:, idx]

    return (lbl_ids, lbl_txt), (tst_mat, tst_info[1], tst_info[2]), (trn_mat, trn_info[1], trn_info[2])
    

In [62]:
#| export
def save_msmarco(save_dir:str, lbl_info:Tuple, tst_info:Tuple, trn_info:Optional[Tuple]=None, suffix:Optional[str]=''):
    os.makedirs(save_dir, exist_ok=True)
    suffix = f'_{suffix}' if len(suffix) else ''

    os.makedirs(f'{save_dir}/raw_data', exist_ok=True)
    save_raw_file(f'{save_dir}/raw_data/label{suffix}.raw.csv', lbl_info[0], lbl_info[1])

    sp.save_npz(f'{save_dir}/tst_X_Y{suffix}.npz', tst_info[0])
    save_raw_file(f'{save_dir}/raw_data/test.raw.csv', tst_info[1], tst_info[2])
    
    if trn_info[0] is not None:
        sp.save_npz(f'{save_dir}/trn_X_Y{suffix}.npz', trn_info[0])
        save_raw_file(f'{save_dir}/raw_data/train.raw.csv', trn_info[1], trn_info[2])
    

In [58]:
#| export
def get_and_save_msmarco(
    qry_file:str, 
    lbl_file:str, 
    tst_file:str, 
    trn_file:Optional[str]=None,
    save_dir:Optional[str]=None, 
    sampling_type:Optional[str]=None, 
    suffix:Optional[str]=''
):
    lbl_info, tst_info, trn_info = get_msmarco(qry_file, lbl_file, tst_file, trn_file)

    if sampling_type is not None:
        suffix = f"{suffix}_{sampling_type}"
        lbl_info, tst_info, trn_info = sample_msmarco(lbl_info, tst_info, trn_info, sampling_type=sampling_type)
    
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
        save_msmarco(save_dir, lbl_info, tst_info, trn_info, suffix=suffix)
        
    return trn_info, tst_info, lbl_info
    
    

In [56]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--download', action='store_true')
    
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--save_dir', type=str)
    
    return parser.parse_args()
    

In [None]:
#| export
if __name__ == '__main__':
    args = parse_args()
    if args.download: 
        download_msmarco(args.data_dir)
    else:
        get_and_save_msmarco(f'{args.data_dir}/queries.jsonl', f'{args.data_dir}/corpus.jsonl', f'{args.data_dir}/qrels/dev.tsv', 
                             f'{args.data_dir}/qrels/train.tsv', save_dir=args.save_dir, sampling_type='xc')

        get_and_save_msmarco(f'{args.data_dir}/queries.jsonl', f'{args.data_dir}/corpus.jsonl', f'{args.data_dir}/qrels/dev.tsv', 
                             f'{args.data_dir}/qrels/train.tsv', save_dir=args.save_dir, sampling_type='exact')

        get_and_save_msmarco(f'{args.data_dir}/queries.jsonl', f'{args.data_dir}/corpus.jsonl', f'{args.data_dir}/qrels/dev.tsv', 
                             f'{args.data_dir}/qrels/train.tsv', save_dir=args.save_dir)
        

In [None]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/msmarco/'
save_dir = '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC'

sampling_type = 'xc'

query_file, lbl_file, trn_file, tst_file = f'{data_dir}/queries.jsonl', f'{data_dir}/corpus.jsonl', f'{data_dir}/qrels/train.tsv', f'{data_dir}/qrels/dev.tsv'
trn_info, tst_info, lbl_info = get_and_save_msmarco(query_file, lbl_file, trn_file, tst_file, save_dir=save_dir, 
                                                    sampling_type=sampling_type, suffix='')

  0%|          | 0/532751 [00:00<?, ?it/s]

  0%|          | 0/502939 [00:00<?, ?it/s]

  0%|          | 0/7437 [00:00<?, ?it/s]

  0%|          | 0/6980 [00:00<?, ?it/s]