In [11]:
#| default_exp 22_beir-config-file

In [12]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [13]:
#| export
import json, os, argparse

from typing import Optional
from xcai.config import PARAM

## `BeIR` config

In [19]:
#| export
def get_data_config(
    data_dir:str, 
    add_trn_cfg:Optional[bool]=True,
    suffix:Optional[str]='',
    **kwargs
):
    def get_raw_file(save_dir:str, type:str, suffix:str, excluded_suffix:List):
        return (
            f'{save_dir}/raw_data/{type}.raw.csv' 
            if suffix in excluded_suffix else 
            f'{save_dir}/raw_data/{type}{suffix}.raw.csv'
        )

    def get_mat_file(save_dir:str, type:str, suffix:str, excluded_suffix:List):
        return (
            f'{save_dir}/{type}_X_Y.npz' 
            if suffix in excluded_suffix else 
            f'{save_dir}/{type}_X_Y{suffix}.npz'
        )
        
    PARAM["main_max_data_sequence_length"] = 300
    PARAM["main_max_lbl_sequence_length"] = 512
    
    for k,v in kwargs.items():
        if k in PARAM and v is not None: PARAM[k] = v
            
    cfg = {
        "data": {
            "path": {},
            "parameters": PARAM,
        }
    }

    tst_file = get_mat_file(save_dir, 'tst', suffix, ['_generated'])
    tst_raw_file = get_raw_file(save_dir, 'test', suffix, ['_generated', '_exact', '_xc'])
    lbl_raw_file = get_raw_file(save_dir, 'label', suffix, ['_generated'])
    
    cfg["data"]["path"]["test"] = {"data_lbl": tst_file, "data_info": tst_raw_file, "lbl_info": lbl_raw_file}
    
    if add_trn_cfg:
        trn_file = get_mat_file(save_dir, 'trn', suffix, [])
        trn_raw_file = get_raw_file(save_dir, 'train', suffix, ['_exact', '_xc'])
        
        cfg["data"]["path"]["train"] = {"data_lbl": trn_file, "data_info": trn_raw_file,
                                        "lbl_info": lbl_raw_file}
        del cfg["data"]["path"]["test"]["lbl_info"]
    return cfg
    

In [20]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str)
    parser.add_argument('--suffix', type=str, default='')
    parser.add_argument('--add_train_config', action='store_true')
    parser.add_argument('--main_max_data_sequence_length', type=int, default=None)
    parser.add_argument('--main_max_lbl_sequence_length', type=int, default=None)
    return parser.parse_args()
    

## `__main__`

In [7]:
#| export
if __name__ == '__main__':
    args = parse_args()

    kwargs = {
        'main_max_data_sequence_length': args.main_max_data_sequence_length,
        'main_max_lbl_sequence_length': args.main_max_lbl_sequence_length,
    }

    assert os.path.exists(args.data_dir), f"Path does not exist: {args.data_dir}"
    config = get_data_config(args.data_dir, add_trn_cfg=args.add_train_config, suffix=args.suffix, **kwargs)
    
    os.makedirs(f'{args.data_dir}/configs/', exist_ok=True)
    with open(f'{args.data_dir}/configs/data.json', 'w') as file:
        json.dump(config, file, indent=4)
        

In [9]:
data_dir = "/home/scai/phd/aiz218323/scratch/datasets/nq/XC/"
config = get_data_config(data_dir, add_trn_cfg=False)

In [12]:
print(config)

{'data': {'path': {'test': {'data_lbl': '/home/scai/phd/aiz218323/scratch/datasets/nq/XC//tst_X_Y.npz', 'data_info': '/home/scai/phd/aiz218323/scratch/datasets/nq/XC//raw_data/test.raw.csv', 'lbl_info': '/home/scai/phd/aiz218323/scratch/datasets/nq/XC//raw_data/label.raw.csv'}}, 'parameters': {'transform_type': 'xc', 'smp_features': [('lbl2data', 1, 2), ('hlk2data', 1, 1), ('hlk2lbl2data', 2, 1)], 'pad_token': 0, 'oversample': False, 'keep_attention_mask': True, 'sampling_features': [('lbl2data', 2), ('hlk2data', 1), ('hlk2lbl2data', 1)], 'num_labels': 1, 'num_metadata': 1, 'metadata_name': None, 'info_column_names': ['identifier', 'input_text'], 'use_tokenizer': True, 'tokenizer': 'distilbert-base-uncased', 'tokenization_column': 'input_text', 'main_max_data_sequence_length': 300, 'main_max_lbl_sequence_length': 512, 'meta_max_sequence_length': 32, 'padding': False, 'return_tensors': None, 'sep': '->', 'data_prompt_func': None, 'lbl_prompt_func': None, 'meta_prompt_func': None, 'pad_s