In [3]:
#| default_exp 21_beir-dataset

In [4]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [5]:
#| export
import os, json, pandas as pd, scipy.sparse as sp, numpy as np, argparse, gzip, shutil

from tqdm.auto import tqdm
from datasets import load_dataset
from dataclasses import dataclass
from huggingface_hub import snapshot_download
from typing import Optional, Dict, Tuple, List

from sugar.core import *

## Download data

In [6]:
#| export
def unzip_file(input_file, output_file):
    with gzip.open(input_file, "rb") as f_in:
        with open(output_file, "wb") as f_out:
            shutil.copyfileobj(f_in, f_out)
            

In [7]:
#| export 
def download_beir(dset:str, data_dir:Optional[str]=None):
    if not os.path.exists(data_dir): os.makedirs(data_dir, exist_ok=True)

    try:
        snapshot_download(repo_id=f"BeIR/{dset}", repo_type="dataset", local_dir=data_dir)
        unzip_file(f'{data_dir}/corpus.jsonl.gz', f'{data_dir}/corpus.jsonl')
        unzip_file(f'{data_dir}/queries.jsonl.gz', f'{data_dir}/queries.jsonl')
    except Exception as e:
        print(f"Failed to download '{dset}': {e}")

    try:
        snapshot_download(repo_id=f"BeIR/{dset}-qrels", repo_type="dataset", local_dir=f"{data_dir}/qrels")
    except Exception as e:
        print(f"Failed to download '{dset}-qrels': {e}")

    try:
        snapshot_download(repo_id=f"BeIR/{dset}-generated-queries", repo_type="dataset", local_dir=f"{data_dir}/generated-queries")
        unzip_file(f'{data_dir}/generated-queries/train.jsonl.gz', f'{data_dir}/generated-queries/train.jsonl')
    except Exception as e:
        print(f"Failed to download '{dset}-generated-queries': {e}")
    

In [118]:
data_dir = '/Users/suchith720/Projects/data/beir/nq'
download_beir('nq', data_dir)

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

## Load data

In [8]:
#| export
def load_queries(fname:str):
    queries = dict()
    with open(fname, 'r') as file:
        for line in file:
            data = json.loads(line)
            queries[data['_id']] = data['text']
    return queries
        

In [29]:
fname = '/Users/suchith720/Projects/data/beir/nq/queries.jsonl'
queries = load_queries(fname)

In [9]:
#| export
def load_labels(fname:str):
    lbl_txt, lbl_id2idx = [], dict()
    with open(fname, 'r') as file:
        for idx,line in enumerate(file):
            data = json.loads(line)
            lbl_id2idx[data['_id']] = idx
            lbl_txt.append(data['text'])
    return lbl_txt, lbl_id2idx
        

In [148]:
fname = '/Users/suchith720/Projects/data/beir/nq/corpus.jsonl'
lbl_txt, lbl_id2idx = load_labels(fname)

In [10]:
#| export
def load_qrels(fname:str, lbl_id2idx:Dict):
    qrels = pd.read_table(fname)
    
    qry2lbls = {"scores": {}, "labels": {}}
    for qry, lbl, sc in tqdm(zip(qrels['query-id'], qrels['corpus-id'], qrels['score']), total=qrels.shape[0]):
        if sc > 0 and str(lbl) in lbl_id2idx:
            qry2lbls["labels"].setdefault(qry, []).append(lbl_id2idx[str(lbl)])
            qry2lbls["scores"].setdefault(qry, []).append(sc)
            
    return qry2lbls
    

In [80]:
tst_qrels = load_qrels(f'{data_dir}/qrels/test.tsv', lbl_id2idx)

  0%|          | 0/4201 [00:00<?, ?it/s]

In [11]:
#| export
def load_generated_queries(fname:str, lbl_txt:List, lbl_id2idx:Dict):
    queries = dict()
    qry2lbls = {"scores": {}, "labels": {}}
    
    with open(fname) as file:
        for i,line in enumerate(file):
            data = json.loads(line)
            
            qry, lbl = f'gen{i}', data['_id']
            queries[qry] = data['query']
            assert lbl_txt[lbl_id2idx[lbl]] == data['text'], "Label mismatch in generated queries."

            qry2lbls["labels"].setdefault(qry, []).append(lbl_id2idx[str(lbl)])
            qry2lbls["scores"].setdefault(qry, []).append(1.0)
            
    return queries, qry2lbls
    

In [162]:
fname = '/Users/suchith720/Projects/data/beir/nq/generated-queries/train.jsonl'
gen_queries, gen_trn_qrels = load_generated_queries(fname, labels, lbl_id2idx)

## Construct matrix

In [12]:
#| export
def get_matrix_from_qry2lbl(qry2lbl:Dict, num_labels:int):
    qry_ids = []
    
    data, indices, indptr = [], [], [0]
    for qry, lbls in tqdm(qry2lbl["labels"].items(), total=len(qry2lbl["labels"])):
        data.extend(qry2lbl["scores"][qry])
        indices.extend(lbls)
        indptr.append(len(data))
        qry_ids.append(qry)

    mat = sp.csr_matrix((data, indices, indptr), shape=(len(qry_ids), num_labels), dtype=np.float32)
    return mat, qry_ids
    

In [167]:
tst_mat, tst_ids = get_matrix_from_qry2lbl(tst_qrels, len(lbl_txt))

  0%|          | 0/3452 [00:00<?, ?it/s]

In [163]:
gen_trn_mat, gen_trn_ids = get_matrix_from_qry2lbl(gen_trn_qrels, len(lbl_txt))

  0%|          | 0/7866640 [00:00<?, ?it/s]

In [13]:
#| export
def get_matrix(fname:str, lbl_id2idx:Dict):
    qry2lbls = load_qrels(fname, lbl_id2idx)
    return get_matrix_from_qry2lbl(qry2lbls, len(lbl_id2idx))

def get_gen_matrix(fname:str, lbl_txt:List, lbl_id2idx:Dict):
    queries, qry2lbls = load_generated_queries(fname, lbl_txt, lbl_id2idx)
    mat, qry_ids = get_matrix_from_qry2lbl(qry2lbls, len(lbl_id2idx))
    return queries, mat, qry_ids


In [85]:
tst_mat, tst_ids = get_matrix(f'{data_dir}/qrels/test.tsv', lbl_id2idx)

  0%|          | 0/4201 [00:00<?, ?it/s]

  0%|          | 0/3452 [00:00<?, ?it/s]

In [51]:
#| export
def load_query_info(fname:str, queries:Dict, lbl_id2idx:Dict):
    mat = ids = txt = None
    if fname is not None and os.path.exists(fname):
        mat, ids = get_matrix(fname, lbl_id2idx)
        txt = [queries[str(o)] for o in ids]
    return mat, ids, txt

def load_generated_query_info(fname:str, lbl_txt:List, lbl_id2idx:Dict):
    mat = ids = txt = None
    if fname is not None and os.path.exists(fname):
        queries, mat, ids = get_gen_matrix(fname, lbl_txt, lbl_id2idx)
        txt = [queries[str(o)] for o in ids]
    return mat, ids, txt
    

In [50]:
#| export
def combine_info(info_1:Tuple, info_2:Tuple):
    if info_1[0] is not None and info_2[0] is not None:
        mat = sp.vstack([info_1[0], info_2[0]])
        ids = info_1[1] + info_2[1]
        txt = info_1[2] + info_2[2]
        return mat, ids, txt
    elif info_1[0] is None:
        return info_2
    else:
        return info_1
        

In [49]:
#| export
def get_dataset(qry_file:str, lbl_file:str, tst_file:str, dev_file:Optional[str]=None, trn_file:Optional[str]=None, 
                gen_trn_file:Optional[str]=None):
    queries = load_queries(qry_file)
    
    lbl_txt, lbl_id2idx = load_labels(lbl_file)
    lbl_ids = sorted(lbl_id2idx, key=lambda x: lbl_id2idx[x])

    tst_mat, tst_ids, tst_txt = load_query_info(tst_file, queries, lbl_id2idx)
    dev_mat, dev_ids, dev_txt = load_query_info(dev_file, queries, lbl_id2idx)
    
    trn_mat, trn_ids, trn_txt = load_query_info(trn_file, queries, lbl_id2idx)
    gen_mat, gen_ids, gen_txt = load_generated_query_info(gen_trn_file, lbl_txt, lbl_id2idx)
    trn_mat, trn_ids, trn_txt = combine_info((trn_mat, trn_ids, trn_txt), (gen_mat, gen_ids, gen_txt))
    
    return (lbl_ids, lbl_txt), (tst_mat, tst_ids, tst_txt), (dev_mat, dev_ids, dev_txt), (trn_mat, trn_ids, trn_txt)
    

In [61]:
#| export
def save_dataset(save_dir:str, lbl_info:Tuple, tst_info:Tuple, dev_info:Optional[Tuple]=None, 
                 trn_info:Optional[Tuple]=None, suffix:Optional[str]=''):

    def get_raw_file(save_dir:str, type:str, suffix:str, excluded_suffix:List):
        return (
            f'{save_dir}/raw_data/{type}.raw.csv' 
            if suffix in excluded_suffix else 
            f'{save_dir}/raw_data/{type}{suffix}.raw.csv'
        )

    def get_mat_file(save_dir:str, type:str, suffix:str, excluded_suffix:List):
        return (
            f'{save_dir}/{type}_X_Y.npz' 
            if suffix in excluded_suffix else 
            f'{save_dir}/{type}_X_Y{suffix}.npz'
        )
        
    os.makedirs(save_dir, exist_ok=True)
    suffix = f'_{suffix}' if len(suffix) else ''

    os.makedirs(f'{save_dir}/raw_data', exist_ok=True)
    
    lbl_raw_file = get_raw_file(save_dir, 'label', suffix, ['_generated'])
    save_raw_file(lbl_raw_file, lbl_info[0], lbl_info[1])

    if tst_info[0] is not None:
        tst_file = get_mat_file(save_dir, 'tst', suffix, ['_generated'])
        sp.save_npz(tst_file, tst_info[0])
    
        tst_raw_file = get_raw_file(save_dir, 'test', suffix, ['_generated', '_exact', '_xc'])
        save_raw_file(tst_raw_file, tst_info[1], tst_info[2])
    
    if dev_info[0] is not None:
        dev_file = get_mat_file(save_dir, 'dev', suffix, ['_generated'])
        sp.save_npz(dev_file, dev_info[0])
        
        dev_raw_file = get_raw_file(save_dir, 'dev', suffix, ['_generated', '_exact', '_xc'])
        save_raw_file(dev_raw_file, dev_info[1], dev_info[2])
        
    if trn_info[0] is not None:
        trn_file = get_mat_file(save_dir, 'trn', suffix, [])
        sp.save_npz(trn_file, trn_info[0])
        
        trn_raw_file = get_raw_file(save_dir, 'train', suffix, ['_exact', '_xc'])
        save_raw_file(trn_raw_file, trn_info[1], trn_info[2])
    

In [58]:
#| export
def _get_valid_lbl_idx(tst_mat:sp.csr_matrix, trn_mat:sp.csr_matrix, type:Optional[str]=None):
    if type == 'exact':
        trn_valid_idx = np.array([], dtype=tst_mat.dtype) if trn_mat is None else np.where(trn_mat.getnnz(axis=0) > 0)[0]
        tst_valid_idx = np.where(tst_mat.getnnz(axis=0) > 0)[0]
        return np.union1d(trn_valid_idx, tst_valid_idx)
    elif type == "xc":
        return np.where(trn_mat.getnnz(axis=0) > 0)[0]
    else:
        raise ValueError(f"Invalid sampling type: {type}.")
    

In [59]:
#| export
def get_and_save_dataset(
    qry_file:str, 
    lbl_file:str, 
    tst_file:str, 
    dev_file:Optional[str]=None, 
    trn_file:Optional[str]=None,
    gen_trn_file:Optional[str]=None,
    save_dir:Optional[str]=None, 
    sampling_type:Optional[str]=None,
    suffix:Optional[str]=''
):
    lbl_info, tst_info, dev_info, trn_info = get_dataset(qry_file, lbl_file, tst_file, dev_file, trn_file, 
                                                         gen_trn_file=gen_trn_file)
    
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
        save_dataset(save_dir, lbl_info, tst_info, dev_info, trn_info, suffix)
        
    return trn_info, dev_info, tst_info, lbl_info
    

In [40]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--download', action='store_true')
    parser.add_argument('--dataset', type=str)
    
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--use_generated_queries', action='store_true')
    parser.add_argument('--save_dir', type=str)
    
    return parser.parse_args()
    

In [None]:
#| export
if __name__ == '__main__':
    args = parse_args()
    if args.download: 
        download_beir(args.dataset, args.data_dir)
    else:
        suffix = 'generated' if args.use_generated_queries else ''
        gen_trn_file =  f'{args.data_dir}/generated-queries/train.jsonl' if args.use_generated_queries else None
        
        get_and_save_dataset(
            qry_file=f'{args.data_dir}/queries.jsonl', 
            lbl_file=f'{args.data_dir}/corpus.jsonl', 
            tst_file=f'{args.data_dir}/qrels/test.tsv',
            dev_file=f'{args.data_dir}/qrels/dev.tsv',
            trn_file=f'{args.data_dir}/qrels/train.tsv',
            gen_trn_file=gen_trn_file,
            save_dir=args.save_dir,
            suffix=suffix,
        )
                          

In [194]:
data_dir = '/Users/suchith720/Projects/data/beir/nq'
save_dir = '/Users/suchith720/Projects/data/beir/nq/XC'

query_file, lbl_file = f'{data_dir}/queries.jsonl', f'{data_dir}/corpus.jsonl'
tst_file, dev_file, trn_file = f'{data_dir}/qrels/test.tsv', f'{data_dir}/qrels/dev.tsv', f'{data_dir}/qrels/train.tsv'
gen_trn_file = f'{data_dir}/generated-queries/train.jsonl'

trn_info, tst_info, dev_info, lbl_info = get_and_save_dataset(query_file, lbl_file, tst_file=tst_file, dev_file=dev_file, 
                                                              trn_file=trn_file, gen_trn_file=gen_trn_file, save_dir=save_dir)

  0%|          | 0/4201 [00:00<?, ?it/s]

  0%|          | 0/3452 [00:00<?, ?it/s]

  0%|          | 0/7866640 [00:00<?, ?it/s]