In [3]:
#| default_exp 22_beir-config-file

In [2]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
#| export
import json, os, argparse

from typing import Optional
from xcai.config import PARAM

## `MSMARCO` config

In [12]:
#| export
def get_config_key(model:Optional[str]='', meta_type:Optional[str]='', suffix:Optional[str]=''):
    key = ""
    if len(model): key = f"{key}-{model}" if len(key) else model
    if len(meta_type): key = f"{key}-{meta_type}" if len(key) else meta_type
    if len(suffix): key = f"{key}_{suffix}"
    return f"data_lbl_{key}"
    

In [13]:
#| export
def get_dataset_config(data_dir:str, model:Optional[str]='', meta_prefix:Optional[str]='lnk', 
                       meta_type:Optional[str]='', suffix:Optional[str]='', add_trn_cfg:Optional[bool]=True, 
                       add_linker_cfg:Optional[bool]=True, **kwargs):
    cfg_key = get_config_key(model, meta_type, suffix)
    
    suffix = f'_{suffix}' if len(suffix) else ''
    
    meta_suffix = f'{meta_suffix}_' if len(meta_suffix) else ''
    if len(model): meta_suffix = f'{meta_suffix}{model}_'

    PARAM["main_max_lbl_sequence_length"] = 512
    PARAM["main_max_data_sequence_length"] = 300
    for k,v in kwargs.items():
        if k in PARAM and v is not None: PARAM[k] = v
            
    cfg = {
        cfg_key: {
            "path": {},
            "parameters": PARAM,
        }
    }
    
    if add_trn_cfg:
        trn_cfg = {
            "data_lbl": f"{data_dir}/trn_X_Y{suffix}.npz",
            "data_info": f"{data_dir}/raw_data/train.raw.csv",
            "lbl_info": f"{data_dir}/raw_data/label{suffix}.raw.csv",
        }
        if add_linker_cfg:
            trn_meta_cfg = {
                "prefix": meta_prefix,
                "data_meta": f"{data_dir}/{meta_suffix}trn_X_Y.npz",
                "lbl_meta": f"{data_dir}/{meta_suffix}lbl_X_Y{suffix}.npz",
                "meta_info": f"{data_dir}/raw_data/{meta_suffix[:-1]}.raw.csv"
            }
            trn_cfg[f"{meta_prefix}_meta"] = trn_meta_cfg
        cfg[cfg_key]["path"]["train"] = trn_cfg

    tst_cfg = {
        "data_lbl": f"{data_dir}/tst_X_Y{suffix}.npz",
        "data_info": f"{data_dir}/raw_data/test.raw.csv",
        "lbl_info": f"{data_dir}/raw_data/label{suffix}.raw.csv",
    }
    if add_linker_cfg:
        tst_meta_cfg = {
            "prefix": meta_prefix,
            "data_meta": f"{data_dir}/{meta_suffix}tst_X_Y.npz",
            "lbl_meta": f"{data_dir}/{meta_suffix}lbl_X_Y{suffix}.npz",
            "meta_info": f"{data_dir}/raw_data/{meta_suffix[:-1]}.raw.csv"
        }
        tst_cfg[f"{meta_prefix}_meta"] = tst_meta_cfg
    cfg[cfg_key]["path"]["test"] = tst_cfg
    
    return cfg
    

In [14]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--model', type=str, default='')
    parser.add_argument('--meta_prefix', type=str, default='')
    parser.add_argument('--meta_type', type=str, default='')
    parser.add_argument('--suffix', type=str, default='')
    parser.add_argument('--add_trn_cfg', type=int, default=1)
    parser.add_argument('--add_linker_cfg', type=int, default=1)
    parser.add_argument('--add_linker_cfg', type=int, default=1)
    parser.add_argument('--main_max_data_sequence_length', type=int, default=None)
    parser.add_argument('--main_max_lbl_sequence_length', type=int, default=None)
    return parser.parse_args()
    

## `__main__`

In [14]:
#| export
if __name__ == '__main__':
    args = parse_args()
    
    config = get_dataset_config(args.data_dir, args.model, args.meta_prefix, args.meta_type, args.suffix, 
                                add_trn_cfg=args.add_trn_cfg, add_linker_cfg=args.add_linker_cfg,
                                main_max_data_sequence_length=args.main_max_data_sequence_length,
                                main_max_lbl_sequence_length=args.main_max_lbl_sequence_length)
    os.makedirs(f'{args.data_dir}/configs/', exist_ok=True)

    fname = get_config_key(args.model, args.meta_type, args.suffix)
    
    with open(f'{args.data_dir}/configs/{fname}.json', 'w') as file:
        json.dump(config, file, indent=4)
        

In [53]:
data_dir = "/home/scai/phd/aiz218323/scratch/datasets/nq/XC/"
config = get_dataset_config(data_dir, add_trn_cfg=False, add_linker_cfg=False)

In [54]:
config

{'data': {'path': {'test': {'data_lbl': '/home/scai/phd/aiz218323/scratch/datasets/nq/XC//tst_X_Y.npz',
    'data_info': '/home/scai/phd/aiz218323/scratch/datasets/nq/XC//raw_data/test.raw.txt',
    'lbl_info': '/home/scai/phd/aiz218323/scratch/datasets/nq/XC//raw_data/label.raw.txt'}},
  'parameters': {'transform_type': 'xc',
   'smp_features': [('lbl2data', 1, 2),
    ('hlk2data', 1, 1),
    ('hlk2lbl2data', 2, 1)],
   'pad_token': 0,
   'oversample': False,
   'sampling_features': [('lbl2data', 2),
    ('hlk2data', 1),
    ('hlk2lbl2data', 1)],
   'num_labels': 1,
   'num_metadata': 1,
   'metadata_name': None,
   'info_column_names': ['identifier', 'input_text'],
   'use_tokenizer': True,
   'tokenizer': 'distilbert-base-uncased',
   'tokenization_column': 'input_text',
   'main_max_data_sequence_length': 32,
   'main_max_lbl_sequence_length': 32,
   'meta_max_sequence_length': 32,
   'padding': False,
   'return_tensors': None,
   'sep': '->',
   'prompt_func': None,
   'pad_side'