# Ramen with input augmentation

In [1]:
#| default_exp 25-ramen-style-clover-training-with-input-augmentation

In [2]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [3]:
#| export
import os,torch, torch.multiprocessing as mp, sys, pickle
from xcai.basics import *
from xcai.models.MMM0XX import DBT017
from xcai.transform import AugmentMetaInputIdsTfm

In [4]:
os.environ['WANDB_MODE'] = 'disabled'

In [5]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'
os.environ['WANDB_PROJECT']='xc-nlg_25-ramen-style-clover-training-with-input-augmentation'

In [None]:
#| export
data_dir = '/home/aiscuser/scratch/datasets'

In [6]:
block = XCBlock.from_cfg(data_dir, 'data_metas', valid_pct=0.001, tfm='rm', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data|cat2lbl2data',1, 2), ('cat2data',1, 1)], 
                         n_data_meta_samples=50, n_lbl_meta_samples=50)

  self._set_arrayXarray(i, j, x)


In [7]:
#| export
pkl_dir = f'{data_dir}/processed/'

In [8]:
with open(f'{pkl_dir}/wikiseealso_data-metas_distilbert-base-uncased_rm_ramen-cat.pkl', 'wb') as file: 
    pickle.dump(block, file)

In [None]:
#| export
with open(f'{pkl_dir}/wikiseealso_data-metas_distilbert-base-uncased_rm_ramen-cat.pkl', 'rb') as file: 
    block = pickle.load(file)

In [9]:
#| export
block = AugmentMetaInputIdsTfm.apply(block, 'hlk_meta', 32, True)

block.train.dset.data.data_info['input_ids'] = block.train.dset.data.data_info['input_ids_aug_hlk']
block.train.dset.data.data_info['attention_mask'] = block.train.dset.data.data_info['attention_mask_aug_hlk']

block.test.dset.data.data_info['input_ids'] = block.test.dset.data.data_info['input_ids_aug_hlk']
block.test.dset.data.data_info['attention_mask'] = block.test.dset.data.data_info['attention_mask_aug_hlk']

In [18]:
#| export
args = XCLearningArguments(
    output_dir='/home/aiscuser/outputs/25-ramen-style-clover-training-with-input-augmentation',
    logging_first_step=True,
    per_device_train_batch_size=800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy='steps',
    eval_steps=1000,
    save_steps=1000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.1,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='INDEX',
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=10,
    clustering_type='EXPO',
    minimum_cluster_size=1,
    maximum_cluster_size=300,
    output_concatenation_weight=1.0,
    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    fp16=True,
    label_names=['cat2data_idx', 'cat2data_input_ids', 'cat2data_attention_mask',
                 'cat2lbl2data_idx', 'cat2lbl2data_input_ids', 'cat2lbl2data_attention_mask'],
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [11]:
#| export
test_dset = block.test.dset.sample(n=2000, seed=50)
metric = PrecRecl(block.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [12]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = DBT017.from_pretrained('distilbert-base-uncased', ig_tok=0, bsz=bsz, tn_targ=1000, margin=0.3, tau=0.1, 
                               n_negatives=5, apply_softmax=True, lw=0.01, m_lw=0.3, meta_prefix='cat')
model.init_dr_head()

Some weights of DBT017 were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['dr_layer_norm.bias', 'dr_layer_norm.weight', 'dr_projector.bias', 'dr_projector.weight', 'dr_transform.bias', 'dr_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
#| export
trie = XCTrie.from_block(block)

  0%|          | 0/312330 [00:00<?, ?it/s]

In [19]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    trie=trie,
    train_dataset=block.train.dset,
    eval_dataset=test_dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

In [20]:
learn.train()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[2024-05-30 21:05:03,670] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

node-0:1810131:1810131 [0] NCCL INFO Bootstrap : Using eth0:10.13.60.215<0>
node-0:1810131:1810131 [0] NCCL INFO NET/Plugin : Plugin load (librccl-net.so) returned 2 : librccl-net.so: cannot open shared object file: No such file or directory
node-0:1810131:1810131 [0] NCCL INFO NET/Plugin : No plugin found, using internal implementation
node-0:1810131:1810131 [0] NCCL INFO Kernel version: 5.15.0-1042-azure
RCCL version 2.17.1+hip5.7 HEAD:cbbb3d8+

node-0:1810131:1820050 [0] /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/rccl/build/hipify/src/misc/ibvwrap.cc:222 NCCL WARN Call to ibv_open_device failed

node-0:1810131:1820050 [0] /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/rccl/build/hipify/src/transport/net_ib.cc:199 NCCL WARN NET/IB : Unable to open device mlx5_0

node-0:1810131:1820050 [0] /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/rccl/build/hipify/src/misc/ibvwrap.cc:222 NCCL WARN Call to ibv_open_device failed

nod

node-0:1810131:1820050 [0] NCCL INFO Ring 4 : 1 -> 0 -> 1 comm 0x1fd1cc80 nRanks 02 busId 300000
node-0:1810131:1820050 [0] NCCL INFO Ring 5 : 1 -> 0 -> 1 comm 0x1fd1cc80 nRanks 02 busId 300000
node-0:1810131:1820050 [0] NCCL INFO Ring 6 : 1 -> 0 -> 1 comm 0x1fd1cc80 nRanks 02 busId 300000
node-0:1810131:1820050 [0] NCCL INFO Ring 7 : 1 -> 0 -> 1 comm 0x1fd1cc80 nRanks 02 busId 300000
node-0:1810131:1820050 [0] NCCL INFO Ring 8 : 1 -> 0 -> 1 comm 0x1fd1cc80 nRanks 02 busId 300000
node-0:1810131:1820050 [0] NCCL INFO Ring 9 : 1 -> 0 -> 1 comm 0x1fd1cc80 nRanks 02 busId 300000
node-0:1810131:1820050 [0] NCCL INFO Ring 10 : 1 -> 0 -> 1 comm 0x1fd1cc80 nRanks 02 busId 300000
node-0:1810131:1820050 [0] NCCL INFO Ring 11 : 1 -> 0 -> 1 comm 0x1fd1cc80 nRanks 02 busId 300000
node-0:1810131:1820050 [0] NCCL INFO Ring 12 : 1 -> 0 -> 1 comm 0x1fd1cc80 nRanks 02 busId 300000
node-0:1810131:1820050 [0] NCCL INFO Ring 13 : 1 -> 0 -> 1 comm 0x1fd1cc80 nRanks 02 busId 300000
node-0:1810131:1820050 [0]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


> /home/aiscuser/scratch/Projects/xcai/xcai/models/MMM0XX.py(943)forward()
    941             dr_loss = self.rep_lfn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, 
    942                                    plbl2data_data2ptr, plbl2data_idx, **kwargs)
--> 943             loss = dr_loss + self.lw*lm_loss
    944 
    945             meta_inputs = self._get_meta_inputs(**kwargs)

ipdb> n
> /home/aiscuser/scratch/Projects/xcai/xcai/models/MMM0XX.py(945)forward()
    943             loss = dr_loss + self.lw*lm_loss
    944 
--> 945             meta_inputs = self._get_meta_inputs(**kwargs)
    946             if isinstance(self.m_lw, float):
    947                 meta_lw = self.m_lw/len(meta_inputs) if len(meta_inputs) else None

ipdb> n
> /home/aiscuser/scratch/Projects/xcai/xcai/models/MMM0XX.py(946)forward()
    944 
    945             meta_inputs = self._get_meta_inputs(**kwargs)
--> 946             if isinstance(self.m_lw, float):
    947                 meta_lw = self

--> 961                         loss += m_lw * (m_drl + self.lw* m_lml)
    962 
    963                 elif 'data2ptr' in m:

ipdb> 
> /home/aiscuser/scratch/Projects/xcai/xcai/models/MMM0XX.py(953)forward()
    951                 meta_lw = self.m_lw
    952 
--> 953             for m,m_lw in zip(meta_inputs.values(), meta_lw):
    954                 if 'lbl2data2ptr' in m:
    955                     valid_idx = torch.where(m['lbl2data2ptr'])[0]

ipdb> 
> /home/aiscuser/scratch/Projects/xcai/xcai/models/MMM0XX.py(954)forward()
    952 
    953             for m,m_lw in zip(meta_inputs.values(), meta_lw):
--> 954                 if 'lbl2data2ptr' in m:
    955                     valid_idx = torch.where(m['lbl2data2ptr'])[0]
    956                     if len(valid_idx) > 0:

ipdb> 
> /home/aiscuser/scratch/Projects/xcai/xcai/models/MMM0XX.py(963)forward()
    961                         loss += m_lw * (m_drl + self.lw* m_lml)
    962 
--> 963                 elif 'data2ptr' in m:


KeyError: Caught KeyError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/opt/conda/envs/ptca/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/envs/ptca/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aiscuser/scratch/Projects/xcai/xcai/models/MMM0XX.py", line 931, in forward
    import pdb; pdb.set_trace()
  File "/opt/.singularity/lib/python3.9/site-packages/IPython/core/debugger.py", line 1099, in set_trace
    Pdb().set_trace(frame or sys._getframe().f_back)
  File "/opt/.singularity/lib/python3.9/site-packages/IPython/core/debugger.py", line 401, in set_trace
    return super().set_trace(frame)
  File "/opt/conda/envs/ptca/lib/python3.9/bdb.py", line 328, in set_trace
    self.reset()
  File "/opt/conda/envs/ptca/lib/python3.9/pdb.py", line 197, in reset
    bdb.Bdb.reset(self)
  File "/opt/conda/envs/ptca/lib/python3.9/bdb.py", line 57, in reset
    linecache.checkcache()
  File "/opt/.singularity/lib/python3.9/site-packages/IPython/core/compilerop.py", line 185, in check_linecache_ipython
    linecache._checkcache_ori(*args)
  File "/opt/conda/envs/ptca/lib/python3.9/linecache.py", line 64, in checkcache
    entry = cache[filename]
KeyError: '/tmp/ipykernel_1810131/869285730.py'


In [None]:
#| export
if __name__ == '__main__':
    mp.freeze_support()
    learn.train()