# `Miscellaneous`

In [12]:
#| default_exp misc

In [29]:
#| export
import os, torch ,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp, argparse
from typing import Optional, Union, Callable, List
from tqdm.auto import tqdm

In [27]:
#| export
from xcai.core import *
from xcai.data import *

from xcai.basics import *
from xcai.models.PPP0XX import DBT009, DBTConfig

## Helper functions

In [23]:
#| export
def load_info(save_file:str, meta_file:str, mname:str, sequence_length:Optional[int]=32):
    os.makedirs(os.path.dirname(save_file), exist_ok=True)
    if os.path.exists(save_file):
        meta_info = joblib.load(save_file)
    else:
        meta_info = Info.from_txt(meta_file, max_sequence_length=sequence_length, padding=True, return_tensors='pt',
                                  info_column_names=["identifier", "input_text"], tokenization_column="input_text",
                                  use_tokenizer=True, tokenizer=mname)
        joblib.dump(meta_info, save_file)
    return meta_info
    

## BeIR datasets

In [22]:
#| export
DATASETS = [
    "arguana",
    "msmarco",
    "climate-fever",
    "dbpedia-entity",
    "fever",
    "fiqa",
    "hotpotqa",
    "nfcorpus",
    "nq",
    "quora",
    "scidocs",
    "scifact",
    "webis-touche2020",
    "trec-covid",
    "cqadupstack/android",
    "cqadupstack/english",
    "cqadupstack/gaming",
    "cqadupstack/gis",
    "cqadupstack/mathematica",
    "cqadupstack/physics",
    "cqadupstack/programmers",
    "cqadupstack/stats",
    "cqadupstack/tex",
    "cqadupstack/unix",
    "cqadupstack/webmasters",
    "cqadupstack/wordpress"
]

## `Linker utils`

In [36]:
#| export
def additional_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pct", type=float, default=1.0)
    parser.add_argument("--use_all", action="store_true")
    parser.add_argument("--use_training_test_set", action="store_true")
    return parser.parse_known_args()[0]


def get_random_idx(n_data:int, pct:float):
    n_trn = int(pct * n_data)
    return np.random.permutation(n_data)[:n_trn]
    

In [37]:
#| export
def load_linker_block(dataset:str, config_file:str, input_args:argparse.ArgumentParser, 
                      extra_args:Optional[argparse.ArgumentParser]=None):
    config_key, fname = get_config_key(config_file)
    pkl_file = get_pkl_file(input_args.pickle_dir, f"{dataset}_{fname}_distilbert-base-uncased", input_args.use_sxc_sampler,
                            input_args.exact, input_args.only_test)

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, only_test=input_args.only_test,
                        n_slbl_samples=1, main_oversample=False)

    do_inference = check_inference_mode(input_args)
    if do_inference:
        train_dset = None if block.train is None else block.train.dset
        test_dset = block.test.dset.get_valid_dset() if extra_args.use_training_test_set else block.test.dset
    else:
        train_dset = block.train.dset.get_valid_dset()
        test_dset = block.test.dset.get_valid_dset()
        if extra_args.pct < 1.0:
            train_dset = train_dset._getitems(get_random_idx(len(train_dset), extra_args.pct))

    return train_dset, test_dset
    

In [40]:
#| export
def linker_run(output_dir:str, input_args:argparse.ArgumentParser, mname:str, test_dset:Union[XCDataset, SXCDataset],
               train_dset:Optional[Union[XCDataset, SXCDataset]]=None, collator:Optional[Callable]=identity_collate_fn):

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        adam_epsilon=1e-6,                                                                                                                                          warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,

        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='N@10',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',

        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,

        use_cpu_for_searching=True,
        use_cpu_for_clustering=True,
    )

    config = DBTConfig(
        margin = 0.3,
        num_negatives = 10,
        tau = 0.1,
        apply_softmax = True,
        reduction = "mean",

        normalize = True,
        use_layer_norm = True,

        use_encoder_parallel = True,
        loss_function = "triplet"
    )

    def model_fn(mname, config):
        return DBT009.from_pretrained(mname, config=config)

    do_inference = check_inference_mode(input_args)
    model = load_model(args.output_dir, model_fn, {"mname": mname, "config": config}, do_inference=do_inference,
                       use_pretrained=input_args.use_pretrained)

    metric = PrecReclMrr(test_dset.data.n_lbl, test_dset.data.data_lbl_filterer, prop=None if train_dset is None else train_dset.data.data_lbl,
                         pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=train_dset,
        eval_dataset=test_dset,
        data_collator=collator,
        compute_metrics=metric,
    )

    return main(learn, input_args, n_lbl=test_dset.data.n_lbl, eval_k=10, train_k=10)
    