# RAGDA LORA training

In [None]:
#| default_exp 92-radga-dr-ep-for-wikiseealso-lora-1-0

In [None]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch, torch.multiprocessing as mp, pickle
from xcai.basics import *
from xcai.models.radga_lora import RAD001
from xclib.utils.sparse import retain_topk

from transformers import DistilBertConfig,DistilBertModel

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='xc-nlg_66-radga-dr-ep-for-wikiseealso-2'

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

In [None]:
#| export
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data-linker_distilbert-base-uncased_rm_oak-linker.pkl'

In [None]:
#| export
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
#| export
data_meta = retain_topk(block.train.dset.meta.lnk_meta.data_meta, k=5)
block.train.dset.meta.lnk_meta.data_meta = data_meta
block.train.dset.meta.lnk_meta.curr_data_meta = data_meta

data_meta = retain_topk(block.test.dset.meta.lnk_meta.data_meta, k=3)
block.test.dset.meta.lnk_meta.data_meta = data_meta
block.test.dset.meta.lnk_meta.curr_data_meta = data_meta

## Training

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/92-radga-dr-ep-for-wikiseealso-lora-1-0',
    logging_first_step=True,
    per_device_train_batch_size=10, #800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=10, #5000,
    save_steps=10, #5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='BRUTEFORCE',
    
    output_representation_attribute='data_fused_repr',
    label_representation_attribute='data_repr',
    metadata_representation_attribute='data_repr',
    data_augmentation_attribute='data_repr',
    representation_attribute='data_fused_repr',
    clustering_representation_attribute='data_fused_repr',
    
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    use_data_metadata_for_clustering=True,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,

    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    
    use_distributional_representation=False,
    use_encoder_parallel=True,
    max_grad_norm=None, 
    fp16=True,
    
    label_names=['lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask'],
    
    prune_metadata=False,
    num_metadata_prune_warmup_epochs=10,
    num_metadata_prune_epochs=5,
    metadata_prune_batch_size=2048,
    prune_metadata_names=['cat_meta'],
    use_data_metadata_for_pruning=True,

    predict_with_augmentation=False,
    use_augmentation_index_representation=True,
    
    data_aug_meta_name='lnk',
    augmentation_num_beams=3,
    data_aug_prefix='lnk',
    use_label_metadata=False,
    
    data_meta_batch_size=2048,
    augment_metadata=False,
    num_metadata_augment_warmup_epochs=10,
    num_metadata_augment_epochs=5,
)

In [None]:
#| export
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

base_model = DistilBertModel.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4')

model = RAD001(DistilBertConfig(), resize_length=5000, base_model=base_model, lora_r=8, lora_alpha=32,
               
               batch_size=100, num_batch_labels=5000, margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
               
               use_query_loss=True,
               
               calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, calib_loss_weight=0.1,
               use_calib_loss=True,
               
               meta_loss_weight=0.0, fusion_loss_weight=0.0, use_fusion_loss=False,
               use_encoder_parallel=False)

model.init_retrieval_head()
model.init_cross_head()

In [None]:
metric = PrecRecl(block.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])
test_dset = block.test.dset.sample(n=100)

In [None]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=test_dset, #block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
def func():
    import pdb; pdb.set_trace()
    return learn.train()
    

In [None]:
func()

> /tmp/ipykernel_16160/199652302.py(3)func()
      1 def func():
      2     import pdb; pdb.set_trace()
----> 3     return learn.train()
      4 



ipdb>  c


Step,Training Loss,Validation Loss,P@1,P@10,P@3,P@5,N@1,N@10,N@3,N@5,Psp@1,Psp@10,Psp@3,Psp@5,Psn@1,Psn@10,Psn@3,Psn@5,R@200,R@10,R@100
10,0.0367,0.046844,0.1,0.048,0.093333,0.078,0.1,0.157398,0.128247,0.146054,0.062985,0.135187,0.09813,0.119156,0.062985,0.118691,0.088369,0.105291,0.433421,0.21198,0.38377
20,0.0367,0.040589,0.1,0.049,0.096667,0.076,0.1,0.161122,0.132705,0.146201,0.055039,0.137724,0.101172,0.116275,0.055039,0.117239,0.087297,0.102091,0.409948,0.223702,0.370837


  0%|          | 0/196 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)
Checkpoint destination directory /home/scai/phd/aiz218323/scratch/outputs/92-radga-dr-ep-for-wikiseealso-lora-1-0/checkpoint-10 already exists and is non-empty. Saving will proceed but saved results may be invalid.


  0%|          | 0/196 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)



Program interrupted. (Use 'cont' to resume).
--Return--
None
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/autograd/__init__.py(266)backward()
    264     # some Python versions print out the first line of a multi-line function
    265     # calls in the traceback and some print out the last line
--> 266     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    267         tensors,
    268         grad_tensors_,



ipdb>  n


--Return--
None
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/_tensor.py(522)backward()
    520                 inputs=inputs,
    521             )
--> 522         torch.autograd.backward(
    523             self, gradient, retain_graph, create_graph, inputs=inputs
    524         )



ipdb>  r


--Return--
None
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/accelerator.py(1999)backward()
   1997             return
   1998         elif self.scaler is not None:
-> 1999             self.scaler.scale(loss).backward(**kwargs)
   2000         else:
   2001             loss.backward(**kwargs)



ipdb>  r


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/comet_ml/monkey_patching.py(294)wrapper()
    292 
    293         # Call after callbacks once we have the return value
--> 294         if should_run:
    295             for callback in after_callbacks:
    296                 callback_allows_exception = getattr(



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/comet_ml/monkey_patching.py(316)wrapper()
    314                     )
    315 
--> 316         if exception_raised is not None:
    317             raise exception_raised
    318 



ipdb>  r


--Return--
None
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/comet_ml/monkey_patching.py(319)wrapper()
    317             raise exception_raised
    318 
--> 319         return return_value
    320 
    321     # Simulate functools.wraps behavior but make it working with mocks



ipdb>  n


> /home/scai/phd/aiz218323/.local/lib/python3.9/site-packages/transformers/trainer.py(2913)training_step()
   2911             self.accelerator.backward(loss)
   2912 
-> 2913         return loss.detach() / self.args.gradient_accumulation_steps
   2914 
   2915     def compute_loss(self, model, inputs, return_outputs=False):



ipdb>  self.model.forward


<bound method convert_outputs_to_fp32.<locals>.forward of RAD001(
  (encoder): Encoder(
    (distilbert): PeftModel(
      (base_model): LoraModel(
        (model): DistilBertModel(
          (embeddings): Embeddings(
            (word_embeddings): Embedding(30522, 768, padding_idx=0)
            (position_embeddings): Embedding(512, 768)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (transformer): Transformer(
            (layer): ModuleList(
              (0-5): 6 x TransformerBlock(
                (attention): MultiHeadSelfAttention(
                  (dropout): Dropout(p=0.1, inplace=False)
                  (q_lin): lora.Linear(
                    (base_layer): Linear(in_features=768, out_features=768, bias=True)
                    (lora_dropout): ModuleDict(
                      (lbl2data): Dropout(p=0.05, inplace=False)
                      (lnk2data): Dropout(p=0.05,

ipdb>  b self.model.forward


Breakpoint 3 at /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/utils/operations.py:821


ipdb>  b self.model.encoder.forward


Breakpoint 4 at /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py:352


ipdb>  c


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/utils/operations.py(822)forward()
    820 
3   821     def forward(*args, **kwargs):
--> 822         return model_forward(*args, **kwargs)
    823 
    824     # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`



ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/utils/operations.py(809)__call__()
    807         update_wrapper(self, model_forward)
    808 
--> 809     def __call__(self, *args, **kwargs):
    810         return convert_to_fp32(self.model_forward(*args, **kwargs))
    811 



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/utils/operations.py(810)__call__()
    808 
    809     def __call__(self, *args, **kwargs):
--> 810         return convert_to_fp32(self.model_forward(*args, **kwargs))
    811 
    812     def __getstate__(self):



ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/amp/autocast_mode.py(13)decorate_autocast()
     11 
     12 def autocast_decorator(autocast_instance, func):
---> 13     @functools.wraps(func)
     14     def decorate_autocast(*args, **kwargs):
     15         with autocast_instance:



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/amp/autocast_mode.py(15)decorate_autocast()
     13     @functools.wraps(func)
     14     def decorate_autocast(*args, **kwargs):
---> 15         with autocast_instance:
     16             return func(*args, **kwargs)
     17 



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/torch/amp/autocast_mode.py(16)decorate_autocast()
     14     def decorate_autocast(*args, **kwargs):
     15         with autocast_instance:
---> 16             return func(*args, **kwargs)
     17 
     18     decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode"  # type: ignore[attr-defined]



ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/utils/operations.py(821)forward()
    819     model_forward = ConvertOutputsToFp32(model_forward)
    820 
3-> 821     def forward(*args, **kwargs):
    822         return model_forward(*args, **kwargs)
    823 



ipdb>  n


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/utils/operations.py(822)forward()
    820 
3   821     def forward(*args, **kwargs):
--> 822         return model_forward(*args, **kwargs)
    823 
    824     # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`



ipdb>  s


--Call--
> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/utils/operations.py(809)__call__()
    807         update_wrapper(self, model_forward)
    808 
--> 809     def __call__(self, *args, **kwargs):
    810         return convert_to_fp32(self.model_forward(*args, **kwargs))
    811 



ipdb>  c


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/utils/operations.py(822)forward()
    820 
3   821     def forward(*args, **kwargs):
--> 822         return model_forward(*args, **kwargs)
    823 
    824     # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`



ipdb>  c


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(361)forward()
    359         **kwargs
    360     ):
--> 361         data_o = self.encode(data_input_ids, data_attention_mask)
    362 
    363         if data_type is not None and data_type == "meta":



ipdb>  xx = [n for n,p in self.named_parameters() if p.requires_grad]
ipdb>  len(xx)


198


ipdb>  self.distilbert.active_adapters


['lbl2data']


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(363)forward()
    361         data_o = self.encode(data_input_ids, data_attention_mask)
    362 
--> 363         if data_type is not None and data_type == "meta":
    364             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    365         else:



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(366)forward()
    364             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    365         else:
--> 366             data_repr = self.dr(data_o[0], data_attention_mask)
    367 
    368         data_fused_repr = meta_repr = None



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(368)forward()
    366             data_repr = self.dr(data_o[0], data_attention_mask)
    367 
--> 368         data_fused_repr = meta_repr = None
    369         if data_aug_meta_prefix is not None:
    370             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(369)forward()
    367 
    368         data_fused_repr = meta_repr = None
--> 369         if data_aug_meta_prefix is not None:
    370             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    371             if len(meta_kwargs):



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(370)forward()
    368         data_fused_repr = meta_repr = None
    369         if data_aug_meta_prefix is not None:
--> 370             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    371             if len(meta_kwargs):
    372                 if self.training: self._mark_only_adapters_as_trainable()



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(371)forward()
    369         if data_aug_meta_prefix is not None:
    370             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
--> 371             if len(meta_kwargs):
    372                 if self.training: self._mark_only_adapters_as_trainable()
    373                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(372)forward()
    370             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    371             if len(meta_kwargs):
--> 372                 if self.training: self._mark_only_adapters_as_trainable()
    373                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    374                                                                              data_attention_mask,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(373)forward()
    371             if len(meta_kwargs):
    372                 if self.training: self._mark_only_adapters_as_trainable()
--> 373                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    374                                                                              data_attention_mask,
    375                                                                              meta_kwargs)



ipdb>  xx = [n for n,p in self.named_parameters() if p.requires_grad]
ipdb>  len(xx)


98


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(374)forward()
    372                 if self.training: self._mark_only_adapters_as_trainable()
    373                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
--> 374                                                                              data_attention_mask,
    375                                                                              meta_kwargs)
    376                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(375)forward()
    373                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    374                                                                              data_attention_mask,
--> 375                                                                              meta_kwargs)
    376                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    377 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(373)forward()
    371             if len(meta_kwargs):
    372                 if self.training: self._mark_only_adapters_as_trainable()
--> 373                 data_fused_embed, meta_repr = self.fuse_meta_into_embeddings(data_o[0], 
    374                                                                              data_attention_mask,
    375                                                                              meta_kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(376)forward()
    374                                                                              data_attention_mask,
    375                                                                              meta_kwargs)
--> 376                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    377 
    378                 self.distilbert.set_adapter('lbl2data')



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(378)forward()
    376                 data_fused_repr = self.dr_fused(data_fused_embed, data_attention_mask)
    377 
--> 378                 self.distilbert.set_adapter('lbl2data')
    379                 if self.training: self._mark_entire_encoder_as_trainable()
    380 



ipdb>  self.distilbert.active_adapters


['lnk2data']


ipdb>  xx = [n for n,p in self.named_parameters() if p.requires_grad]
ipdb>  len(xx)


62


ipdb>  aa = [o for o in xx if 'lnk2data' in o]
ipdb>  len(aa)


36


ipdb>  aa = [o for o in xx if 'lbl2data' in o]
ipdb>  len(aa)


0


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(379)forward()
    377 
    378                 self.distilbert.set_adapter('lbl2data')
--> 379                 if self.training: self._mark_entire_encoder_as_trainable()
    380 
    381         return EncoderOutput(



ipdb>  self.distilbert.active_adapters


['lbl2data']


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(381)forward()
    379                 if self.training: self._mark_entire_encoder_as_trainable()
    380 
--> 381         return EncoderOutput(
    382             rep=data_repr,
    383             fused_rep=data_fused_repr,



ipdb>  xx = [n for n,p in self.named_parameters() if p.requires_grad]
ipdb>  len(xx)


198


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(382)forward()
    380 
    381         return EncoderOutput(
--> 382             rep=data_repr,
    383             fused_rep=data_fused_repr,
    384             meta_repr=meta_repr,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(383)forward()
    381         return EncoderOutput(
    382             rep=data_repr,
--> 383             fused_rep=data_fused_repr,
    384             meta_repr=meta_repr,
    385         )



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(384)forward()
    382             rep=data_repr,
    383             fused_rep=data_fused_repr,
--> 384             meta_repr=meta_repr,
    385         )
    386 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(381)forward()
    379                 if self.training: self._mark_entire_encoder_as_trainable()
    380 
--> 381         return EncoderOutput(
    382             rep=data_repr,
    383             fused_rep=data_fused_repr,



ipdb>  


--Return--
EncoderOutput...vBackward0>)})
> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(381)forward()
    379                 if self.training: self._mark_entire_encoder_as_trainable()
    380 
--> 381         return EncoderOutput(
    382             rep=data_repr,
    383             fused_rep=data_fused_repr,



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(565)forward()
    563 
    564 
--> 565         loss = None; lbl2data_o = EncoderOutput()
    566         if lbl2data_input_ids is not None:
    567             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(566)forward()
    564 
    565         loss = None; lbl2data_o = EncoderOutput()
--> 566         if lbl2data_input_ids is not None:
    567             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    568             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(567)forward()
    565         loss = None; lbl2data_o = EncoderOutput()
    566         if lbl2data_input_ids is not None:
--> 567             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    568             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    569                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(568)forward()
    566         if lbl2data_input_ids is not None:
    567             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
--> 568             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    569                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    570 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(569)forward()
    567             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    568             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
--> 569                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    570 
    571             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(568)forward()
    566         if lbl2data_input_ids is not None:
    567             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
--> 568             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    569                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    570 



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(569)forward()
    567             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
    568             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
--> 569                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    570 
    571             loss = self.compute_loss(data_o.fused_rep, lbl2data_o.rep,lbl2data_data2ptr,lbl2data_idx,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(568)forward()
    566         if lbl2data_input_ids is not None:
    567             lbl2data_meta_kwargs = Parameters.from_feat_meta_aug_prefix('lbl2data', self.lbl2data_aug_meta_prefix, **kwargs)
--> 568             lbl2data_o = encoder(data_input_ids=lbl2data_input_ids, data_attention_mask=lbl2data_attention_mask, 
    569                                  data_aug_meta_prefix=self.lbl2data_aug_meta_prefix, **lbl2data_meta_kwargs)
    570 



ipdb>  c


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(361)forward()
    359         **kwargs
    360     ):
--> 361         data_o = self.encode(data_input_ids, data_attention_mask)
    362 
    363         if data_type is not None and data_type == "meta":



ipdb>  xx = [n for n,p in self.named_parameters() if p.requires_grad]
ipdb>  len(xx)


198


ipdb>  aa = [o for o in xx if 'lnk2data' in o]
ipdb>  len(aa)


36


ipdb>  aa = [o for o in xx if 'lbl2data' in o]
ipdb>  len(aa)


36


ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(363)forward()
    361         data_o = self.encode(data_input_ids, data_attention_mask)
    362 
--> 363         if data_type is not None and data_type == "meta":
    364             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    365         else:



ipdb>  n


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(366)forward()
    364             data_repr = self.meta_unnormalized(data_o[0], data_attention_mask) if data_unnormalized else self.meta(data_o[0], data_attention_mask)
    365         else:
--> 366             data_repr = self.dr(data_o[0], data_attention_mask)
    367 
    368         data_fused_repr = meta_repr = None



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(368)forward()
    366             data_repr = self.dr(data_o[0], data_attention_mask)
    367 
--> 368         data_fused_repr = meta_repr = None
    369         if data_aug_meta_prefix is not None:
    370             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(369)forward()
    367 
    368         data_fused_repr = meta_repr = None
--> 369         if data_aug_meta_prefix is not None:
    370             meta_kwargs = Parameters.from_meta_aug_prefix(data_aug_meta_prefix, **kwargs)
    371             if len(meta_kwargs):



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(381)forward()
    379                 if self.training: self._mark_entire_encoder_as_trainable()
    380 
--> 381         return EncoderOutput(
    382             rep=data_repr,
    383             fused_rep=data_fused_repr,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(382)forward()
    380 
    381         return EncoderOutput(
--> 382             rep=data_repr,
    383             fused_rep=data_fused_repr,
    384             meta_repr=meta_repr,



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/radga_lora.py(383)forward()
    381         return EncoderOutput(
    382             rep=data_repr,
--> 383             fused_rep=data_fused_repr,
    384             meta_repr=meta_repr,
    385         )



ipdb>  c


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg/lib/python3.9/site-packages/accelerate/utils/operations.py(822)forward()
    820 
3   821     def forward(*args, **kwargs):
--> 822         return model_forward(*args, **kwargs)
    823 
    824     # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`



ipdb>  q


xx = [n for n,p in self.named_parameters() if p.requires_grad]

aa = [o for o in xx if 'lnk2data' in o]

In [None]:
learn.train()

Step,Training Loss,Validation Loss


  0%|          | 0/196 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)

KeyboardInterrupt



In [None]:
#| export
if __name__ == '__main__':
    mp.freeze_support()
    learn.train()
    

## Prediction

In [None]:
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/80-radga-dr-ep-for-wikiseealso-1-0',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='BRUTEFORCE',
    
    output_representation_attribute='data_fused_repr',
    label_representation_attribute='data_repr',
    metadata_representation_attribute='data_repr',
    data_augmentation_attribute='data_repr',
    representation_attribute='data_fused_repr',
    clustering_representation_attribute='data_fused_repr',
    
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    use_data_metadata_for_clustering=True,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,

    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    
    use_distributional_representation=False,
    use_encoder_parallel=True,
    max_grad_norm=None, 
    fp16=True,
    
    label_names=['lnk2data_idx', 'lnk2data_input_ids', 'lnk2data_attention_mask'],

    prune_metadata=False,
    num_metadata_prune_warmup_epochs=10,
    num_metadata_prune_epochs=5,
    metadata_prune_batch_size=2048,
    prune_metadata_names=['cat_meta'],
    use_data_metadata_for_pruning=True,

    predict_with_augmentation=False,
    use_augmentation_index_representation=True,
    
    data_aug_meta_name='lnk',
    augmentation_num_beams=3,
    data_aug_prefix='lnk',
    use_label_metadata=False,
    
    data_meta_batch_size=2048,
    augment_metadata=False,
    num_metadata_augment_warmup_epochs=10,
    num_metadata_augment_epochs=5,
)

In [None]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = RAD006.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=bsz, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,

                               resize_length=5000, use_noise=False, shuffle_noise_pct=0.5, dropout_noise_pct=0.1,
                               
                               use_query_loss=True,

                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, calib_loss_weight=0.1,
                               use_calib_loss=True,
                               
                               meta_loss_weight=0.0, fusion_loss_weight=0.0, use_fusion_loss=False,
                               use_encoder_parallel=False)

Some weights of RAD006 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 'encoder.meta_head.transform.bias', 'encoder.meta_head.transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
model.eval()

data_dset = learn._get_dataset(learn.train_dataset, dset_type='data', use_metadata=False)
dataloader = learn.get_test_dataloader(data_dset)
train_data_repr = learn.get_representation(dataloader, representation_attribute='data_repr')

data_dset = learn._get_dataset(learn.eval_dataset, dset_type='data', use_metadata=True)
dataloader = learn.get_test_dataloader(data_dset)
test_data_repr = learn.get_representation(dataloader, representation_attribute='data_fused_repr')

  0%|          | 0/434 [00:00<?, ?it/s]

  0%|          | 0/111 [00:00<?, ?it/s]

In [None]:
train_data_repr = train_data_repr.to('cuda')
test_data_repr = test_data_repr.to('cuda')

In [None]:
from scipy import sparse
from xcai.analysis import *

In [None]:
from torch.utils.data import DataLoader
from scipy import sparse
from tqdm.auto import tqdm

score, indices, topk = None, None, 3
dl = DataLoader(test_data_repr, batch_size=1000)

for x in tqdm(dl, total=len(dl)):
    o = x@train_data_repr.T
    sc,idx = torch.topk(o, topk, dim=1)

    score = sc if score is None else torch.cat([score, sc], dim=0)
    indices = idx if indices is None else torch.cat([indices, idx], dim=0)

score, indices = score.cpu(), indices.cpu()
indptr = torch.arange(0, (score.shape[0]+1)*topk, topk)

  0%|          | 0/178 [00:00<?, ?it/s]

In [None]:
test_hlk = sparse.csr_matrix((score.flatten(), indices.flatten(), indptr))

In [None]:
fname = "test_hlk.pkl"
with open(fname, 'wb') as file: pickle.dump(test_hlk, file)

In [None]:
from xcai.data import *
test_dset = TextColumns(MainXCDataset(block.test.dset.data.data_info, test_hlk, block.train.dset.data.data_info))
test_dset[2000]

{'data_input_text': 'Mathematical model',
 'lbl2data_input_text': ['Polyhedron model',
  'Data model',
  'Simulation modeling']}

In [None]:
test_cat = test_hlk@block.train.dset.meta.cat_meta.data_meta

block.test.dset.meta.lnk_meta.data_meta = test_cat
block.test.dset.meta.lnk_meta.curr_data_meta = test_cat

In [None]:
o = learn.predict(block.test.dset)
print(o.metrics)

  0%|          | 0/196 [00:00<?, ?it/s]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


  self._set_arrayXarray(i, j, x)


{'test_loss': 0.10378105938434601, 'test_P@1': 0.16729290482494438, 'test_P@10': 0.05309973805035105, 'test_P@3': 0.10881897304453633, 'test_P@5': 0.08202799763404925, 'test_N@1': 0.16729290783405304, 'test_N@10': 0.18368692696094513, 'test_N@3': 0.16428664326667786, 'test_N@5': 0.17077775299549103, 'test_PSP@1': 0.1586184204459441, 'test_PSP@10': 0.1977837675423582, 'test_PSP@3': 0.1613118277559351, 'test_PSP@5': 0.17142875641498287, 'test_PSN@1': 0.15861842036247253, 'test_PSN@10': 0.19298529624938965, 'test_PSN@3': 0.16770175099372864, 'test_PSN@5': 0.17800629138946533, 'test_R@200': 0.3965625126598376, 'test_R@10': 0.22037828039708174, 'test_R@100': 0.35519597960181626, 'test_runtime': 295.0679, 'test_samples_per_second': 601.607, 'test_steps_per_second': 0.376}


In [None]:
pattern = r'^(data|cat2data|lnk2data)_input_text$'
dset = TextColumns(get_pred_dset(pred, block), pat=pattern)

In [None]:
dset[0]