# RAGDA DR training

In [None]:
#| default_exp 72-1-radga-dr-ep-for-wikiseealso-1-0

In [None]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, torch.nn as nn
from xcai.basics import *
from xcai.models.radga import RAD002

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
import math, numpy as np
from scipy import sparse
from xcai.data import MetaXCDataset
from xcai.clustering.cluster import BalancedClusters

In [None]:
#| export
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='xc-nlg_66-radga-dr-ep-for-wikiseealso-2'

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

In [None]:
block = XCBlock.from_cfg(data_dir, 'data_metas', tfm='rm', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data|cat2lbl2data',1,(1,3)), ('cat2data',1,3)])



In [None]:
block = XCBlock.from_cfg(data_dir, 'data_catlnk', tfm='rm', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data|cat2lbl2data|lnk2lbl2data',1, (1,3,3)), ('cat2data',1,3), ('lnk2data',1,3)])



In [None]:
#| export
data_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-metas_distilbert-base-uncased_rm_radga-cat-linker.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [None]:
#| export
with open(pkl_file, 'rb') as file: block = pickle.load(file)

## XCLearner

In [None]:
from fastcore.utils import *
from torch.utils.data import Dataset
from typing import Any, Optional
from xcai.data import MainXCDataset
from xcai.representation.search import BruteForceSearch,IndexSearch

In [None]:
@patch
def _build_aug_index(self:XCLearner, dataset:Optional[Dataset]=None):
    dataset = dataset if self.eval_dataset is None else self.eval_dataset
    dataset = dataset if self.train_dataset is None else self.train_dataset
    
    meta_name = f'{self.args.data_aug_meta_name}_meta' if self.args.data_aug_meta_name is not None else None
    if (
        dataset is not None and 
        dataset.meta is not None and 
        meta_name is not None and 
        meta_name in dataset.meta
    ):
        self.aug_idxs = (
            BruteForceSearch(n_bm=self.args.augmentation_num_beams)
            if self.args.representation_search_type == 'BRUTEFORCE' else
            IndexSearch(space=self.args.index_space, efc=self.args.index_efc, m=self.args.index_m, 
                        efs=self.args.index_efs, n_bm=self.args.augmentation_num_beams, 
                        n_threads=self.args.index_num_threads) 
        )
        
        dset = MainXCDataset(getattr(dataset.meta[meta_name], 'meta_info'))
        dataloader = self.get_test_dataloader(dset)
        rep = self.get_meta_representation(dataloader, to_cpu=isinstance(self.aug_idxs, IndexSearch))
        self.aug_idxs.build(rep)
        

In [None]:
@patch
def augmentation_output(
    self:XCLearner,
    model:nn.Module,
    inputs:Dict[str, Union[torch.Tensor, Any]],
    **kwargs
):
    if self.aug_idxs is None: 
        raise ValueError('Augmentation index `aug_idx` is not initialized.')
        
    """
    Preparing augmentation input
    """
    data_aug_prefix = self.args.data_aug_meta_name if self.args.data_aug_prefix is None else self.args.data_aug_prefix
    rep = self.aug_idxs.get_items(inputs[f'{data_aug_prefix}2data_idx'])
    
    return {
        f'{data_aug_prefix}2data_meta_repr': rep,
    }
    

## Prediction

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/72-radga-dr-ep-for-wikiseealso-1-0',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='BRUTEFORCE',
    
    output_representation_attribute='data_fused_repr',
    label_representation_attribute='data_repr',
    metadata_representation_attribute='data_repr',
    data_augmentation_attribute='data_repr',
    representation_attribute='data_repr',
    clustering_representation_attribute='data_fused_repr',
    
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    use_data_metadata_for_clustering=True,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,

    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    
    use_distributional_representation=False,
    use_encoder_parallel=True,
    max_grad_norm=None, 
    fp16=True,
    
    label_names=['cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask'],

    prune_metadata=False,
    num_metadata_prune_warmup_epochs=10,
    num_metadata_prune_epochs=5,
    metadata_prune_batch_size=2048,
    prune_metadata_names=['cat_meta'],
    use_data_metadata_for_pruning=True,

    predict_with_augmentation=True,
    use_augmentation_index_representation=False,
    
    data_aug_meta_name='cat',
    augmentation_num_beams=3,
    data_aug_prefix='cat',
    use_label_metadata=False,
    
    data_meta_batch_size=2048,
    augment_metadata=False,
    num_metadata_augment_warmup_epochs=10,
    num_metadata_augment_epochs=5,
)

In [None]:
#| export
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
#| export
output_dir = f"/home/scai/phd/aiz218323/scratch/outputs/{os.path.basename(args.output_dir)}"
mname = f'{output_dir}/{os.path.basename(get_best_model(output_dir))}'

In [None]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = RAD002.from_pretrained(mname, num_batch_labels=5000, batch_size=bsz,
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='cat2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               resize_length=5000, use_noise=False, noise_percent=0.5,
                               
                               meta_loss_weight=0.3, fusion_loss_weight=0.1, use_fusion_loss=False,  
                               
                               use_encoder_parallel=True)

Some weights of RAD002 were not initialized from the model checkpoint at /home/scai/phd/aiz218323/scratch/outputs/72-radga-dr-ep-for-wikiseealso-1-0/checkpoint-20 and are newly initialized: ['encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
#| export
o = learn.predict(block.test.dset)
print(o.metrics)

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/411 [00:00<?, ?it/s]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


  self._set_arrayXarray(i, j, x)


{'test_loss': 0.03902879357337952, 'test_P@1': 1.6899980283356335e-05, 'test_P@10': 1.126665352223756e-05, 'test_P@3': 1.1266653522237557e-05, 'test_P@5': 7.88665746556629e-06, 'test_N@1': 1.6899979527806863e-05, 'test_N@10': 2.7909240088774823e-05, 'test_N@3': 1.6435647921753116e-05, 'test_N@5': 1.5355926734628156e-05, 'test_PSP@1': 1.2167058785972205e-05, 'test_PSP@10': 3.336352037262817e-05, 'test_PSP@3': 1.4832848831197245e-05, 'test_PSP@5': 1.3942623317441168e-05, 'test_PSN@1': 1.2167058230261318e-05, 'test_PSN@10': 1.92723728105193e-05, 'test_PSN@3': 1.2745624189847149e-05, 'test_PSN@5': 1.1565142813196871e-05, 'test_R@200': 0.0008486172194166132, 'test_R@10': 4.87819272147357e-05, 'test_R@100': 0.00045826033674922886, 'test_runtime': 323.1363, 'test_samples_per_second': 549.35, 'test_steps_per_second': 0.344}


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,0.0017,0.0011,0.0008,0.0011,0.0017,0.0016,0.0015,0.0028,0.0012,0.0015,0.0014,0.0033,0.0012,0.0013,0.0012,0.0019,0.0049,0.0458,0.0849,0.039,323.1363,549.35,0.344


## Super metadata

In [None]:
dset = MainXCDataset(getattr(block.train.dset.meta['cat_meta'], 'meta_info'))
dataloader = learn.get_test_dataloader(dset)
rep = learn.get_meta_representation(dataloader)

  0%|          | 0/411 [00:00<?, ?it/s]

In [None]:
n_clusters = 64
cluster_size = math.ceil(rep.shape[0]/n_clusters)

In [None]:
clusters = BalancedClusters.proc(rep, cluster_size)

Updating clusters with size 10252
Tree depth = 6
doing random split
lengths: [328043, 328043]
remaining levels for GPU split=5
==> gpu splitting random clusters 0 to 2
 rank=1 => Total clusters 2	Avg. Cluster size                 164021.50	Time to split nodes on this level 1.22 sec
 rank=0 => Total clusters 2	Avg. Cluster size                 164021.50	Time to split nodes on this level 2.07 sec
  rank=1 => Total clusters 4	Avg. Cluster size                 82010.75	Time to split nodes on this level 0.02 secrank=0 => Total clusters 4	Avg. Cluster size                 82010.75	Time to split nodes on this level 0.02 sec

  rank=1 => Total clusters 8	Avg. Cluster size                 41005.38	Time to split nodes on this level 0.02 sec
rank=0 => Total clusters 8	Avg. Cluster size                 41005.38	Time to split nodes on this level 0.02 sec
  rank=1 => Total clusters 16	Avg. Cluster size                 20502.69	Time to split nodes on this level 0.02 secrank=0 => Total clusters 16	Avg

In [None]:
def get_super_clusters(data_meta:sparse.csr_matrix):
    super_data_meta = []
    for o in clusters:
        x = data_meta[:, o].getnnz(axis=1)
        super_data_meta.append(x)
    
    super_data_meta = sparse.csr_matrix(np.vstack(super_data_meta).T)
    return super_data_meta
    

In [None]:
super_cluster_idx = [o[len(o)//2] for o in clusters]

super_data_meta = get_super_clusters(block.test.dset.meta.lnk_meta.data_meta)
super_lbl_meta = get_super_clusters(block.test.dset.meta.lnk_meta.lbl_meta)

super_meta_info = {k: [v[idx] for idx in super_cluster_idx] for k,v in block.train.dset.meta.lnk_meta.meta_info.items()}

In [None]:
super_meta_info['meta_repr'] = rep[super_cluster_idx].tolist()

In [None]:
super_meta_info['meta_repr'] = torch.cat([rep[o].mean(dim=0, keepdims=True) for o in clusters])

torch.Size([64, 768])

In [None]:
block.test.dset.meta['sup_meta'] = MetaXCDataset('sup', super_data_meta, super_lbl_meta, super_meta_info)

## Meta-clustering based thresholding

In [None]:
from xclib.utils.sparse import retain_topk
from tqdm.auto import tqdm
import numpy as np

In [None]:
dset = MainXCDataset(getattr(block.train.dset.meta['cat_meta'], 'meta_info'))
dataloader = learn.get_test_dataloader(dset)
rep = learn.get_meta_representation(dataloader)

  0%|          | 0/411 [00:00<?, ?it/s]

In [None]:
n_clusters = 64
cluster_size = math.ceil(rep.shape[0]/n_clusters)

In [None]:
clusters = BalancedClusters.proc(rep, cluster_size)

Updating clusters with size 10252
Tree depth = 6
doing random split
lengths: [328043, 328043]
remaining levels for GPU split=5
==> gpu splitting random clusters 0 to 2
 rank=0 => Total clusters 2	Avg. Cluster size                 164021.50	Time to split nodes on this level 0.26 sec
 rank=0 => Total clusters 4	Avg. Cluster size                 82010.75	Time to split nodes on this level 0.02 sec
 rank=0 => Total clusters 8	Avg. Cluster size                 41005.38	Time to split nodes on this level 0.02 sec
 rank=0 => Total clusters 16	Avg. Cluster size                 20502.69	Time to split nodes on this level 0.02 sec
 rank=0 => Total clusters 32	Avg. Cluster size                 10251.34	Time to split nodes on this level 0.03 sec
 rank=1 => Total clusters 2	Avg. Cluster size                 164021.50	Time to split nodes on this level 0.27 sec
 rank=1 => Total clusters 4	Avg. Cluster size                 82010.75	Time to split nodes on this level 0.02 sec
 rank=1 => Total clusters 8	Av

In [None]:
def threshold(data_meta:sparse.csr_matrix, k):
    thresholded_data_meta = []
    for o in tqdm(clusters, total=len(clusters)):
        x = retain_topk(data_meta[:, o], k=k)
        thresholded_data_meta.append(x)
        
    return sparse.hstack(thresholded_data_meta)
    

In [None]:
thresholded_data_meta = threshold(block.test.dset.meta.lnk_meta.data_meta, k=1)

  0%|          | 0/64 [00:00<?, ?it/s]

In [None]:
cluster_idx = np.hstack(clusters)

thresholded_data_meta = threshold(block.test.dset.meta.lnk_meta.data_meta, k=1)
thresholded_lbl_meta = threshold(block.test.dset.meta.lnk_meta.lbl_meta, k=1)

threholded_meta_info = {k: [v[idx] for idx in cluster_idx] for k,v in block.train.dset.meta.lnk_meta.meta_info.items()}

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

In [None]:
block.test.dset.meta['sup_meta'] = MetaXCDataset('sup', thresholded_data_meta, thresholded_lbl_meta, threholded_meta_info)