# `Miscellaneous`

In [1]:
#| default_exp misc

In [2]:
#| export
import os, torch ,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp, argparse
from typing import Optional, Union, Callable, List
from tqdm.auto import tqdm

from xclib.utils.sparse import retain_topk

In [3]:
#| export
from xcai.core import *
from xcai.data import *
from xcai.sdata import SXCDataset, SMainXCDataset

from xcai.basics import *
from xcai.models.PPP0XX import DBT009, DBTConfig
from xcai.models.upma import UPA000, UPMAConfig
from xcai.models.PPP0XX import DBT023, DBTConfig

## Helper functions

In [5]:
#| export
def load_info(save_file:str, meta_file:str, mname:str, sequence_length:Optional[int]=32):
    os.makedirs(os.path.dirname(save_file), exist_ok=True)
    if os.path.exists(save_file):
        meta_info = joblib.load(save_file)
    else:
        meta_info = Info.from_txt(meta_file, max_sequence_length=sequence_length, padding=True, return_tensors='pt',
                                  info_column_names=["identifier", "input_text"], tokenization_column="input_text",
                                  use_tokenizer=True, tokenizer=mname)
        joblib.dump(meta_info, save_file)
    return meta_info
    

In [6]:
#| export
def collate_beir_metrics(metric_dir:str):
    beir_metrics = {}
    for dataset in BEIR_DATASETS:
        dataset = dataset.replace("/", "-")
        
        fname = f"{metric_dir}/{dataset}.json"
        if os.path.exists(fname):
            with open(fname) as file:
                beir_metrics.update(json.load(file))
            
    with open(f"{metric_dir}/beir.json", "w") as file:
        json.dump(beir_metrics, file, indent=4)
        

In [7]:
#| export
def additional_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pct", type=float, default=1.0)
    parser.add_argument("--use_all", action="store_true")
    parser.add_argument("--use_task_specific_metadata", action="store_true")
    parser.add_argument("--use_training_test_set", action="store_true")
    return parser.parse_known_args()[0]

def get_random_idx(n_data:int, pct:float):
    n_trn = int(pct * n_data)
    return np.random.permutation(n_data)[:n_trn]
    

In [8]:
#| export
def get_label_dataset(dataset:SXCDataset, mname:str, args:argparse.ArgumentParser):
    lbl_info = dataset.data.lbl_info
    dataset = SXCDataset(SMainXCDataset(data_info=lbl_info, lbl_info=lbl_info))
    mname = mname.split("/")[1] if "/" in mname else mname
    args.prediction_suffix = f"labels_{mname}" if args.use_pretrained else "labels"
    return dataset
    

## BeIR datasets

In [9]:
#| export
BEIR_DATASETS = [
    "arguana",
    "msmarco",
    "climate-fever",
    "dbpedia-entity",
    "fever",
    "fiqa",
    "hotpotqa",
    "nfcorpus",
    "nq",
    "quora",
    "scidocs",
    "scifact",
    "webis-touche2020",
    "trec-covid",
    "cqadupstack/android",
    "cqadupstack/english",
    "cqadupstack/gaming",
    "cqadupstack/gis",
    "cqadupstack/mathematica",
    "cqadupstack/physics",
    "cqadupstack/programmers",
    "cqadupstack/stats",
    "cqadupstack/tex",
    "cqadupstack/unix",
    "cqadupstack/webmasters",
    "cqadupstack/wordpress"
]

## `Linker utils`

In [45]:
#| export
def load_linker_block(dataset:str, config_file:str, input_args:argparse.ArgumentParser, extra_args:Optional[argparse.ArgumentParser]=None, 
                      main_max_data_sequence_length:Optional[int]=32):
    config_key, fname = get_config_key(config_file)
    pkl_file = get_pkl_file(input_args.pickle_dir, f"{dataset}_{fname}_distilbert-base-uncased", input_args.use_sxc_sampler,
                            input_args.exact, input_args.only_test)

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, only_test=input_args.only_test,
                        n_slbl_samples=1, main_oversample=False, main_max_data_sequence_length=main_max_data_sequence_length)

    do_inference = check_inference_mode(input_args)
    if do_inference:
        train_dset = None if block.train is None else block.train.dset
        test_dset = block.test.dset.get_valid_dset() if extra_args.use_training_test_set else block.test.dset
    else:
        train_dset = block.train.dset.get_valid_dset()
        test_dset = block.test.dset.get_valid_dset()
        if extra_args.pct < 1.0:
            train_dset = train_dset._getitems(get_random_idx(len(train_dset), extra_args.pct))

    return train_dset, test_dset
    

In [46]:
#| export
def linker_run(output_dir:str, input_args:argparse.ArgumentParser, mname:str, test_dset:Union[XCDataset, SXCDataset],
               train_dset:Optional[Union[XCDataset, SXCDataset]]=None, label_dset:Optional[Union[XCDataset, SXCDataset]]=None,
               collator:Optional[Callable]=identity_collate_fn, save_dir_name:Optional[str]=None):

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,

        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='N@10',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',

        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,

        use_cpu_for_searching=True,
        use_cpu_for_clustering=True,
    )

    config = DBTConfig(
        margin = 0.3,
        num_negatives = 10,
        tau = 0.1,
        apply_softmax = True,
        reduction = "mean",

        normalize = True,
        use_layer_norm = True,

        use_encoder_parallel = True,
        loss_function = "triplet"
    )

    def model_fn(mname, config):
        return DBT009.from_pretrained(mname, config=config)

    do_inference = check_inference_mode(input_args)
    model = load_model(args.output_dir, model_fn, {"mname": mname, "config": config}, do_inference=do_inference,
                       use_pretrained=input_args.use_pretrained)

    metric = PrecReclMrr(test_dset.data.n_lbl, test_dset.data.data_lbl_filterer, prop=None if train_dset is None else train_dset.data.data_lbl,
                         pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=train_dset,
        eval_dataset=test_dset,
        data_collator=collator,
        compute_metrics=metric,
    )

    eval_dataset = get_label_dataset(test_dset, mname, input_args) if input_args.label_similarity else None

    return main(learn, input_args, n_lbl=test_dset.data.n_lbl, eval_dataset=eval_dataset, 
                eval_k=10, train_k=10, label_dataset=label_dset, label_k=10, save_dir_name=save_dir_name)
    

In [47]:
#| export
def linker_beir_inference(output_dir:str, input_args:argparse.ArgumentParser, mname:str, 
                          save_file_name:str, meta_file:str, datasets:Optional[List]=None, 
                          pred_dir_name:Optional[str]=None, use_task_specific_metadata:Optional[bool]=False, 
                          meta_sequence_length:Optional[int]=64, get_label_predictions:Optional[bool]=False):
    
    metric_dir = f"{output_dir}/metrics"
    os.makedirs(metric_dir, exist_ok=True)

    input_args.only_test = input_args.do_test_inference = input_args.save_test_prediction = True
    
    # meta-data
    if not use_task_specific_metadata:
        meta_info = load_info(f"{input_args.pickle_dir}/{save_file_name}.joblib",
                              f"/data/datasets/beir/msmarco/XC/{meta_file}",
                              mname, sequence_length=meta_sequence_length)

    os.makedirs(f"{input_args.pickle_dir}/beir/", exist_ok=True)

    datasets = BEIR_DATASETS if datasets is None else datasets
    for dataset in tqdm(datasets):
        print(dataset)
        dataset_prefix = dataset.replace("/", "-")
        
        # test-data
        test_info = load_info(f"{input_args.pickle_dir}/beir/{dataset_prefix}.joblib",
                              f"/data/datasets/beir/{dataset}/XC/raw_data/test.raw.csv",
                              mname, sequence_length=32)
        
        # meta-data
        if use_task_specific_metadata:
            fname = f"/data/datasets/beir/{dataset}/XC/{meta_file}"
            if os.path.exists(fname):
                meta_info = load_info(f"{input_args.pickle_dir}/beir/{save_file_name}/{dataset_prefix}.joblib",
                                      fname, mname, sequence_length=meta_sequence_length)
            else:
                print(f"WARNING:: Missing raw file at {fname}. Dataset '{dataset_prefix}' will be skipped.")
                continue
                
        # dataset
        test_dset = SXCDataset(SMainXCDataset(data_info=test_info, lbl_info=meta_info))

        # label-data
        label_dset = None
        if get_label_predictions:
            input_args.do_label_inference = input_args.save_label_prediction = True
            
            lbl_info = load_info(f"{input_args.pickle_dir}/beir/{dataset_prefix}-label.joblib", 
                                 f"/data/datasets/beir/{dataset}/XC/raw_data/label.raw.csv",
                                 mname, sequence_length=128)
            label_dset = SXCDataset(SMainXCDataset(data_info=lbl_info, lbl_info=meta_info))
            
        input_args.prediction_suffix = dataset_prefix
        trn_repr, tst_repr, lbl_repr, trn_pred, tst_pred, trn_metric, tst_metric = linker_run(output_dir, input_args, mname, test_dset, 
                                                                                              save_dir_name=pred_dir_name, label_dset=label_dset)

        with open(f"{metric_dir}/{dataset_prefix}.json", "w") as file:
            json.dump({dataset: tst_metric}, file, indent=4)

    collate_beir_metrics(metric_dir)
        

## `UPMA utils`

In [77]:
#| export
def upma_beir_inference(output_dir:str, input_args:argparse.ArgumentParser, mname:str, meta_save_fname:str, 
                        meta_file:str, linker_dir:str, n_data_lnk_samples:Optional[int]=5, n_lbl_lnk_samples:Optional[int]=5, 
                        data_lnk_topk:Optional[int]=5, lbl_lnk_topk:Optional[int]=5, eval_batch_size:Optional[int]=400, 
                        datasets:Optional[List]=None, pred_dir_name:Optional[str]=None, data_repr_pooling:Optional[bool]=True, 
                        memory_injection_layer:Optional[Union[int, List]]=6, memory_type:Optional[Union[str, List]]="embeddings", 
                        n_memory_layers:Optional[int]=3, use_label_memory:Optional[bool]=False, num_input_metadata:Optional[int]=5):
    
    metric_dir = f"{output_dir}/metrics"
    os.makedirs(metric_dir, exist_ok=True)

    input_args.only_test = input_args.do_test_inference = input_args.save_test_prediction = True

    meta_info = load_info(f"{input_args.pickle_dir}/{meta_save_fname}.joblib", meta_file, mname, 
                          sequence_length=64)

    datasets = BEIR_DATASETS if datasets is None else datasets
    for dataset in tqdm(datasets):
        print(dataset)

        config_file = f"/data/datasets/beir/{dataset}/XC/configs/data.json"
        train_dset, test_dset = load_upma_block(dataset, config_file, input_args)

        dataset = dataset.replace("/", "-")
        data_meta = retain_topk(sp.load_npz(f"{linker_dir}/predictions/test_predictions_{dataset}.npz"), k=data_lnk_topk)
        lbl_meta = (
            retain_topk(sp.load_npz(f"{linker_dir}/predictions/label_predictions_{dataset}.npz"), k=lbl_lnk_topk) 
            if use_label_memory else None
        )
        
        meta_kwargs = {
            "lnk_meta": SMetaXCDataset(prefix="lnk", data_meta=data_meta, lbl_meta=lbl_meta, meta_info=meta_info, n_sdata_meta_samples=n_data_lnk_samples,
                                       n_slbl_meta_samples=n_lbl_lnk_samples, return_scores=True, meta_oversample=True),
        }
        test_dset = SXCDataset(test_dset.data, **meta_kwargs)

        input_args.prediction_suffix = dataset
        trn_repr, tst_repr, lbl_repr, trn_pred, tst_pred, trn_metric, tst_metric = upma_run(output_dir, input_args, mname, test_dset, train_dset, 
                                                                                            eval_batch_size=eval_batch_size, save_dir_name=pred_dir_name, 
                                                                                            data_repr_pooling=data_repr_pooling, 
                                                                                            memory_injection_layer=memory_injection_layer, 
                                                                                            use_label_memory=use_label_memory, memory_type=memory_type, 
                                                                                            n_memory_layers=n_memory_layers, 
                                                                                            num_input_metadata=num_input_metadata)
        
        with open(f"{metric_dir}/{dataset}.json", "w") as file:
            json.dump({dataset: tst_metric}, file, indent=4)

    collate_beir_metrics(metric_dir)
        

In [78]:
#| export
def load_upma_block(dataset:str, config_file:str, input_args:argparse.ArgumentParser, n_data_lnk_samples:Optional[int]=5, 
                    n_lbl_lnk_samples:Optional[int]=5, n_neg_lnk_samples:Optional[int]=5, data_lnk_topk:Optional[int]=5, 
                    lbl_lnk_topk:Optional[int]=5, neg_lnk_topk:Optional[int]=5):
    config_key, fname = get_config_key(config_file)
    pkl_file = get_pkl_file(input_args.pickle_dir, f"{dataset}_{fname}_distilbert-base-uncased", input_args.use_sxc_sampler,
                            input_args.exact, input_args.only_test)

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, 
                        only_test=input_args.only_test, main_oversample=True, meta_oversample=True, return_scores=True, 
                        n_slbl_samples=1, n_sdata_meta_samples={"lnk_meta": n_data_lnk_samples, "neg_meta": 1},
                        n_slbl_meta_samples={"lnk_meta": n_lbl_lnk_samples, "neg_meta": 1},
                        n_sneg_meta_samples={"lnk_meta": n_neg_lnk_samples, "neg_meta": 1},
                        
                        train_data_meta_topk={"lnk_meta": data_lnk_topk}, test_data_meta_topk={"lnk_meta": data_lnk_topk}, 
                        train_label_meta_topk={"lnk_meta": lbl_lnk_topk}, test_label_meta_topk={"lnk_meta": lbl_lnk_topk},
                        train_neg_meta_topk={"lnk_meta": neg_lnk_topk}, test_neg_meta_topk={"lnk_meta": neg_lnk_topk},)
    
    train_dset, test_dset = None if block.train is None else block.train.dset, block.test.dset

    return train_dset, test_dset
    

In [84]:
#| export
def upma_run(output_dir:str, input_args:argparse.ArgumentParser, mname:str, test_dset:Union[XCDataset, SXCDataset],
             train_dset:Optional[Union[XCDataset, SXCDataset]]=None, collator:Optional[Callable]=identity_collate_fn, 
             train_batch_size:Optional[int]=128, eval_batch_size:Optional[int]=400, save_dir_name:Optional[str]=None,
             data_repr_pooling:Optional[bool]=True, memory_injection_layer:Optional[Union[int, List]]=6, 
             memory_type:Optional[Union[str, List]]="embeddings", n_memory_layers:Optional[int]=3, 
             use_label_memory:Optional[bool]=False, num_input_metadata:Optional[int]=5):

    label_names = ["plbl2data_idx", "plbl2data_data2ptr", "lnk2data_idx", "lnk2data_data2ptr", "lnk2data_scores"]
    if "encoder" in label_names: label_names = label_names + ["lnk2data_input_ids", "lnk2data_attention_mask"]
    
    label_memory_names = ["lnk2lbl_idx", "lnk2lbl_data2ptr", "lnk2lbl_lbl2ptr", "lnk2lbl_scores", 
                          "lnk2neg_idx", "lnk2neg_data2ptr", "lnk2neg_neg2ptr", "lnk2neg_scores"]
    if "encoder" in label_names:
        label_memory_names = label_memory_names + ["lnk2lbl_input_ids", "lnk2lbl_attention_mask", "lnk2neg_input_ids", "lnk2neg_attention_mask"]
    
    label_names = label_names + label_memory_names if use_label_memory else label_names
    use_label_metadata = lbl2data_inject_memory = neg2data_inject_memory = use_label_memory

    memory_type = memory_type if isinstance(memory_type, list) else [memory_type]
    memory_injection_layer = memory_injection_layer if isinstance(memory_injection_layer, list) else [memory_injection_layer]
    
    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=eval_batch_size,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=50,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        search_normalize=False,

        adam_epsilon=1e-6,
        warmup_steps=1000,
        weight_decay=0.01,
        learning_rate=6e-5,
        label_names=label_names,

        group_by_cluster=True,
        use_data_metadata_for_clustering=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        data_aug_meta_name="lnk",
        use_label_metadata=use_label_metadata,

        metric_for_best_model='N@10',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',

        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,

        use_cpu_for_searching=True,
        use_cpu_for_clustering=True,
    )

    config = UPMAConfig(
        memory_module_names = memory_type,
        memory_injection_layers = memory_injection_layer,

        num_total_metadata = test_dset.meta["lnk_meta"].n_meta,
        num_input_metadata = num_input_metadata,
        metadata_dropout = 0.1,

        n_memory_layers = 3,

        data_aug_meta_prefix="lnk2data",
        lbl2data_aug_meta_prefix="lnk2lbl",
        neg2data_aug_meta_prefix="lnk2neg",

        data_inject_memory=True,
        lbl2data_inject_memory=lbl2data_inject_memory,
        neg2data_inject_memory=neg2data_inject_memory,

        data_repr_pooling=data_repr_pooling,
        data_normalize=False,

        margin=0.3,
        num_negatives=10,
        tau=0.1,
        apply_softmax=True,
        reduction="mean",

        calib_margin=0.3,
        calib_num_negatives=10,
        calib_tau=0.1,
        calib_apply_softmax=False,

        calib_loss_weight=0.1,
        use_calib_loss=True,

        use_encoder_parallel=True,
        loss_function="margin",

        initialize_memory_embeddings_from_injection_layer_mean=True,
        metadata_embedding_file=f"{output_dir}/metadata/gpt-substring.pth",
    )

    def model_fn(mname:Optional[str]=None):
        meta_dset = test_dset.meta_dset("lnk_meta")
        model = UPA000.from_pretrained(config, mname=mname, meta_dset=meta_dset, batch_size=1000)
        return model

    metric = PrecReclMrr(test_dset.n_lbl, test_dset.data.data_lbl_filterer, pk=10, rk=200, rep_pk=[1, 3, 5, 10],
                         rep_rk=[10, 100, 200], mk=[5, 10, 20])

    model = load_model(args.output_dir, model_fn, do_inference=check_inference_mode(input_args), use_pretrained=input_args.use_pretrained)

    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=train_dset,
        eval_dataset=test_dset,
        data_collator=collator,
        compute_metrics=metric,
    )

    return main(learn, input_args, n_lbl=test_dset.n_lbl, save_dir_name=save_dir_name)
    

## `Early-fusion utils`

In [67]:
#| export
def early_fusion_run(output_dir:str, input_args:argparse.ArgumentParser, mname:str, test_dset:Union[XCDataset, SXCDataset],
                     train_dset:Optional[Union[XCDataset, SXCDataset]]=None, collator:Optional[Callable]=identity_collate_fn, 
                     save_dir_name:Optional[str]=None):

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=128,
        per_device_eval_batch_size=1600,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=50,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        search_normalize=False,

        adam_epsilon=1e-6,
        warmup_steps=1000,
        weight_decay=0.01,
        learning_rate=6e-5,
        label_names=['plbl2data_idx', 'plbl2data_data2ptr'],

        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',

        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,

        use_cpu_for_searching=True,
        use_cpu_for_clustering=True,
    )

    config = DBTConfig(
        normalize = False,
        use_layer_norm = False,
        use_encoder_parallel = True,
    )

    def model_fn(mname, config):
        return DBT023.from_pretrained(mname, config=config)

    do_inference = check_inference_mode(input_args)
    model = load_model(args.output_dir, model_fn, {"mname": mname, "config": config}, do_inference=do_inference,
                       use_pretrained=input_args.use_pretrained)

    metric = PrecReclMrr(test_dset.data.n_lbl, test_dset.data.data_lbl_filterer, prop=None if train_dset is None else train_dset.data.data_lbl,
                         pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=train_dset,
        eval_dataset=test_dset,
        data_collator=collator,
        compute_metrics=metric,
    )

    return main(learn, input_args, n_lbl=test_dset.data.n_lbl, eval_k=10, train_k=10, save_dir_name=save_dir_name)
    

In [68]:
#| export
def load_early_fusion_block(dataset:str, config_file:str, input_args:argparse.ArgumentParser):
    config_key, fname = get_config_key(config_file)
    pkl_file = get_pkl_file(input_args.pickle_dir, f"{dataset}_{fname}_distilbert-base-uncased", input_args.use_sxc_sampler,
                            input_args.exact, input_args.only_test)

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block,
                        only_test=input_args.only_test, n_slbl_samples=1, main_oversample=False, n_sdata_meta_samples=1,
                        meta_oversample=False, return_scores=True)
    train_dset, test_dset = None if block.train is None else block.train.dset, block.test.dset

    return train_dset, test_dset
    

In [69]:
#| export
def early_fusion_beir_inference(output_dir:str, input_args:argparse.ArgumentParser, mname:str, linker_dir:str, 
                                datasets:Optional[List]=None, raw_dir_name:Optional[str]="raw_data", 
                                metric_dir_name:Optional[str]="metrics", pred_dir_name:Optional[str]=None):
    
    metric_dir = f"{output_dir}/{metric_dir_name}"
    os.makedirs(metric_dir, exist_ok=True)

    input_args.only_test = input_args.do_test_inference = input_args.save_test_prediction = True

    datasets = BEIR_DATASETS if datasets is None else datasets
    for dataset in tqdm(datasets):
        print(dataset)

        config_file = f"/data/datasets/beir/{dataset}/XC/configs/data.json"
        train_dset, test_dset = load_early_fusion_block(dataset, config_file, input_args)

        dataset = dataset.replace("/", "-")
        linker_dir_name, cross_name = os.path.basename(linker_dir.rstrip("/")), raw_dir_name.rstrip("/")
        linker_dir_name = f"{linker_dir_name}/{cross_name.split('/')[1]}" if "/" in cross_name else linker_dir_name

        data_file = f"{linker_dir}/{raw_dir_name}/test_{dataset}.raw.csv"
        if not os.path.exists(data_file):
            print(f"WARNING:: Missing raw file at {data_file}. Dataset '{dataset}' will be skipped.")
            continue
        
        data_info = load_info(f"{input_args.pickle_dir}/{linker_dir_name}/{dataset}.joblib",
                              data_file, mname, sequence_length=128)
        test_dset = SXCDataset(SMainXCDataset(data_info=data_info, data_lbl=test_dset.data.data_lbl, lbl_info=test_dset.data.lbl_info))

        input_args.prediction_suffix = dataset
        trn_repr, tst_repr, lbl_repr, trn_pred, tst_pred, trn_metric, tst_metric = early_fusion_run(output_dir, input_args, mname, test_dset, train_dset, 
                                                                                                    save_dir_name=pred_dir_name)
        with open(f"{metric_dir}/{dataset}.json", "w") as file:
            json.dump({dataset: tst_metric}, file, indent=4)

    collate_beir_metrics(metric_dir)
        