In [1]:
#| default_exp 34_metadexa-for-msmarco

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [24]:
#| export
import os,torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp

from transformers import DistilBertConfig

from xcai.basics import *
from xcai.models.oakY import OAK008

In [5]:
os.environ['WANDB_MODE'] = 'disabled'

In [6]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT'] = 'mogicX_00-msmarco'

## Setup

In [7]:
output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/34_metadexa-for-msmarco'
pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/mogicX'

config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_gpt_exact.json'
config_key = 'data_entity-gpt_exact'

mname = 'sentence-transformers/msmarco-distilbert-cos-v5'

In [8]:
do_train_inference = False
do_test_inference = False

save_train_inference = False
save_test_inference = False

save_representation = False

use_sxc_sampler, only_test = True, False

In [9]:
pkl_file = get_pkl_file(pkl_dir, 'msmarco_data-meta_distilbert-base-uncased', use_sxc_sampler)

In [10]:
pkl_file

'/scratch/scai/phd/aiz218323/datasets/processed/mogicX/msmarco_data-meta_distilbert-base-uncased_sxc.joblib'

In [11]:
do_inference = do_train_inference or do_test_inference or save_train_inference or save_test_inference or save_representation

In [12]:
%%time
os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
block = build_block(pkl_file, config_file, use_sxc_sampler, config_key, do_build=False, only_test=False, meta_oversample=True, 
                    meta_dropout_remove=0.3, meta_dropout_replace=0.3)


CPU times: user 16.2 s, sys: 1.89 s, total: 18.1 s
Wall time: 18.2 s


In [13]:
meta_name = 'ent'

In [30]:
batch = block.train.dset.__getitems__([10, 20, 50])

{k:v for k,v in batch.items() if f'{meta_name}2' in k}

{'pent2data_idx': tensor([  490,   113, 11483]),
 'pent2data_data2ptr': tensor([1, 1, 1]),
 'ent2data_idx': tensor([  490,   113, 11483]),
 'ent2data_dropout_remove_mask': tensor([False,  True,  True]),
 'ent2data_dropout_replace_mask': tensor([ True, False, False]),
 'ent2data_data2ptr': tensor([1, 1, 1]),
 'ent2data_identifier': ['490', '113', '11483'],
 'ent2data_input_text': ['cars', 'mountain', 'venus'],
 'ent2data_input_ids': tensor([[ 101, 3079,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0],
         [ 101, 3231,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0],
         [ 101, 1396, 9299,  102,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    

In [25]:
model = OAK008.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=100, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix=f'{meta_name}2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,

                               num_metadata=block.train.dset.meta[f'{meta_name}_meta'].n_meta, num_metadata_clusters=block.train.dset.meta[f'{meta_name}_meta'].n_meta,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=True, 
                               calib_loss_weight=0.1, use_calib_loss=True,

                               use_query_loss=True,
                               
                               use_encoder_parallel=False, do_meta_embed_sparse=False)
model.init_retrieval_head()
model.init_meta_encoder()
model.init_cross_head()

Some weights of OAK008 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.alpha', 'encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_distilbert.embeddings.LayerNorm.bias', 'encoder.meta_distilbert.embeddings.LayerNorm.weight', 'encoder.meta_distilbert.embeddings.position_embeddings.weight', 'encoder.meta_distilbert.embeddings.word_embeddings.weight', 'encoder.meta_distilbert.transformer.layer.0.attention.k_lin.bias', 'encoder.meta_distilbert.transformer.layer.0.attention.k_lin.weight', 'encoder.

In [34]:
b = prepare_batch(model, batch, m_args=[
    f'p{meta_name}2data_idx', f'p{meta_name}2data_data2ptr', f'{meta_name}2data_idx', f'{meta_name}2data_input_ids', 
    f'{meta_name}2data_attention_mask', f'{meta_name}2data_data2ptr', f'{meta_name}2data_dropout_remove_mask', 
    f'{meta_name}2data_dropout_replace_mask'
])

In [27]:
model = model.to('cuda')

In [28]:
o = model(**b.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [32]:
def func():
    import pdb; pdb.set_trace()
    o = model(**b.to(model.device))
    

In [None]:
func()

> [0;32m/tmp/ipykernel_7783/3721260802.py[0m(3)[0;36mfunc[0;34m()[0m
[0;32m      1 [0;31m[0;32mdef[0m [0mfunc[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      2 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 3 [0;31m    [0mo[0m [0;34m=[0m [0mmodel[0m[0;34m([0m[0;34m**[0m[0mb[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mmodel[0m[0;34m.[0m[0mdevice[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(827)[0;36mforward[0;34m()[0m
[0;32m    825 [0;31m        [0;34m**[0m[0mkwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    826 [0;31m    ):  
[0m[0;32m--> 827 [0;31m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    828 [0;31m[0;34m[0m[0m
[0m[0;32m    829 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(829)[0;36mforward[0;34m()[0m
[0;32m    827 [0;31m        [0mreturn_dict[0m [0;34m=[0m [0mreturn_dict[0m [0;32mif[0m [0mreturn_dict[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32melse[0m [0mself[0m[0;34m.[0m[0mconfig[0m[0;34m.[0m[0muse_return_dict[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    828 [0;31m[0;34m[0m[0m
[0m[0;32m--> 829 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    830 [0;31m            [0mencoder[0m [0;34m=[0m [0mXCDataParallel[0m[0;34m([0m[0mmodule[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mencoder[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    831 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(831)[0;36mforward[0;34m()[0m
[0;32m    829 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0muse_encoder_parallel[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    830 [0;31m            [0mencoder[0m [0;34m=[0m [0mXCDataParallel[0m[0;34m([0m[0mmodule[0m[0;34m=[0m[0mself[0m[0;34m.[0m[0mencoder[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 831 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    832 [0;31m[0;34m[0m[0m
[0m[0;32m    833 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(833)[0;36mforward[0;34m()[0m
[0;32m    831 [0;31m        [0;32melse[0m[0;34m:[0m [0mencoder[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mencoder[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    832 [0;31m[0;34m[0m[0m
[0m[0;32m--> 833 [0;31m        [0mdata_meta_kwargs[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_get_encoder_meta_kwargs[0m[0;34m([0m[0;34m'data'[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mdata_aug_meta_prefix[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    834 [0;31m        data_o = encoder(data_input_ids=data_input_ids, data_attention_mask=data_attention_mask, 
[0m[0;32m    835 [0;31m                         data_aug_meta_prefix=self.data_aug_meta_prefix, **data_meta_kwargs)
[0m


ipdb>  s


--Call--
> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1134)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1132 [0;31m        [0mx[0m[0;34m[[0m[0mmask[0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1133 [0;31m[0;34m[0m[0m
[0m[0;32m-> 1134 [0;31m    [0;32mdef[0m [0m_get_encoder_meta_kwargs[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mfeat[0m[0;34m:[0m[0mstr[0m[0;34m,[0m [0mprefix[0m[0;34m:[0m[0mstr[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1135 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0mfeat[0m[0;34m,[0m [0mprefix[0m[0;34m,[0m [0madditional_keys[0m[0;34m=[0m[0;34m[[0m[0;34m'dropout_remove_mask'[0m[0;34m,[0m [0;34m'dropout_replace_mask'[0m[0;34m][0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m

ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1135)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1133 [0;31m[0;34m[0m[0m
[0m[0;32m   1134 [0;31m    [0;32mdef[0m [0m_get_encoder_meta_kwargs[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mfeat[0m[0;34m:[0m[0mstr[0m[0;34m,[0m [0mprefix[0m[0;34m:[0m[0mstr[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1135 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0mfeat[0m[0;34m,[0m [0mprefix[0m[0;34m,[0m [0madditional_keys[0m[0;34m=[0m[0;34m[[0m[0;34m'dropout_remove_mask'[0m[0;34m,[0m [0;34m'dropout_replace_mask'[0m[0;34m][0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1136 [0;31m        [0;32mif[0m [0;34mf'{prefix}_idx'[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m:[0m[0;34m[0m[0;34

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1136)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1134 [0;31m    [0;32mdef[0m [0m_get_encoder_meta_kwargs[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mfeat[0m[0;34m:[0m[0mstr[0m[0;34m,[0m [0mprefix[0m[0;34m:[0m[0mstr[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1135 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0mfeat[0m[0;34m,[0m [0mprefix[0m[0;34m,[0m [0madditional_keys[0m[0;34m=[0m[0;34m[[0m[0;34m'dropout_remove_mask'[0m[0;34m,[0m [0;34m'dropout_replace_mask'[0m[0;34m][0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1136 [0;31m        [0;32mif[0m [0;34mf'{prefix}_idx'[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1137 [0;31m      

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1137)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1135 [0;31m        [0mmeta_kwargs[0m [0;34m=[0m [0mParameters[0m[0;34m.[0m[0mfrom_feat_meta_aug_prefix[0m[0;34m([0m[0mfeat[0m[0;34m,[0m [0mprefix[0m[0;34m,[0m [0madditional_keys[0m[0;34m=[0m[0;34m[[0m[0;34m'dropout_remove_mask'[0m[0;34m,[0m [0;34m'dropout_replace_mask'[0m[0;34m][0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1136 [0;31m        [0;32mif[0m [0;34mf'{prefix}_idx'[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1137 [0;31m            [0mm_idx[0m [0;34m=[0m [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'{prefix}_idx'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1138 [0;31m            [0mremove_mask[0m [0;34m=[0m [0mmeta_kwargs[0m[0;34m.[0m[0mpop[0m[0;34m([0m[0;34mf'{prefix}_dropout_remove_mas

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1138)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1136 [0;31m        [0;32mif[0m [0;34mf'{prefix}_idx'[0m [0;32min[0m [0mmeta_kwargs[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1137 [0;31m            [0mm_idx[0m [0;34m=[0m [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'{prefix}_idx'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1138 [0;31m            [0mremove_mask[0m [0;34m=[0m [0mmeta_kwargs[0m[0;34m.[0m[0mpop[0m[0;34m([0m[0;34mf'{prefix}_dropout_remove_mask'[0m[0;34m,[0m [0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1139 [0;31m            [0mreplace_mask[0m [0;34m=[0m [0mmeta_kwargs[0m[0;34m.[0m[0mpop[0m[0;34m([0m[0;34mf'{prefix}_dropout_replace_mask'[0m[0;34m,[0m [0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1140 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mm_idx[0m[0;34m)[0m

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1139)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1137 [0;31m            [0mm_idx[0m [0;34m=[0m [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'{prefix}_idx'[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1138 [0;31m            [0mremove_mask[0m [0;34m=[0m [0mmeta_kwargs[0m[0;34m.[0m[0mpop[0m[0;34m([0m[0;34mf'{prefix}_dropout_remove_mask'[0m[0;34m,[0m [0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1139 [0;31m            [0mreplace_mask[0m [0;34m=[0m [0mmeta_kwargs[0m[0;34m.[0m[0mpop[0m[0;34m([0m[0;34mf'{prefix}_dropout_replace_mask'[0m[0;34m,[0m [0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1140 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mm_idx[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1141 [0;31m                [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmeta_emb

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1140)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1138 [0;31m            [0mremove_mask[0m [0;34m=[0m [0mmeta_kwargs[0m[0;34m.[0m[0mpop[0m[0;34m([0m[0;34mf'{prefix}_dropout_remove_mask'[0m[0;34m,[0m [0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1139 [0;31m            [0mreplace_mask[0m [0;34m=[0m [0mmeta_kwargs[0m[0;34m.[0m[0mpop[0m[0;34m([0m[0;34mf'{prefix}_dropout_replace_mask'[0m[0;34m,[0m [0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1140 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mm_idx[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1141 [0;31m                [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmeta_embeddings[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mmetadata_cluster_mapping[0m[0;34m[[0m[0mm_idx[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1141)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1139 [0;31m            [0mreplace_mask[0m [0;34m=[0m [0mmeta_kwargs[0m[0;34m.[0m[0mpop[0m[0;34m([0m[0;34mf'{prefix}_dropout_replace_mask'[0m[0;34m,[0m [0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1140 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mm_idx[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1141 [0;31m                [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmeta_embeddings[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mmetadata_cluster_mapping[0m[0;34m[[0m[0mm_idx[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1142 [0;31m                [0;32mif[0m [0mself[0m[0;34m.[0m[0mtraining[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1143 [0;31m                    [0mself[0m[0;34m.[0m[0mremove_dropout[0m[0;34m([0m

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1142)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1140 [0;31m            [0;32mif[0m [0mlen[0m[0;34m([0m[0mm_idx[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1141 [0;31m                [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmeta_embeddings[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mmetadata_cluster_mapping[0m[0;34m[[0m[0mm_idx[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1142 [0;31m                [0;32mif[0m [0mself[0m[0;34m.[0m[0mtraining[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1143 [0;31m                    [0mself[0m[0;34m.[0m[0mremove_dropout[0m[0;34m([0m[0mmeta_repr[0m[0;34m,[0m [0mremove_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1144 [0;31m                    [0mself[0m[0;34m.[0m[0mreplace_dropout[0m[0;34m([0m[0mmeta_repr[0m[0;34m,[0m [0mself[0m[

ipdb>  self.training = True
ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1143)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1141 [0;31m                [0mmeta_repr[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mmeta_embeddings[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mmetadata_cluster_mapping[0m[0;34m[[0m[0mm_idx[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1142 [0;31m                [0;32mif[0m [0mself[0m[0;34m.[0m[0mtraining[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1143 [0;31m                    [0mself[0m[0;34m.[0m[0mremove_dropout[0m[0;34m([0m[0mmeta_repr[0m[0;34m,[0m [0mremove_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1144 [0;31m                    [0mself[0m[0;34m.[0m[0mreplace_dropout[0m[0;34m([0m[0mmeta_repr[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mmetadata_dropout[0m[0;34m,[0m [0mmask[0m[0;34m=[0m[0mreplace_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;3

ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1144)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1142 [0;31m                [0;32mif[0m [0mself[0m[0;34m.[0m[0mtraining[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1143 [0;31m                    [0mself[0m[0;34m.[0m[0mremove_dropout[0m[0;34m([0m[0mmeta_repr[0m[0;34m,[0m [0mremove_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1144 [0;31m                    [0mself[0m[0;34m.[0m[0mreplace_dropout[0m[0;34m([0m[0mmeta_repr[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mmetadata_dropout[0m[0;34m,[0m [0mmask[0m[0;34m=[0m[0mreplace_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1145 [0;31m                [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'{prefix}_meta_repr'[0m[0;34m][0m [0;34m=[0m [0mmeta_repr[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1146 [0;31m        [0;32mreturn[0m [0mmeta_kwargs[0m[0;34m[0m[0;34m[0m[0m


ipdb>  meta_repr


tensor([[ 0.0222,  0.0072,  0.0079,  ..., -0.0072,  0.0113,  0.0146],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', grad_fn=<IndexPutBackward0>)


ipdb>  remove_mask


tensor([False,  True,  True], device='cuda:0')


ipdb>  s


--Call--
> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1116)[0;36mreplace_dropout[0;34m()[0m
[0;32m   1114 [0;31m        [0mself[0m[0;34m.[0m[0mpost_init[0m[0;34m([0m[0;34m)[0m[0;34m;[0m [0mself[0m[0;34m.[0m[0mremap_post_init[0m[0;34m([0m[0;34m)[0m[0;34m;[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1115 [0;31m[0;34m[0m[0m
[0m[0;32m-> 1116 [0;31m    [0;34m@[0m[0mstaticmethod[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1117 [0;31m    [0;32mdef[0m [0mreplace_dropout[0m[0;34m([0m[0mx[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mp[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mfloat[0m[0;34m][0m[0;34m=[0m[0;36m0.5[0m[0;34m,[0m [0mmask[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1118 [0;31m        [0;32mif[0m [0mmask[0m [0;32mis[0m [0;

ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1118)[0;36mreplace_dropout[0;34m()[0m
[0;32m   1116 [0;31m    [0;34m@[0m[0mstaticmethod[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1117 [0;31m    [0;32mdef[0m [0mreplace_dropout[0m[0;34m([0m[0mx[0m[0;34m:[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mp[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mfloat[0m[0;34m][0m[0;34m=[0m[0;36m0.5[0m[0;34m,[0m [0mmask[0m[0;34m:[0m[0mOptional[0m[0;34m[[0m[0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1118 [0;31m        [0;32mif[0m [0mmask[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m [0mn_rows[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1119 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1120 [0;31m           

ipdb>  mask


tensor([ True, False, False], device='cuda:0')


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1120)[0;36mreplace_dropout[0;34m()[0m
[0;32m   1118 [0;31m        [0;32mif[0m [0mmask[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m [0mn_rows[0m [0;34m=[0m [0mx[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1119 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1120 [0;31m            [0;32massert[0m [0mx[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m [0;34m==[0m [0mmask[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1121 [0;31m            [0mvalid_row_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmask[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1122 [0;31m            [0mn_rows[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mvalid_row_idx[0m[0;34m)[0m

ipdb>  


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1121)[0;36mreplace_dropout[0;34m()[0m
[0;32m   1119 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1120 [0;31m            [0;32massert[0m [0mx[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m [0;34m==[0m [0mmask[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1121 [0;31m            [0mvalid_row_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmask[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1122 [0;31m            [0mn_rows[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mvalid_row_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1123 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1122)[0;36mreplace_dropout[0;34m()[0m
[0;32m   1120 [0;31m            [0;32massert[0m [0mx[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m [0;34m==[0m [0mmask[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1121 [0;31m            [0mvalid_row_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmask[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1122 [0;31m            [0mn_rows[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mvalid_row_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1123 [0;31m[0;34m[0m[0m
[0m[0;32m   1124 [0;31m        [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mrand[0m[0;34m([0m[0mn_rows[0m[0;34m,[0m [0mdevice[0m[0;34m=[0m[0mx[0m[0;34m.[0m[0mdevice[0m

ipdb>  valid_row_idx


tensor([0], device='cuda:0')


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1124)[0;36mreplace_dropout[0;34m()[0m
[0;32m   1122 [0;31m            [0mn_rows[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mvalid_row_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1123 [0;31m[0;34m[0m[0m
[0m[0;32m-> 1124 [0;31m        [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mrand[0m[0;34m([0m[0mn_rows[0m[0;34m,[0m [0mdevice[0m[0;34m=[0m[0mx[0m[0;34m.[0m[0mdevice[0m[0;34m)[0m [0;34m<[0m [0mp[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1125 [0;31m[0;34m[0m[0m
[0m[0;32m   1126 [0;31m        [0;32mif[0m [0mmask[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m [0mx[0m[0;34m[[0m[0midx[0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1126)[0;36mreplace_dropout[0;34m()[0m
[0;32m   1124 [0;31m        [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mrand[0m[0;34m([0m[0mn_rows[0m[0;34m,[0m [0mdevice[0m[0;34m=[0m[0mx[0m[0;34m.[0m[0mdevice[0m[0;34m)[0m [0;34m<[0m [0mp[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1125 [0;31m[0;34m[0m[0m
[0m[0;32m-> 1126 [0;31m        [0;32mif[0m [0mmask[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m [0mx[0m[0;34m[[0m[0midx[0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1127 [0;31m        [0;32melse[0m[0;34m:[0m [0mx[0m[0;34m[[0m[0mvalid_row_idx[0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1128 [0;31m[0;34m[0m[0m
[0m


ipdb>  idx


tensor([], device='cuda:0', dtype=torch.int64)


ipdb>  idx = torch.where(torch.rand(n_rows, device=x.device) < p)[0]; idx


tensor([], device='cuda:0', dtype=torch.int64)


ipdb>  idx = torch.where(torch.rand(n_rows, device=x.device) < p)[0]; idx


tensor([0], device='cuda:0')


ipdb>  l


[1;32m   1121 [0m            [0mvalid_row_idx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mmask[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[1;32m   1122 [0m            [0mn_rows[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mvalid_row_idx[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1123 [0m[0;34m[0m[0m
[1;32m   1124 [0m        [0midx[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mwhere[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mrand[0m[0;34m([0m[0mn_rows[0m[0;34m,[0m [0mdevice[0m[0;34m=[0m[0mx[0m[0;34m.[0m[0mdevice[0m[0;34m)[0m [0;34m<[0m [0mp[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[1;32m   1125 [0m[0;34m[0m[0m
[0;32m-> 1126 [0;31m        [0;32mif[0m [0mmask[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m [0mx[0m[0;34m[[0m[0midx[0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[1;32m   1127 [0m        [0;32mel

ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1127)[0;36mreplace_dropout[0;34m()[0m
[0;32m   1125 [0;31m[0;34m[0m[0m
[0m[0;32m   1126 [0;31m        [0;32mif[0m [0mmask[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m [0mx[0m[0;34m[[0m[0midx[0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1127 [0;31m        [0;32melse[0m[0;34m:[0m [0mx[0m[0;34m[[0m[0mvalid_row_idx[0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1128 [0;31m[0;34m[0m[0m
[0m[0;32m   1129 [0;31m    [0;34m@[0m[0mstaticmethod[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  


--Return--
None
> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1127)[0;36mreplace_dropout[0;34m()[0m
[0;32m   1125 [0;31m[0;34m[0m[0m
[0m[0;32m   1126 [0;31m        [0;32mif[0m [0mmask[0m [0;32mis[0m [0;32mNone[0m[0;34m:[0m [0mx[0m[0;34m[[0m[0midx[0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1127 [0;31m        [0;32melse[0m[0;34m:[0m [0mx[0m[0;34m[[0m[0mvalid_row_idx[0m[0;34m[[0m[0midx[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;36m0[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1128 [0;31m[0;34m[0m[0m
[0m[0;32m   1129 [0;31m    [0;34m@[0m[0mstaticmethod[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/scratch/scai/phd/aiz218323/projects/xcai/xcai/models/oakY.py[0m(1145)[0;36m_get_encoder_meta_kwargs[0;34m()[0m
[0;32m   1143 [0;31m                    [0mself[0m[0;34m.[0m[0mremove_dropout[0m[0;34m([0m[0mmeta_repr[0m[0;34m,[0m [0mremove_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1144 [0;31m                    [0mself[0m[0;34m.[0m[0mreplace_dropout[0m[0;34m([0m[0mmeta_repr[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mmetadata_dropout[0m[0;34m,[0m [0mmask[0m[0;34m=[0m[0mreplace_mask[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1145 [0;31m                [0mmeta_kwargs[0m[0;34m[[0m[0;34mf'{prefix}_meta_repr'[0m[0;34m][0m [0;34m=[0m [0mmeta_repr[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1146 [0;31m        [0;32mreturn[0m [0mmeta_kwargs[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1147 [0;31m[0;34m[0m[0m
[0m


In [37]:
args = XCLearningArguments(
    output_dir=output_dir,
    logging_first_step=True,
    per_device_train_batch_size=512,
    per_device_eval_batch_size=512,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    eval_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    representation_search_type='BRUTEFORCE',

    output_representation_attribute='data_fused_repr',
    label_representation_attribute='data_repr',
    metadata_representation_attribute='data_repr',
    data_augmentation_attribute='data_repr',
    representation_attribute='data_fused_repr',
    clustering_representation_attribute='data_fused_repr',

    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    use_data_metadata_for_clustering=True,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,

    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',

    use_distributional_representation=False,
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,
    
    label_names=['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_input_ids', 'lbl2data_attention_mask', 
                 f'{meta_name}2data_idx', f'{meta_name}2data_data2ptr', f'{meta_name}2data_input_ids', f'{meta_name}2data_attention_mask', 
                 f'{meta_name}2data_dropout_remove_mask', f'{meta_name}2data_dropout_replace_mask'],

    prune_metadata=False,
    num_metadata_prune_warmup_epochs=10,
    num_metadata_prune_epochs=5,
    metadata_prune_batch_size=2048,
    prune_metadata_names=[f'{meta_name}_meta'],
    use_data_metadata_for_pruning=True,

    predict_with_augmentation=False,
    use_augmentation_index_representation=True,

    data_aug_meta_name=meta_name,
    augmentation_num_beams=None,
    data_aug_prefix=meta_name,
    use_label_metadata=False,

    data_meta_batch_size=2048,
    augment_metadata=False,
    num_metadata_augment_warmup_epochs=10,
    num_metadata_augment_epochs=5,

    use_cpu_for_searching=True,
    use_cpu_for_clustering=True,
)

In [39]:
learn = XCLearner(
    model=model,
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
)

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [40]:
trn_dl = learn.get_train_dataloader()
batch = next(iter(trn_dl))

In [41]:
def func():
    import pdb; pdb.set_trace()
    o = model(**batch.to(model.device))
    

In [43]:
func()

## Driver

In [None]:
#| export
if __name__ == '__main__':
    output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/34_metadexa-for-msmarco'

    config_file = '/scratch/scai/phd/aiz218323/datasets/msmarco/XC/configs/entity_gpt_exact.json'
    config_key = 'data_entity-gpt_exact'

    meta_name = 'ent'
    
    mname = 'sentence-transformers/msmarco-distilbert-cos-v5'

    input_args = parse_args()

    pkl_file = get_pkl_file(f'{input_args.pickle_dir}/mogicX', 'msmarco_data_distilbert-base-uncased', input_args.use_sxc_sampler, 
                            input_args.exact)

    do_inference = input_args.do_train_inference or input_args.do_test_inference or input_args.save_train_prediction or input_args.save_test_prediction or input_args.save_representation

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, 
                        only_test=input_args.only_test, n_slbl_samples=3, main_oversample=False, n_sdata_meta_samples=3, 
                        meta_oversample=False, train_meta_topk=5, test_meta_topk=3, meta_dropout_remove=0.3, 
                        meta_dropout_replace=0.3)

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=512,
        per_device_eval_batch_size=512,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
        representation_search_type='BRUTEFORCE',
    
        output_representation_attribute='data_fused_repr',
        label_representation_attribute='data_repr',
        metadata_representation_attribute='data_repr',
        data_augmentation_attribute='data_repr',
        representation_attribute='data_fused_repr',
        clustering_representation_attribute='data_fused_repr',

        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        use_data_metadata_for_clustering=True,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
    
        use_distributional_representation=False,
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
        
        label_names=['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_data2ptr', 'lbl2data_input_ids', 'lbl2data_attention_mask', 
                     f'{meta_name}2data_idx', f'{meta_name}2data_data2ptr', f'{meta_name}2data_input_ids', f'{meta_name}2data_attention_mask', 
                     f'{meta_name}2data_dropout_remove_mask', f'{meta_name}2data_dropout_replace_mask'],
        
        prune_metadata=False,
        num_metadata_prune_warmup_epochs=10,
        num_metadata_prune_epochs=5,
        metadata_prune_batch_size=2048,
        prune_metadata_names=[f'{meta_name}_meta'],
        use_data_metadata_for_pruning=True,
    
        predict_with_augmentation=False,
        use_augmentation_index_representation=True,

        data_aug_meta_name=meta_name,
        augmentation_num_beams=None,
        data_aug_prefix=meta_name,
        use_label_metadata=False,

        data_meta_batch_size=2048,
        augment_metadata=False,
        num_metadata_augment_warmup_epochs=10,
        num_metadata_augment_epochs=5,
    
        use_cpu_for_searching=True,
        use_cpu_for_clustering=True,
    )
    
    def model_fn(mname, bsz):
        model = OAK008.from_pretrained(mname, batch_size=bsz, num_batch_labels=5000, 
                                       margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                                       data_aug_meta_prefix=f'{meta_name}2data', lbl2data_aug_meta_prefix=None, 
                                       data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
        
                                       num_metadata=block.train.dset.meta[f'{meta_name}_meta'].n_meta, 
                                       num_metadata_clusters=block.train.dset.meta[f'{meta_name}_meta'].n_meta,

                                       calib_loss_weight=0.1, use_calib_loss=True,
                                       calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
        
                                       use_query_loss=True,
                                       
                                       use_encoder_parallel=True, do_meta_embed_sparse=False, metadata_dropout=0.3)
        return model
    
    def init_fn(model):
        model.init_retrieval_head()
        model.init_meta_encoder()
        model.init_cross_head()
        model.init_meta_embeddings()

    metric = PrecReclMrr(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                         pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
    model = load_model(args.output_dir, model_fn, {"mname": mname, "bsz": bsz}, init_fn, do_inference=do_inference, use_pretrained=input_args.use_pretrained)
    
    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    main(learn, input_args, n_lbl=block.test.dset.n_lbl, eval_k=10, train_k=10)
    