In [7]:
#| default_exp 01_cachew-for-wikiseealsotitles-with-meta-loss

In [2]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
%reload_ext autoreload
%autoreload 2

In [4]:
#| export
import os,torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp

from xcai.main import *
from xcai.basics import *

from xcai.models.cachew import CAW002, CachewConfig

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
os.environ['WANDB_MODE'] = 'disabled'

In [8]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT'] = 'cachew_00-wikiseealsotitles'

## Driver

In [None]:
#| export
if __name__ == '__main__':
    output_dir = '/scratch/scai/phd/aiz218323/outputs/cachew/01_cachew-for-wikiseealsotitles-with-meta-loss'

    data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/'
    config_file = 'wikiseealsotitles'
    config_key = 'data_meta'
    
    meta_embed_init_file = '/data/OGB_Weights/LF-WikiSeeAlsoTitles-320K/emb_weights.npy'
    
    meta_name = 'cat'

    input_args = parse_args()

    pkl_file = f'{input_args.pickle_dir}/cachew/wikiseealsotitles_data-meta_distilbert-base-uncased'
    pkl_file = f'{pkl_file}_sxc' if input_args.use_sxc_sampler else f'{pkl_file}_xcs'
    if input_args.only_test: pkl_file = f'{pkl_file}_only-test'
    pkl_file = f'{pkl_file}.joblib'

    do_inference = input_args.do_train_inference or input_args.do_test_inference or input_args.save_train_prediction or input_args.save_test_prediction or input_args.save_representation

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, only_test=input_args.only_test, 
                        n_slbl_samples=4, main_oversample=False, n_sdata_meta_samples=5, meta_oversample=False, train_meta_topk=5, test_meta_topk=3, 
                        data_dir=data_dir)
    block.test.dset.meta = {}

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=512,
        per_device_eval_batch_size=512,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
        representation_search_type='BRUTEFORCE',
    
        output_representation_attribute='data_enriched_repr',
        label_representation_attribute='data_repr',
        metadata_representation_attribute='data_repr',
        data_augmentation_attribute='data_repr',
        representation_attribute='data_enriched_repr',
        clustering_representation_attribute='data_enriched_repr',
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        use_data_metadata_for_clustering=True,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
    
        use_distributional_representation=False,
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
    
        label_names=[f'{meta_name}2data_idx', f'{meta_name}2data_data2ptr'],
                     
        prune_metadata=False,
        num_metadata_prune_warmup_epochs=10,
        num_metadata_prune_epochs=5,
        metadata_prune_batch_size=2048,
        prune_metadata_names=[f'{meta_name}_meta'],
        use_data_metadata_for_pruning=True,
    
        predict_with_augmentation=False,
        use_augmentation_index_representation=True,
    
        data_aug_meta_name=meta_name,
        augmentation_num_beams=None,
        data_aug_prefix=meta_name,
        use_label_metadata=False,
    
        data_meta_batch_size=2048,
        augment_metadata=False,
        num_metadata_augment_warmup_epochs=10,
        num_metadata_augment_epochs=5,
    
        use_cpu_for_searching=True,
        use_cpu_for_clustering=True,
    )

    config = CachewConfig(
        top_k_metadata = 5,
        num_metadata=block.train.dset.meta['cat_meta'].n_meta,
    
        data_aug_meta_prefix='cat2data', 
        lbl2data_aug_meta_prefix=None,

        data_enrich=True,
        lbl2data_enrich=False,
    
        margin=0.3,
        num_negatives=10,
        tau=0.1,
        apply_softmax=True,
    
        calib_margin=0.3,
        calib_num_negatives=10,
        calib_tau=0.1,
        calib_apply_softmax=False,
        calib_loss_weight=0.1,
        use_calib_loss=True,

        meta_loss_weight=0.1,
        use_meta_loss=True,
    
        use_query_loss=True, 
        use_encoder_parallel=True
    )
    
    model = CAW002.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', config=config)
    model.init_combiner_to_last_layer()
    
    meta_embeddings = torch.tensor(np.load(meta_embed_init_file), dtype=torch.float32)
    model.set_memory_embeddings(meta_embeddings)

    metric = PrecReclMrr(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                         pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )

    if do_inference: os.environ['WANDB_MODE'] = 'disabled'
    
    main(learn, input_args, n_lbl=block.n_lbl)
    

## Analysis

In [40]:
import scipy.sparse as sp

from xcai.analysis import *
from xcai.main import *
from xclib.utils.sparse import retain_topk

In [61]:
from xcai.data import MetaXCDataset

In [62]:
def get_pred_dset(pred, block):
    data = MainXCDataset(block.test.dset.data.data_info, pred, block.test.dset.data.lbl_info, 
                         block.test.dset.data.data_lbl_filterer)
    return XCDataset(data, **{k: MetaXCDataset._initialize(v) for k,v in block.test.dset.meta.items()})
    

In [21]:
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/'
config_file = 'wikiseealsotitles'
config_key = 'data_meta'

pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/processed/'
pkl_file = f'{pkl_dir}/cachew/wikiseealsotitles_data-meta_distilbert-base-uncased_sxc.joblib'

block = build_block(pkl_file, config_file, True, config_key, n_slbl_samples=4, main_oversample=False, n_sdata_meta_samples=5, 
                    meta_oversample=False, train_meta_topk=5, test_meta_topk=3, data_dir=data_dir)

In [63]:
data_dir = '/home/scai/phd/aiz218323/scratch/outputs/cachew/'

In [65]:
pred_file_1 = f'{data_dir}/01_cachew-for-wikiseealsotitles-with-meta-loss-005/predictions/test_predictions.npz'
pred_lbl_1 = sp.load_npz(pred_file_1)
pred_block_1 = get_pred_dset(retain_topk(pred_lbl_1, k=10), block)

pred_file_2 = f'{data_dir}/04_oak-for-wikiseealsotitles-001/predictions/test_predictions.npz'
pred_lbl_2 = sp.load_npz(pred_file_2)
pred_block_2 = get_pred_dset(retain_topk(pred_lbl_2, k=10), block)

In [26]:
metric_1 = pointwise_eval(pred_lbl_1, block.test.dset.data.data_lbl, block.test.dset.data.data_lbl_filterer, 
                          topk=1, metric='P')

  self._set_arrayXarray(i, j, x)


In [27]:
metric_2 = pointwise_eval(pred_lbl_2, block.test.dset.data.data_lbl, block.test.dset.data.data_lbl_filterer, 
                          topk=200, metric='R')

In [67]:
score = np.ravel(metric_1.sum(axis=1) - metric_2.sum(axis=1))
idxs = np.argsort(score)

In [70]:
gt_dset = XCDataset._initialize(block.test.dset)

In [71]:
disp = CompareDataset(pred_block_1, gt_dset, 'Cachew', 'GT')
disp.show(idxs[:10])

[5m[7m[33mCachew data_input_text[0m [33m: List of peaks by prominence[0m
[5m[7m[33mGT data_input_text[0m [33m: List of peaks by prominence[0m
[5m[7m[96mCachew lbl2data_input_text[0m [96m: ['Highest unclimbed mountain', 'List of peaks by prominence', 'Lists of mountains', 'Mountain', 'List of mountain lists', 'Ultra-prominent summit', 'List of Ultras of South America', 'Lists of highest points', 'List of peaks in Norway by prominence', 'Highest points'][0m
[5m[7m[96mGT lbl2data_input_text[0m [96m: ['List of islands by highest point', 'Topographic prominence', 'List of the most prominent summits of the United States', 'Topographic elevation', 'Topographic isolation', 'List of mountain lists', 'List of ultra-prominent summits of Africa', 'List of ultra-prominent summits of Antarctica', 'List of ultra-prominent summits of Australia', 'List of ultra-prominent summits of the Alps', 'List of the most prominent summits of the British Isles', 'List of ultra-prominent summ

In [72]:
disp = CompareDataset(pred_block_2, gt_dset, 'OAK', 'GT')
disp.show(idxs[:10])

[5m[7m[90mOAK data_input_text[0m [90m: List of peaks by prominence[0m
[5m[7m[90mGT data_input_text[0m [90m: List of peaks by prominence[0m
[5m[7m[92mOAK lbl2data_input_text[0m [92m: ['List of peaks by prominence', 'Topographic prominence', 'summit (topography)', 'topographic elevation', 'topographic prominence', 'topography', 'Topographic elevation', 'Topographic isolation', 'List of the most prominent summits of Central America', 'List of the most prominent summits of the Caribbean'][0m
[5m[7m[92mGT lbl2data_input_text[0m [92m: ['List of islands by highest point', 'Topographic prominence', 'List of the most prominent summits of the United States', 'Topographic elevation', 'Topographic isolation', 'List of mountain lists', 'List of ultra-prominent summits of Africa', 'List of ultra-prominent summits of Antarctica', 'List of ultra-prominent summits of Australia', 'List of ultra-prominent summits of the Alps', 'List of the most prominent summits of the British Isle