# RAGDA LORA training

In [None]:
#| default_exp 92-radga-dr-ep-for-wikiseealso-lora-1-0

In [None]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch, torch.multiprocessing as mp, pickle
from xcai.basics import *
from xcai.models.radga_lora import RAD001
from xclib.utils.sparse import retain_topk

from transformers import DistilBertConfig,DistilBertModel

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='xc-nlg_66-radga-dr-ep-for-wikiseealso-2'

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

In [None]:
#| export
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-linker_distilbert-base-uncased_rm_oak-linker.pkl'

In [None]:
#| export
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
#| export
data_meta = retain_topk(block.train.dset.meta.lnk_meta.data_meta, k=5)
block.train.dset.meta.lnk_meta.data_meta = data_meta
block.train.dset.meta.lnk_meta.curr_data_meta = data_meta

data_meta = retain_topk(block.test.dset.meta.lnk_meta.data_meta, k=3)
block.test.dset.meta.lnk_meta.data_meta = data_meta
block.test.dset.meta.lnk_meta.curr_data_meta = data_meta

## Training

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/92-radga-dr-ep-for-wikiseealso-lora-1-0',
    logging_first_step=True,
    per_device_train_batch_size=10, #800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=10, #5000,
    save_steps=10, #5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='BRUTEFORCE',
    
    output_representation_attribute='data_fused_repr',
    label_representation_attribute='data_repr',
    metadata_representation_attribute='data_repr',
    data_augmentation_attribute='data_repr',
    representation_attribute='data_fused_repr',
    clustering_representation_attribute='data_fused_repr',
    
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    use_data_metadata_for_clustering=True,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,

    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    
    use_distributional_representation=False,
    use_encoder_parallel=True,
    max_grad_norm=None, 
    fp16=True,
    
    label_names=['lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask'],
    
    prune_metadata=False,
    num_metadata_prune_warmup_epochs=10,
    num_metadata_prune_epochs=5,
    metadata_prune_batch_size=2048,
    prune_metadata_names=['cat_meta'],
    use_data_metadata_for_pruning=True,

    predict_with_augmentation=False,
    use_augmentation_index_representation=True,
    
    data_aug_meta_name='lnk',
    augmentation_num_beams=3,
    data_aug_prefix='lnk',
    use_label_metadata=False,
    
    data_meta_batch_size=2048,
    augment_metadata=False,
    num_metadata_augment_warmup_epochs=10,
    num_metadata_augment_epochs=5,
)

In [None]:
#| export
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

base_model = DistilBertModel.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4')

model = RAD001(DistilBertConfig(), resize_length=5000, base_model=base_model, lora_r=8, lora_alpha=32,
               
               batch_size=100, num_batch_labels=5000, margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
               
               use_query_loss=True,
               
               calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, calib_loss_weight=0.1,
               use_calib_loss=True,
               
               meta_loss_weight=0.0, fusion_loss_weight=0.0, use_fusion_loss=False,
               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()

In [None]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
learn.train()

Step,Training Loss,Validation Loss,P@1,P@10,P@3,P@5,N@1,N@10,N@3,N@5,Psp@1,Psp@10,Psp@3,Psp@5,Psn@1,Psn@10,Psn@3,Psn@5,R@200,R@10,R@100
10,0.0359,0.084405,0.146348,0.051153,0.098865,0.076378,0.146348,0.171785,0.148467,0.156732,0.111032,0.16792,0.123681,0.137515,0.111032,0.151191,0.124306,0.135325,0.428242,0.214766,0.376105


  0%|          | 0/196 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


  0%|          | 0/196 [00:00<?, ?it/s]


KeyboardInterrupt



In [None]:
#| export
if __name__ == '__main__':
    mp.freeze_support()
    learn.train()
    

## Prediction

In [None]:
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/80-radga-dr-ep-for-wikiseealso-1-0',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='BRUTEFORCE',
    
    output_representation_attribute='data_fused_repr',
    label_representation_attribute='data_repr',
    metadata_representation_attribute='data_repr',
    data_augmentation_attribute='data_repr',
    representation_attribute='data_fused_repr',
    clustering_representation_attribute='data_fused_repr',
    
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    use_data_metadata_for_clustering=True,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,

    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    
    use_distributional_representation=False,
    use_encoder_parallel=True,
    max_grad_norm=None, 
    fp16=True,
    
    label_names=['lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask'],

    prune_metadata=False,
    num_metadata_prune_warmup_epochs=10,
    num_metadata_prune_epochs=5,
    metadata_prune_batch_size=2048,
    prune_metadata_names=['cat_meta'],
    use_data_metadata_for_pruning=True,

    predict_with_augmentation=False,
    use_augmentation_index_representation=True,
    
    data_aug_meta_name='lnk',
    augmentation_num_beams=3,
    data_aug_prefix='lnk',
    use_label_metadata=False,
    
    data_meta_batch_size=2048,
    augment_metadata=False,
    num_metadata_augment_warmup_epochs=10,
    num_metadata_augment_epochs=5,
)

In [None]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = RAD006.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=bsz, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,

                               resize_length=5000, use_noise=False, shuffle_noise_pct=0.5, dropout_noise_pct=0.1,
                               
                               use_query_loss=True,

                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, calib_loss_weight=0.1,
                               use_calib_loss=True,
                               
                               meta_loss_weight=0.0, fusion_loss_weight=0.0, use_fusion_loss=False,
                               use_encoder_parallel=False)

Some weights of RAD006 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
model.eval()

data_dset = learn._get_dataset(learn.train_dataset, dset_type='data', use_metadata=False)
dataloader = learn.get_test_dataloader(data_dset)
train_data_repr = learn.get_representation(dataloader, representation_attribute='data_repr')

data_dset = learn._get_dataset(learn.eval_dataset, dset_type='data', use_metadata=True)
dataloader = learn.get_test_dataloader(data_dset)
test_data_repr = learn.get_representation(dataloader, representation_attribute='data_fused_repr')

  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/111 [00:00<?, ?it/s]

In [None]:
train_data_repr = train_data_repr.to('cuda')
test_data_repr = test_data_repr.to('cuda')

In [None]:
from scipy import sparse
from xcai.analysis import *

In [None]:
from torch.utils.data import DataLoader
from scipy import sparse
from tqdm.auto import tqdm

score, indices, topk = None, None, 3
dl = DataLoader(test_data_repr, batch_size=1000)

for x in tqdm(dl, total=len(dl)):
    o = x@train_data_repr.T
    sc,idx = torch.topk(o, topk, dim=1)

    score = sc if score is None else torch.cat([score, sc], dim=0)
    indices = idx if indices is None else torch.cat([indices, idx], dim=0)

score, indices = score.cpu(), indices.cpu()
indptr = torch.arange(0, (score.shape[0]+1)*topk, topk)

  0%|          | 0/178 [00:00<?, ?it/s]

In [None]:
test_hlk = sparse.csr_matrix((score.flatten(), indices.flatten(), indptr))

In [None]:
fname = "test_hlk.pkl"
with open(fname, 'wb') as file: pickle.dump(test_hlk, file)

In [None]:
from xcai.data import *
test_dset = TextColumns(MainXCDataset(block.test.dset.data.data_info, test_hlk, block.train.dset.data.data_info))
test_dset[2000]

{'data_input_text': 'Mathematical model',
 'lbl2data_input_text': ['Polyhedron model',
  'Data model',
  'Simulation modeling']}

In [None]:
test_cat = test_hlk@block.train.dset.meta.cat_meta.data_meta

block.test.dset.meta.lnk_meta.data_meta = test_cat
block.test.dset.meta.lnk_meta.curr_data_meta = test_cat

In [None]:
o = learn.predict(block.test.dset)
print(o.metrics)

  0%|          | 0/196 [00:00<?, ?it/s]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


  self._set_arrayXarray(i, j, x)


{'test_loss': 0.10378105938434601, 'test_P@1': 0.16729290482494438, 'test_P@10': 0.05309973805035105, 'test_P@3': 0.10881897304453633, 'test_P@5': 0.08202799763404925, 'test_N@1': 0.16729290783405304, 'test_N@10': 0.18368692696094513, 'test_N@3': 0.16428664326667786, 'test_N@5': 0.17077775299549103, 'test_PSP@1': 0.1586184204459441, 'test_PSP@10': 0.1977837675423582, 'test_PSP@3': 0.1613118277559351, 'test_PSP@5': 0.17142875641498287, 'test_PSN@1': 0.15861842036247253, 'test_PSN@10': 0.19298529624938965, 'test_PSN@3': 0.16770175099372864, 'test_PSN@5': 0.17800629138946533, 'test_R@200': 0.3965625126598376, 'test_R@10': 0.22037828039708174, 'test_R@100': 0.35519597960181626, 'test_runtime': 295.0679, 'test_samples_per_second': 601.607, 'test_steps_per_second': 0.376}


In [None]:
pattern = r'^(data|cat2data|lnk2data)_input_text$'
dset = TextColumns(get_pred_dset(pred, block), pat=pattern)

In [None]:
dset[0]