In [2]:
#| default_exp 21_beir-dataset

In [3]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [4]:
#| export
import os, json, pandas as pd, scipy.sparse as sp, numpy as np, argparse

from tqdm.auto import tqdm
from datasets import load_dataset
from dataclasses import dataclass
from huggingface_hub import snapshot_download

from sugar.core import *

## Download data

In [4]:
#| export 
def download_mteb(dset:str, data_dir=None):
    if not os.path.exists(data_dir): os.makedirs(data_dir, exist_ok=True)
    snapshot_download(repo_id=f"mteb/{dset}", repo_type="dataset", local_dir=data_dir)
    

In [5]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/nq/'

In [7]:
download_mteb('nq', data_dir)

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

.gitattributes:   0%|          | 0.00/2.41k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/7.41k [00:00<?, ?B/s]

## Load data

In [14]:
#| export
def load_queries(fname):
    queries = dict()
    with open(fname, 'r') as file:
        for line in file:
            data = json.loads(line)
            queries[data['_id']] = data['text']
    return queries
        

In [20]:
#| export
def load_labels(fname):
    labels, lid_to_idx = [], dict()
    with open(fname, 'r') as file:
        for idx,line in enumerate(file):
            data = json.loads(line)
            lid_to_idx[data['_id']] = idx
            labels.append(data['text'])
    return labels, lid_to_idx
        

In [93]:
#| export
def load_qrels(fname, lbl_id2idx=None):
    qrels = pd.read_table(fname)
    # assert (qrels['score'] == 1).all(), 'Score should contain all 1s'

    query_to_labels = dict()
    for qid, lid, sc in tqdm(zip(qrels['query-id'], qrels['corpus-id'], qrels['score']), total=qrels.shape[0]):
        if sc > 0 and str(lid) in lbl_id2idx: query_to_labels.setdefault(qid, []).append(str(lid) if lbl_id2idx is None else lbl_id2idx[str(lid)])
    
    return query_to_labels
    

In [40]:
data_dir

'/home/scai/phd/aiz218323/scratch/datasets/nq/'

In [23]:
queries = load_queries(f'{data_dir}/queries.jsonl')

In [24]:
labels, lid_to_idx = load_labels(f'{data_dir}/corpus.jsonl')

In [43]:
tst_qrels = load_qrels(f'{data_dir}/qrels/test.tsv', lid_to_idx)

  0%|          | 0/4201 [00:00<?, ?it/s]

## Construct matrix

In [45]:
tst_mat, tst_ids = get_matrix_from_item2idx(tst_qrels, len(labels))

  0%|          | 0/3452 [00:00<?, ?it/s]

In [58]:
#| export
def get_matrix(fname, lbl_id2idx):
    mapping = load_qrels(fname, lbl_id2idx)
    return get_matrix_from_item2idx(mapping, len(lbl_id2idx))
    

In [59]:
#| export
@dataclass
class QueryInfo:
    mat: sp.csr_matrix
    ids: list
    txt: list

    def sample_labels(self, lbl_idx:list):
        data_idx = np.where(self.mat.getnnz(axis=1) > 0)[0]
        
        self.mat = self.mat[:, lbl_idx][data_idx, :]
        self.ids = [self.ids[i] for i in data_idx]
        self.txt = [self.txt[i] for i in data_idx]

@dataclass
class LabelInfo:
    ids: list
    txt: list

    def sample(self, valid_idx:list):
        self.ids = [self.ids[i] for i in valid_idx]
        self.txt = [self.txt[i] for i in valid_idx]
    

In [73]:
#| export
def get_dataset(query_file:str, lbl_file:str, tst_file:str, trn_file:str=None):
    queries = load_queries(query_file)
    
    lbl_txt, lbl_id2idx = load_labels(lbl_file)
    lbl_ids = sorted(lbl_id2idx, key=lambda x: lbl_id2idx[x])
    lbl_info = LabelInfo(lbl_ids, lbl_txt)
    
    tst_mat, tst_ids = get_matrix(tst_file, lbl_id2idx)
    tst_txt = [queries[o] for o in tst_ids]
    tst_info = QueryInfo(tst_mat, tst_ids, tst_txt)

    if trn_file is not None:
        trn_mat, trn_ids = get_matrix(trn_file, lbl_id2idx)
        trn_txt = [queries[o] for o in trn_ids]
        trn_info = QueryInfo(trn_mat, trn_ids, trn_txt)
        return trn_info, tst_info, lbl_info
        
    return None, tst_info, lbl_info
    

In [83]:
#| export
def save_dataset(save_dir, tst_info, lbl_info, trn_info=None, suffix=''):
    os.makedirs(save_dir, exist_ok=True)
    x_suffix = f'_{suffix}' if len(suffix) else ''

    if trn_info is not None: sp.save_npz(f'{save_dir}/trn_X_Y{x_suffix}.npz', trn_info.mat)
    sp.save_npz(f'{save_dir}/tst_X_Y{x_suffix}.npz', tst_info.mat)
    
    os.makedirs(f'{save_dir}/raw_data', exist_ok=True)
    y_suffix = f'.{suffix}' if len(suffix) else ''
    if trn_info is not None: save_raw_file(f'{save_dir}/raw_data/train.raw.csv', trn_info.ids, trn_info.txt)
    save_raw_file(f'{save_dir}/raw_data/test.raw.csv', tst_info.ids, tst_info.txt)
    save_raw_file(f'{save_dir}/raw_data/label{y_suffix}.raw.csv', lbl_info.ids, lbl_info.txt)
    

In [72]:
#| export
def sample_dataset(tst_info, lbl_info, trn_info=None, sampling_type=None):
    if sampling_type == 'exact':
        valid_idx = np.where(trn_info.mat.getnnz(axis=0) > 0)[0] if trn_info is None else np.array([], dtype=tst_info.mat.dtype)
        tst_valid_idx = np.where(tst_info.mat.getnnz(axis=0) > 0)[0]
        valid_idx = np.union1d(trn_valid_idx, tst_valid_idx)
    elif sampling_type == 'xc':
        valid_idx = np.where(trn_info.mat.getnnz(axis=0) > 0)[0]
    else:
        raise ValueError(f'Invalid sampling value: {sampling_type}.')
            
    trn_info.sample_labels(valid_idx)
    tst_info.sample_labels(valid_idx)
    lbl_info.sample(valid_idx)
    

In [77]:
#| export
def get_and_save_dataset(query_file:str, lbl_file:str, tst_file:str, trn_file:str=None, save_dir:str=None, sampling_type=None, suffix=''):
    trn_info, tst_info, lbl_info = get_dataset(query_file, lbl_file, tst_file, trn_file)
    if sampling_type is not None: sample_dataset(tst_info, lbl_info, trn_info, sampling_type)
    if save_dir is not None: save_dataset(save_dir, tst_info, lbl_info, trn_info, suffix)
    return trn_info, tst_info, lbl_info
    

In [None]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--download', action='store_true')
    parser.add_argument('--dataset', type=str)
    
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--save_dir', type=str)
    
    return parser.parse_args()
    

In [None]:
#| export
if __name__ == '__main__':
    args = parse_args()
    if args.download: 
        download_mteb(args.dataset, args.data_dir)
    else:
        get_and_save_dataset(f'{args.data_dir}/queries.jsonl', f'{args.data_dir}/corpus.jsonl', f'{args.data_dir}/qrels/test.tsv', 
                             save_dir=args.save_dir)
                          

In [92]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/nq/'
save_dir = '/home/scai/phd/aiz218323/scratch/datasets/nq/XC'

query_file, lbl_file, tst_file = f'{data_dir}/queries.jsonl', f'{data_dir}/corpus.jsonl', f'{data_dir}/qrels/test.tsv'
trn_info, tst_info, lbl_info = get_and_save_dataset(query_file, lbl_file,tst_file, save_dir=save_dir)

  0%|          | 0/4201 [00:00<?, ?it/s]

  0%|          | 0/3452 [00:00<?, ?it/s]