In [48]:
#| default_exp 22_beir-config-file

In [49]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [57]:
#| export
import json, os, argparse
from xcai.config import PARAM

## `MSMARCO` config

In [60]:
#| export
def get_dataset_config(data_dir, model='', entity_type='', suffix='', add_trn_cfg=True, add_linker_cfg=True):
    mat_suffix = f'_{suffix}' if len(suffix) else ''
    raw_suffix = f'.{suffix}' if len(suffix) else ''

    cfg_key = f"data_{entity_type}" if len(entity_type) else "data"
    if len(model): cfg_key = f"{cfg_key}-{model}"
    if len(mat_suffix): cfg_key = f"{cfg_key}{mat_suffix}"
    
    entity_suffix = f'{entity_type}_' if len(entity_type) else ''
    if len(model): entity_suffix = f'{entity_suffix}{model}_'
    
    cfg = {
        cfg_key: {
            "path": {},
            "parameters": PARAM,
        }
    }
    cfg[cfg_key]["parameters"]["main_max_lbl_sequence_length"] = 128
    
    if add_trn_cfg:
        trn_cfg = {
            "data_lbl": f"{data_dir}/trn_X_Y{mat_suffix}.npz",
            "data_info": f"{data_dir}/raw_data/train.raw.csv",
            "lbl_info": f"{data_dir}/raw_data/label{raw_suffix}.raw.csv",
        }
        if add_linker_cfg:
            trn_meta_cfg = {
                "prefix": "ent",
                "data_meta": f"{data_dir}/{entity_suffix}trn_X_Y.npz",
                "lbl_meta": f"{data_dir}/{entity_suffix}lbl_X_Y{mat_suffix}.npz",
                "meta_info": f"{data_dir}/raw_data/{entity_suffix[:-1]}.raw.txt"
            }
            trn_cfg["ent_meta"] = trn_meta_cfg
        cfg[cfg_key]["path"]["train"] = trn_cfg

    tst_cfg = {
        "data_lbl": f"{data_dir}/tst_X_Y{mat_suffix}.npz",
        "data_info": f"{data_dir}/raw_data/test.raw.csv",
        "lbl_info": f"{data_dir}/raw_data/label{raw_suffix}.raw.csv",
    }
    if add_linker_cfg:
        tst_meta_cfg = {
            "prefix": "ent",
            "data_meta": f"{data_dir}/{entity_suffix}tst_X_Y.npz",
            "lbl_meta": f"{data_dir}/{entity_suffix}lbl_X_Y{mat_suffix}.npz",
            "meta_info": f"{data_dir}/raw_data/{entity_suffix[:-1]}.raw.txt"
        }
        tst_cfg["ent_meta"] = tst_meta_cfg
    cfg[cfg_key]["path"]["test"] = tst_cfg
    
    return cfg
    

In [61]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--model', type=str, default='')
    parser.add_argument('--entity_type', type=str, default='')
    parser.add_argument('--suffix', type=str, default='')
    parser.add_argument('--add_trn_cfg', type=int, default=1)
    parser.add_argument('--add_linker_cfg', type=int, default=1)
    return parser.parse_args()
    

## `__main__`

In [14]:
#| export
if __name__ == '__main__':
    args = parse_args()
    
    config = get_dataset_config(args.data_dir, args.model, args.entity_type, args.suffix, add_trn_cfg=args.add_trn_cfg, 
                                add_linker_cfg=args.add_linker_cfg)
    os.makedirs(f'{args.data_dir}/configs/', exist_ok=True)

    fname = f"data_{args.entity_type}" if len(args.entity_type) else "data"
    if len(args.model): fname = f"{fname}_{args.model}"
    if len(args.suffix): fname = f"{fname}_{args.suffix}"
    
    with open(f'{args.data_dir}/configs/{fname}.json', 'w') as file:
        json.dump(config, file, indent=4)
        

In [53]:
data_dir = "/home/scai/phd/aiz218323/scratch/datasets/nq/XC/"
config = get_dataset_config(data_dir, add_trn_cfg=False, add_linker_cfg=False)

In [54]:
config

{'data': {'path': {'test': {'data_lbl': '/home/scai/phd/aiz218323/scratch/datasets/nq/XC//tst_X_Y.npz',
    'data_info': '/home/scai/phd/aiz218323/scratch/datasets/nq/XC//raw_data/test.raw.txt',
    'lbl_info': '/home/scai/phd/aiz218323/scratch/datasets/nq/XC//raw_data/label.raw.txt'}},
  'parameters': {'transform_type': 'xc',
   'smp_features': [('lbl2data', 1, 2),
    ('hlk2data', 1, 1),
    ('hlk2lbl2data', 2, 1)],
   'pad_token': 0,
   'oversample': False,
   'sampling_features': [('lbl2data', 2),
    ('hlk2data', 1),
    ('hlk2lbl2data', 1)],
   'num_labels': 1,
   'num_metadata': 1,
   'metadata_name': None,
   'info_column_names': ['identifier', 'input_text'],
   'use_tokenizer': True,
   'tokenizer': 'distilbert-base-uncased',
   'tokenization_column': 'input_text',
   'main_max_data_sequence_length': 32,
   'main_max_lbl_sequence_length': 32,
   'meta_max_sequence_length': 32,
   'padding': False,
   'return_tensors': None,
   'sep': '->',
   'prompt_func': None,
   'pad_side'