# NGAME training pipeline with multi-triplet loss and clustering

In [1]:
#| default_exp 26-oak-training-pipeline-with-multitriplet-loss-and-clustering-predictions

In [2]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
#| export
import os,torch, torch.multiprocessing as mp
from xcai.basics import *
from xcai.models.MMM00X import DBT013

In [5]:
os.environ['WANDB_MODE'] = 'disabled'

In [6]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
os.environ['WANDB_PROJECT']='xc-nlg_26-oak-training-pipeline-with-multitriplet-loss-and-clustering-predictions'

## Load model and data

In [8]:
#| export
block = XCBlock.from_cfg('/home/aiscuser/scratch/datasets', 'data', valid_pct=0.001, tfm='xc', 
                         tokenizer='distilbert-base-uncased')

  self._set_arrayXarray(i, j, x)


In [9]:
#| export
args = XCLearningArguments(
    output_dir='/home/aiscuser/outputs/22-oak-training-pipeline-with-multitriplet-loss-and-clustering-2-4',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=100,
    save_strategy="steps",
    evaluation_strategy='steps',
    eval_steps=100,
    save_steps=100,
    save_total_limit=5,
    num_train_epochs=50,
    predict_with_representation=True,
    adam_epsilon=1e-8,
    warmup_steps=0,
    weight_decay=0.1,
    learning_rate=1e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    label_names=['lbl2data_idx'],
    representation_search_type='BRUTEFORCE',
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    clustering_type='EXPO',
    minimum_cluster_size=1,
    output_concatenation_weight=1.0,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [10]:
#| export
test_dset = block.test.dset.sample(n=2000, seed=50)
metric = PrecRecl(block.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [12]:
!ls /home/aiscuser/outputs/22-oak-training-pipeline-with-multitriplet-loss-and-clustering-2-4/

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


checkpoint-14700  checkpoint-14900  checkpoint-15100
checkpoint-14800  checkpoint-15000


In [16]:
#| export
mname = f'{args.output_dir}/checkpoint-15000/'

bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*2
model = DBT013.from_pretrained(mname, ig_tok=0, bsz=bsz, tn_targ=10_000, margin=0.3, tau=0.1, n_negatives=10, 
                               apply_softmax=True, lw=1.0)

In [17]:
#| export
trie = XCTrie.from_block(block)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [18]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    trie=trie,
    train_dataset=block.train.dset,
    eval_dataset=test_dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


## Prediction

In [19]:
def reconstruct_score_matrix(learner, test_dset, metric):
    o = learn.predict(test_dset)
    
    targ_idx,targ_ptr = [],[]
    for b in learn.get_test_dataloader(learn.eval_dataset):
        b = b.to('cpu'); targ_idx.append(b['lbl2data_idx']); targ_ptr.append(b['lbl2data_data2ptr'])
    targ_idx,targ_ptr = torch.concat(targ_idx),torch.concat(targ_ptr)
    
    pred = { 'pred_idx': o.pred_idx, 'pred_ptr': o.pred_ptr, 'pred_score': o.pred_score}
    pred['pred_ptr'] = torch.cat([torch.tensor([0]), pred['pred_ptr'].cumsum(dim=0)])
    pred_sparse = metric.get_pred(pred)

    targ = {'targ_idx': targ_idx, 'targ_ptr': targ_ptr}
    targ['targ_ptr'] = torch.cat([torch.tensor([0]), targ['targ_ptr'].cumsum(dim=0)])
    targ_sparse = metric.get_targ(targ)
    
    return o, pred_sparse, targ_sparse


In [20]:
o, pred_sparse, targ_sparse = reconstruct_score_matrix(learn, test_dset, metric)

  0%|          | 0/391 [00:00<?, ?it/s]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


  self._set_arrayXarray(i, j, x)
  self._set_arrayXarray(i, j, x)


In [21]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@1_GEN,P@1_REPR,P@3,P@3_GEN,P@3_REPR,P@5,P@5_GEN,P@5_REPR,P@10,P@10_GEN,P@10_REPR,N@1,N@1_GEN,N@1_REPR,N@3,N@3_GEN,N@3_REPR,N@5,N@5_GEN,N@5_REPR,N@10,N@10_GEN,N@10_REPR,PSP@1,PSP@1_GEN,PSP@1_REPR,PSP@3,PSP@3_GEN,PSP@3_REPR,PSP@5,PSP@5_GEN,PSP@5_REPR,PSP@10,PSP@10_GEN,PSP@10_REPR,PSN@1,PSN@1_GEN,PSN@1_REPR,PSN@3,PSN@3_GEN,PSN@3_REPR,PSN@5,PSN@5_GEN,PSN@5_REPR,PSN@10,PSN@10_GEN,PSN@10_REPR,R@10,R@10_GEN,R@10_REPR,R@100,R@100_GEN,R@100_REPR,R@200,R@200_GEN,R@200_REPR,loss,runtime,samples_per_second,steps_per_second
0,18.5,8.4,25.3,12.5667,5.0833,15.7,9.92,3.97,11.67,7.03,2.635,7.45,18.5,8.4,25.3,18.2993,7.8842,23.5861,19.5174,8.3136,24.1841,22.0413,9.1373,25.7319,13.9373,5.1211,22.296,15.6936,5.0404,21.1397,17.9193,5.6698,22.1449,23.4207,6.9251,25.2607,13.9373,5.1211,22.296,15.6964,5.295,22.3046,17.4526,5.8074,23.4061,20.2287,6.5168,25.0937,28.0914,11.0066,29.2782,43.5747,11.0295,43.3408,47.1039,11.0295,46.9202,3.7838,159.7358,12.521,0.019
