# NGAME training pipeline with multi-triplet loss and clustering

In [1]:
#| default_exp 24-oak-training-pipeline-with-multitriplet-loss-and-clustering-and-input-augmentation

In [2]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
%load_ext autoreload
%autoreload 2

In [14]:
#| export
import os,torch, torch.multiprocessing as mp
from xcai.basics import *
from xcai.models.MMM00X import DBT013
from xcai.transform import AugmentMetaInputIdsTfm
from xcai.generation.generate import XCTrieBeamSearch

In [5]:
os.environ['WANDB_MODE'] = 'disabled'

In [7]:
#| export
os.environ['WANDB_PROJECT']='xc-nlg_24-oak-training-pipeline-with-multitriplet-loss-and-clustering-and-input-augmentation'

In [12]:
#| export
block = XCBlock.from_cfg('/home/aiscuser/scratch/datasets', 'data_meta', valid_pct=0.001, tfm='xc', 
                         tokenizer='distilbert-base-uncased')

In [15]:
#| export
block = AugmentMetaInputIdsTfm.apply(block, 'hlk_meta', 32, True)

In [26]:
#| export
block.train.dset.data.data_info['input_ids'] = block.train.dset.data.data_info['input_ids_aug_hlk']
block.train.dset.data.data_info['attention_mask'] = block.train.dset.data.data_info['attention_mask_aug_hlk']

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/aiscuser/outputs/24-oak-training-pipeline-with-multitriplet-loss-and-clustering-and-input-augmentation',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=100,
    save_strategy="steps",
    evaluation_strategy='steps',
    eval_steps=1000,
    save_steps=1000,
    save_total_limit=5,
    num_train_epochs=50,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=0.0,
    predict_with_generation=True,
    label_names=['lbl2data_idx'],
    representation_search_type='BRUTEFORCE',
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    clustering_type='EXPO',
    minimum_cluster_size=1,
)

In [10]:
#| export
test_dset = block.test.dset.sample(n=100, seed=50)

metric = PrecRecl(block.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [11]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = DBT013.from_pretrained('distilbert-base-uncased', ig_tok=0, bsz=bsz, tn_targ=10_000, margin=0.3, tau=0.1, 
                               n_negatives=50, apply_softmax=True, lw=4)

Some weights of DBT013 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['gen_lfn.o', 'repr_lfn.u', 'repr_lfn.v']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
#| export
tbs = XCTrieBeamSearch.from_block(block, max_height=32, sos_id=101, eos_id=102, pad_token=0, 
                                  n_bm=10, len_penalty=0.0)

In [12]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    trie_generator=tbs,
    train_dataset=block.train.dset,
    eval_dataset=test_dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [None]:
#| export
if __name__ == '__main__':
    mp.freeze_support()
    learn.train()