In [1]:
#| default_exp 23_ngame-linker-for-wikiseealsotitles

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np
from scipy import sparse

from xcai.basics import *
from xcai.models.PPP0XX import DBT009,DBT011,DBT021
from xcai.models.distillation import TCH001,DTL007,TCH003,DTL008
from xcai.data import MetaXCDataset,XCDataset,MainXCDataset

from transformers import DistilBertConfig

comet_ml is installed but `COMET_API_KEY` is not set.


In [4]:
os.environ['WANDB_MODE'] = 'disabled'

In [5]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='medic_00-wikiseealsotitles'

## Load data

In [6]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

In [62]:
block = XCBlock.from_cfg(data_dir, 'data_meta', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                         sampling_features=[('lbl2data',4)], oversample=False)



In [7]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-meta_distilbert-base-uncased_xcs.pkl'

In [64]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [8]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

## Linker task

In [9]:
#| export
def get_meta_dataset(meta, idx):
    data_meta = meta.data_meta[:, idx]
    lbl_meta = meta.lbl_meta[:, idx]
    meta_info = {k:[v[i] for i in idx] for k,v in meta.meta_info.items()}
    return MetaXCDataset(meta.prefix, data_meta, lbl_meta, meta_info)
    
def threshold_meta_dataset(train_meta, test_meta, thresh=100):
    nnz = train_meta.data_meta.getnnz(axis=0)
    idx = np.where(np.logical_and(nnz < thresh, nnz > 0))[0]
    return get_meta_dataset(train_meta, idx), get_meta_dataset(test_meta, idx)
    

In [10]:
train_meta, test_meta = threshold_meta_dataset(block.train.dset.meta['cat_meta'], block.test.dset.meta['cat_meta'], thresh=100)
block.train.dset.meta['cat_meta'], block.test.dset.meta['cat_meta'] = train_meta, test_meta

In [11]:
train_idx = np.where(block.train.dset.meta['cat_meta'].data_meta.getnnz(axis=1) > 0)[0]
train_block = block.train._getitems(train_idx)

In [12]:
test_idx = np.where(block.test.dset.meta['cat_meta'].data_meta.getnnz(axis=1) > 0)[0]
test_block = block.test._getitems(test_idx)

In [13]:
sal_meta = MetaXCDataset('sal', train_block.dset.data.data_lbl, train_block.dset.meta['cat_meta'].lbl_meta.T.tocsr(), 
                         train_block.dset.data.lbl_info)
train_dset = XCDataset(MainXCDataset(train_block.dset.data.data_info, train_block.dset.meta.cat_meta.data_meta, 
                                     train_block.dset.meta.cat_meta.meta_info), sal_meta=sal_meta)

In [14]:
sal_meta = MetaXCDataset('sal', test_block.dset.data.data_lbl, test_block.dset.meta['cat_meta'].lbl_meta.T.tocsr(), 
                         test_block.dset.data.lbl_info)
test_dset = XCDataset(MainXCDataset(test_block.dset.data.data_info, test_block.dset.meta.cat_meta.data_meta, 
                                    test_block.dset.meta.cat_meta.meta_info))

In [15]:
block.collator.tfms.tfms[0].sampling_features = [('lbl2data,sal2lbl2data', (4,1)), ('sal2data', 1)]

## Training

In [16]:
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/medic/23_ngame-linker-for-wikiseealsotitles',
    logging_first_step=True,
    per_device_train_batch_size=100,
    per_device_eval_batch_size=100,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    representation_search_type='BRUTEFORCE',
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    
    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,
    
    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,

    label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask',
                 'sal2data_idx', 'sal2data_input_ids', 'sal2data_attention_mask',
                 'sal2lbl2data_idx', 'sal2lbl2data_input_ids', 'sal2lbl2data_attention_mask'],

    prune_metadata=True,
    num_metadata_prune_warmup_epochs=0,
    num_metadata_prune_epochs=5,
    metadata_prune_batch_size=2048,
    prune_metadata_names=['sal_meta'],
    use_data_metadata_for_pruning=True,
    prune_metadata_threshold=0.0,
    prune_metadata_topk=3,
)

In [17]:
model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'

teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)
teacher.freeze_embeddings()

In [18]:
m_teacher = TCH003(DistilBertConfig(), n_data=len(train_idx))
m_teacher.init_embeddings(teacher.data_repr.weight.data[train_idx])
m_teacher.freeze_embeddings()

In [19]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

m_student = DBT021.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=bsz, tn_targ=1000, margin=0.3, tau=0.1, 
                                   apply_softmax=True, n_negatives=10, m_lw=0.2, data_meta_prefix='sal2data', 
                                   lbl2data_meta_prefix='sal2lbl', use_encoder_parallel=False, task_repr_type='pool', meta_repr_type='pool')
m_student.init_dr_head()

Some weights of DBT021 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.dr_layer_norm.bias', 'encoder.dr_layer_norm.weight', 'encoder.dr_projector.bias', 'encoder.dr_projector.weight', 'encoder.dr_transform.bias', 'encoder.dr_transform.weight', 'encoder.meta_layer_norm.bias', 'encoder.meta_layer_norm.weight', 'encoder.meta_projector.bias', 'encoder.meta_projector.weight', 'encoder.meta_transform.bias', 'encoder.meta_transform.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
model = DTL008(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, 
               n_negatives=10, apply_softmax=True, teacher_data_student_label_loss_weight=1.0, data_mse_loss_weight=0.1)

In [21]:
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [22]:
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=train_dset,
    eval_dataset=test_dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [23]:
bb = next(iter(learn.get_train_dataloader()))

In [28]:
bb['sal2lbl_input_ids'].shape

torch.Size([392, 18])

In [29]:
bb['sal2lbl_attention_mask'].shape

torch.Size([392, 18])

In [25]:
o = model(**bb.to(model.device))

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


In [24]:
learn.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


  0%|          | 0/3171 [00:00<?, ?it/s]

  0%|          | 0/3026 [00:00<?, ?it/s]

  0%|          | 0/1562 [00:00<?, ?it/s]

  0%|          | 0/674 [00:00<?, ?it/s]

  0%|          | 0/488 [00:00<?, ?it/s]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss


RuntimeError: shape '[433, 1, 1, 16]' is invalid for input of size 6495

In [30]:
%debug

> /home/scai/phd/aiz218323/.local/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py(244)forward()
    242         q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
    243         scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
--> 244         mask = (mask == 0).view(mask_reshp).expand_as(scores)  # (bs, n_heads, q_length, k_length)
    245         scores = scores.masked_fill(
    246             mask, torch.tensor(torch.finfo(scores.dtype).min)



ipdb>  u


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1541)_call_impl()
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1532)_wrapped_call_impl()
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):



ipdb>  


> /home/scai/phd/aiz218323/.local/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py(513)forward()
    511         """
    512         # Self-Attention
--> 513         sa_output = self.attention(
    514             query=x,
    515             key=x,



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1541)_call_impl()
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1532)_wrapped_call_impl()
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):



ipdb>  


> /home/scai/phd/aiz218323/.local/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py(587)forward()
    585                 )
    586             else:
--> 587                 layer_outputs = layer_module(
    588                     hidden_state,
    589                     attn_mask,



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1541)_call_impl()
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1532)_wrapped_call_impl()
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):



ipdb>  


> /home/scai/phd/aiz218323/.local/lib/python3.9/site-packages/transformers/models/distilbert/modeling_distilbert.py(822)forward()
    820                 attention_mask = torch.ones(input_shape, device=device)  # (bs, seq_length)
    821 
--> 822         return self.transformer(
    823             x=embeddings,
    824             attn_mask=attention_mask,



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1541)_call_impl()
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:



ipdb>  u


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1532)_wrapped_call_impl()
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/PPP0XX.py(930)forward()
    928         if repr_type is None: repr_type = self.repr_type
    929 
--> 930         o = self.distilbert(
    931             input_ids=input_ids,
    932             attention_mask=attention_mask,



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1541)_call_impl()
   1539                 or _global_backward_pre_hooks or _global_backward_hooks
   1540                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541             return forward_call(*args, **kwargs)
   1542 
   1543         try:



ipdb>  


> /scratch/scai/phd/aiz218323/anaconda3/envs/xc_nlg_2/lib/python3.9/site-packages/torch/nn/modules/module.py(1532)_wrapped_call_impl()
   1530             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531         else:
-> 1532             return self._call_impl(*args, **kwargs)
   1533 
   1534     def _call_impl(self, *args, **kwargs):



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/PPP0XX.py(1001)compute_meta_loss()
    999                 idx = torch.where(inputs['lbl2data2ptr'])[0]
   1000                 if len(idx) > 0:
-> 1001                     inputs_o = encoder(input_ids=inputs['input_ids'],
   1002                                        attention_mask=inputs['attention_mask'],
   1003                                        input_type="meta", repr_type=self.meta_repr_type)



ipdb>  


> /scratch/scai/phd/aiz218323/Projects/xcai/xcai/models/PPP0XX.py(1048)forward()
   1046             lbl2data_o, lbl2data_repr = encoder(lbl2data_input_ids, lbl2data_attention_mask, repr_type=self.task_repr_type)
   1047             loss = self.loss_fn(data_repr, lbl2data_repr, lbl2data_data2ptr, lbl2data_idx, plbl2data_data2ptr, plbl2data_idx)
-> 1048             loss += self.compute_meta_loss(data_repr, lbl2data_repr, **kwargs)
   1049 
   1050         if not return_dict:



ipdb>  kwargs.keys()


dict_keys(['psal2lbl_idx', 'psal2lbl_data2ptr', 'psal2lbl_lbl2data2ptr', 'sal2lbl_data2ptr', 'sal2lbl_lbl2data2ptr', 'sal2lbl_idx', 'sal2lbl_input_ids', 'sal2lbl_attention_mask', 'psal2data_idx', 'psal2data_data2ptr', 'sal2data_idx', 'sal2data_input_ids', 'sal2data_attention_mask', 'sal2data_data2ptr'])


ipdb>  kwargs['sal2lbl_input_ids'].shape


torch.Size([433, 16])


ipdb>  kwargs['sal2lbl_attention_mask'].shape


torch.Size([433, 15])


ipdb>  q


## Driver

In [None]:
#| export
if __name__ == '__main__':
    build_block = False

    meta_freq_threshold=100
    pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
    data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

    output_dir = '/home/scai/phd/aiz218323/scratch/outputs/medic/23_ngame-linker-for-wikiseealsotitles'
    model_output = '/home/scai/phd/aiz218323/scratch/outputs/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4'

    """ Load data """
    pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-meta_distilbert-base-uncased_xcs.pkl'
    if build_block:
        block = XCBlock.from_cfg(data_dir, 'data_meta', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                                 sampling_features=[('lbl2data', 4)], oversample=True)
        with open(pkl_file, 'wb') as file: pickle.dump(block, file)
    else:
        with open(pkl_file, 'rb') as file: block = pickle.load(file)

    """ Linker dataset """
    train_meta, test_meta = threshold_meta_dataset(block.train.dset.meta['cat_meta'], block.test.dset.meta['cat_meta'], thresh=100)
    block.train.dset.meta['cat_meta'], block.test.dset.meta['cat_meta'] = train_meta, test_meta

    train_idx = np.where(block.train.dset.meta['cat_meta'].data_meta.getnnz(axis=1) > 0)[0]
    train_block = block.train._getitems(train_idx)
    test_idx = np.where(block.test.dset.meta['cat_meta'].data_meta.getnnz(axis=1) > 0)[0]
    test_block = block.test._getitems(test_idx)

    sal_meta = MetaXCDataset('sal', train_block.dset.data.data_lbl, train_block.dset.meta['cat_meta'].lbl_meta.T, 
                             train_block.dset.data.lbl_info)
    train_dset = XCDataset(MainXCDataset(train_block.dset.data.data_info, train_block.dset.meta.cat_meta.data_meta, 
                                         train_block.dset.meta.cat_meta.meta_info), sal_meta=sal_meta)
    sal_meta = MetaXCDataset('sal', test_block.dset.data.data_lbl, test_block.dset.meta['cat_meta'].lbl_meta.T, 
                             test_block.dset.data.lbl_info)
    test_dset = XCDataset(MainXCDataset(test_block.dset.data.data_info, test_block.dset.meta.cat_meta.data_meta, 
                                        test_block.dset.meta.cat_meta.meta_info))

    block.collator.tfms.tfms[0].sampling_features = [('lbl2data,sal2lbl2data', (4,1)), ('sal2data', 1)]
    

    """ Training arguements """
    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        representation_search_type='BRUTEFORCE',
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,
        
        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
        
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,
    
        label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask',
                     'sal2data_idx', 'sal2data_input_ids', 'sal2data_attention_mask',
                     'sal2lbl2data_idx', 'sal2lbl2data_input_ids', 'sal2lbl2data_attention_mask'],

        prune_metadata=True,
        num_metadata_prune_warmup_epochs=0,
        num_metadata_prune_epochs=5,
        metadata_prune_batch_size=2048,
        prune_metadata_names=['sal_meta'],
        use_data_metadata_for_pruning=True,
        prune_metadata_threshold=0.0,
        prune_metadata_topk=3,
    )

    metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                      pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

    """ Model """
    teacher = TCH001.from_pretrained(f'{model_output}/teacher', n_data=block.train.dset.n_data, n_lbl=block.n_lbl)
    teacher.freeze_embeddings()
    m_teacher = TCH003(DistilBertConfig(), n_data=len(train_idx))
    m_teacher.init_embeddings(teacher.data_repr.weight.data[train_idx])
    m_teacher.freeze_embeddings()

    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

    m_student = DBT021.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', bsz=bsz, tn_targ=1000, margin=0.3, tau=0.1, 
                                       apply_softmax=True, n_negatives=10, m_lw=0.2, data_meta_prefix='sal2data', 
                                       lbl2data_meta_prefix='sal2lbl', use_encoder_parallel=True, task_repr_type='pool', 
                                       meta_repr_type='pool')
    m_student.init_dr_head()

    model = DTL008(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1, 
                   n_negatives=10, apply_softmax=True, teacher_data_student_label_loss_weight=1.0, data_mse_loss_weight=0.1)
    

    """ Training """
    learn = XCLearner(
        model=model, 
        args=args,
        train_dataset=train_dset,
        eval_dataset=test_dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    mp.freeze_support()
    learn.train()
    