In [2]:
#| default_exp 11_msmarco-config-file

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [1]:
#| export
import json, os, argparse
from typing import Optional

from xcai.config import PARAM

## `MSMARCO` config

In [5]:
#| export
def get_config_key(model:Optional[str]='', entity_type:Optional[str]='', suffix:Optional[str]=''):
    key = 'data'
    if len(entity_type): key = f'{key}_{entity_type}'
    if len(model): key = f'{key}-{model}'
    if len(suffix): key = f'{key}_{suffix}'
    return key
    

In [6]:
#| export
def get_msmarco_config(data_dir, model:Optional[str]='', entity_type:Optional[str]='', suffix:Optional[str]=''):
    config_key = get_config_key(model, entity_type, suffix)
    
    mat_suffix = f'_{suffix}' if len(suffix) else ''
    raw_suffix = f'.{suffix}' if len(suffix) else ''
    return {
        config_key: {
            "path": {
                "train": {
                    "data_lbl": f"{data_dir}/trn_X_Y{mat_suffix}.npz",
                    "data_info": f"{data_dir}/raw_data/train.raw.txt",
                    "lbl_info": f"{data_dir}/raw_data/label{raw_suffix}.raw.txt",
                    "ent_meta": {
                        "prefix": "ent",
                        "data_meta": f"{data_dir}/{entity_type}_{model}_trn_X_Y.npz",
                        "lbl_meta": f"{data_dir}/{entity_type}_{model}_lbl_X_Y{mat_suffix}.npz",
                        "meta_info": f"{data_dir}/raw_data/{entity_type}_{model}.raw.txt"
                    }
                },
                "test": {
                    "data_lbl": f"{data_dir}/tst_X_Y{mat_suffix}.npz",
                    "data_info": f"{data_dir}/raw_data/test.raw.txt",
                    "lbl_info": f"{data_dir}/raw_data/label{raw_suffix}.raw.txt",
                    "ent_meta": {
                        "prefix": "ent",
                        "data_meta": f"{data_dir}/{entity_type}_{model}_tst_X_Y.npz",
                        "lbl_meta": f"{data_dir}/{entity_type}_{model}_lbl_X_Y{mat_suffix}.npz",
                        "meta_info": f"{data_dir}/raw_data/{entity_type}_{model}.raw.txt"
                    }
                }
            },
            "parameters": PARAM,
        }
    }

In [14]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, required=True)
    parser.add_argument('--model', type=str, required=True)
    parser.add_argument('--entity_type', type=str, required=True)
    parser.add_argument('--suffix', type=str, default='')
    return parser.parse_args()
    

## `__main__`

In [10]:
#| export
if __name__ == '__main__':
    args = parse_args()
    
    config = get_msmarco_config(args.data_dir, args.model, args.entity_type, args.suffix)
    os.makedirs(f'{args.data_dir}/configs/', exist_ok=True)

    config_key = get_config_key(args.model, args.entity_type, args.suffix)
    with open(f'{args.data_dir}/configs/{config_key}.json', 'w') as file:
        json.dump(config, file, indent=4)
        

In [15]:
data_dir = "/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC/"
config = get_msmarco_config(data_dir, 'llama', 'entity', suffix='exact')

config

{'data_entity-llama_exact': {'path': {'train': {'data_lbl': '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC//trn_X_Y_exact.npz',
    'data_info': '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC//raw_data/train.raw.txt',
    'lbl_info': '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC//raw_data/label.exact.raw.txt',
    'ent_meta': {'prefix': 'ent',
     'data_meta': '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC//entity_llama_trn_X_Y.npz',
     'lbl_meta': '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC//entity_llama_lbl_X_Y_exact.npz',
     'meta_info': '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC//raw_data/entity_llama.raw.txt'}},
   'test': {'data_lbl': '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC//tst_X_Y_exact.npz',
    'data_info': '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC//raw_data/test.raw.txt',
    'lbl_info': '/home/scai/phd/aiz218323/scratch/datasets/msmarco/XC//raw_data/label.exact.raw.txt',
    'ent_meta': {'prefi