In [1]:
#| default_exp 09_msmarco-dataset

In [2]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [184]:
#| export
from tqdm.auto import tqdm
import json, pandas as pd, scipy.sparse as sp, numpy as np, argparse
from huggingface_hub import snapshot_download

## Download dataset

In [138]:
#| export 
def download_msmarco(data_dir=None):
    if not os.path.exists(data_dir): os.makedirs(data_dir, exist_ok=True)
    snapshot_download(repo_id="mteb/msmarco", repo_type="dataset", local_dir=data_dir)
    

## Load and construct dataset

In [154]:
#| export
def load_queries(fname):
    queries = dict()
    with open(fname, 'r') as file:
        for line in file:
            data = json.loads(line)
            queries[int(data['_id'])] = data['text']
    return queries
        

In [153]:
#| export
def load_passages(fname):
    passages, pid_to_idx = [], dict()
    with open(fname, 'r') as file:
        idx = 0
        for line in file:
            data = json.loads(line)
            pid_to_idx[int(data['_id'])] = idx
            passages.append((data['title'], data['text']))
            idx += 1
    return passages, pid_to_idx
        

In [155]:
#| export
def load_qrels(fname):
    qrels = pd.read_table(fname)
    assert (qrels['score'] == 1).all(), 'Score should contain all 1s'

    query_to_passage = dict()
    for qid, pid in tqdm(zip(qrels['query-id'], qrels['corpus-id']), total=qrels.shape[0]):
        query_to_passage.setdefault(qid, []).append(pid)

    return query_to_passage
    

In [156]:
#| export
def construct_matrix(query_to_passage, pid_to_idx):
    query_id = []

    data, indices, indptr = [], [], [0]
    for qid, pid in tqdm(query_to_passage.items()):
        query_id.append(qid)
        data.extend([1] * len(pid))
        indices.extend([pid_to_idx[o] for o in pid])
        indptr.append(len(indices))
        
    mat = sp.csr_matrix((data, indices, indptr), shape=(len(query_id), len(passages)), dtype=np.int64)
    return mat, query_id


In [157]:
#| export
def load_qrel_matrix(fname, pid_to_idx):
    query_to_passage = load_qrels(fname)
    return construct_matrix(query_to_passage, pid_to_idx)
    

In [176]:
#| export
def save_raw_txt(fname, ids, texts):
    with open(fname, 'w') as file:
        for k,v in tqdm(zip(ids, texts), total=len(ids)):
            file.write(f'{k}->{v}\n')
        

In [200]:
#| export
def save_qrel_matrix(qrel_fname, queries, pid_to_idx, save_dir, data_type='train'):
    mat, query_ids = load_qrel_matrix(qrel_fname, pid_to_idx)
    if data_type == 'train': mat_file, raw_file = f'{save_dir}/trn_X_Y.npz', f'{save_dir}/raw_data/train.raw.txt'
    elif data_type == 'test': mat_file, raw_file = f'{save_dir}/tst_X_Y.npz', f'{save_dir}/raw_data/test.raw.txt'
    else: raise ValueError(f'Invalid data type: {data_type}')
    sp.save_npz(mat_file, mat)
    save_raw_txt(raw_file, query_ids, [queries[o] for o in query_ids])
    

In [196]:
#| export
def construct_msmarco(data_dir:str, save_dir:str, query_fname:str, passage_fname:str, 
                      train_qrel_fname:str, test_qrel_fname:str, is_download:bool):
    if is_download: download_msmarco(data_dir)
    
    if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True)
    if not os.path.exists(f'{save_dir}/raw_data'): os.makedirs(f'{save_dir}/raw_data', exist_ok=True)

    queries = load_queries(query_fname)
    passages, pid_to_idx = load_passages(passage_fname)

    passage_ids = sorted(pid_to_idx, key=lambda x: pid_to_idx[x])
    save_raw_txt(f'{save_dir}/raw_data/titles.raw.txt', passage_ids, [o[0] for o in passages])
    save_raw_txt(f'{save_dir}/raw_data/passages.raw.txt', passage_ids, [o[1] for o in passages])

    save_qrel_matrix(trn_qrel_fname, queries, pid_to_idx, save_dir=save_dir, data_type='train')
    save_qrel_matrix(tst_qrel_fname, queries, pid_to_idx, save_dir=save_dir, data_type='test')
    

In [197]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--download', action='store_true')
    
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--save_dir', type=str, required=True)
    
    parser.add_argument('--query_filename', type=str, required=True)
    parser.add_argument('--passage_filename', type=str, required=True)

    parser.add_argument('--train_qrel_filename', type=str, required=True)
    parser.add_argument('--test_qrel_filename', type=str, required=True)
    return parser.parse_args()
    

In [202]:
#| export
if __name__ == '__main__':
    args = parse_args()
    construct_msmarco(args.data_dir, args.save_dir, args.query_filename, args.passage_filename, 
                      args.train_qrel_filename, args.test_qrel_filename, is_download=args.download)
    

In [198]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/msmarco-data/'
save_dir = '/home/scai/phd/aiz218323/scratch/datasets/msmarco-data/XC'

is_download = False

query_fname, passage_fname = f'{data_dir}/queries.jsonl', f'{data_dir}/corpus.jsonl'
train_qrel_fname, test_qrel_fname = f'{data_dir}/qrels/train.tsv', f'{data_dir}/qrels/dev.tsv'

In [199]:
construct_msmarco(data_dir, save_dir, query_fname, passage_fname, train_qrel_fname, test_qrel_fname, is_download=is_download)

  0%|          | 0/8841823 [00:00<?, ?it/s]

  0%|          | 0/8841823 [00:00<?, ?it/s]

  0%|          | 0/532751 [00:00<?, ?it/s]

  0%|          | 0/502939 [00:00<?, ?it/s]

  0%|          | 0/502939 [00:00<?, ?it/s]

  0%|          | 0/7437 [00:00<?, ?it/s]

  0%|          | 0/6980 [00:00<?, ?it/s]

  0%|          | 0/6980 [00:00<?, ?it/s]

In [2]:
from xcai.config import PARAM
from xcai.block import SXCBlock

In [14]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/msmarco-data/XC'

In [18]:
config = {
    'data' : {
        'path': {
            'train': {
                'data_lbl': f'{data_dir}/trn_X_Y.npz',
                'data_info': f'{data_dir}/raw_data/train.raw.txt',
                'lbl_info': f'{data_dir}/raw_data/passages.raw.txt',
            },
            'test': {
                'data_lbl': f'{data_dir}/tst_X_Y.npz',
                'data_info': f'{data_dir}/raw_data/test.raw.txt',
                'lbl_info': f'{data_dir}/raw_data/passages.raw.txt',
            },
        },
        'parameters': PARAM,
    }
}

In [None]:
block = SXCBlock.from_cfg(config['data'], main_max_data_sequence_length=32, main_max_lbl_sequence_length=128, 
                          padding=True, return_tensors='pt', tokenizer='distilbert-base-uncased')