In [None]:
#| default_exp 99-wikititles-with-oak-multi-optimizer-1-0

In [None]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np, transformers
from transformers import DistilBertConfig

from xcai.basics import *
from xcai.models.oakX import OAK001

from xclib.utils.sparse import retain_topk

from fastcore.utils import *

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='xc-nlg_97-oak-dr-ep-for-wikititles'

In [None]:
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

block = XCBlock.from_cfg(data_dir, 'data_lnk', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                         smp_features=[('lbl2data',4), ('lnk2data',3)])



In [None]:
#| export
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikititles_data-lnk_distilbert-base-uncased_xcs.pkl'

In [None]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [None]:
#| export
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [None]:
#| export
data_meta = retain_topk(block.train.dset.meta.lnk_meta.data_meta, k=5)
block.train.dset.meta.lnk_meta.curr_data_meta = data_meta
block.train.dset.meta.lnk_meta.data_meta = data_meta

data_meta = retain_topk(block.test.dset.meta.lnk_meta.data_meta, k=3)
block.test.dset.meta.lnk_meta.curr_data_meta = data_meta
block.test.dset.meta.lnk_meta.data_meta = data_meta

## Optimizer

In [None]:
#| export
import torch
from itertools import chain
from collections import defaultdict

In [None]:
#| export
class MultipleOptimizer(torch.optim.Optimizer):
    # Wrapper around multiple optimizers that should be executed at the same time
    def __init__(self, optimizers):
        self.optimizers = optimizers

    @property
    def state(self):
        state = defaultdict(dict)
        for optimizer in self.optimizers:
            state = {**state, **optimizer.state}
        return state

    @property
    def param_groups(self):
        param_groups = []
        for optimizer in self.optimizers:
            param_groups = param_groups + optimizer.param_groups
        return param_groups

    def __getstate__(self):
        return [optimizer.__getstate__() for optimizer in self.optimizers]

    def __setstate__(self, state):
        for opt_state, optimizer in zip(self.optimizers, state):
            optimizer.__setstate__(opt_state)

    def __repr__(self):
        format_string = self.__class__.__name__ + ' ('
        for optimizer in self.optimizers:
            format_string += '\n'
            format_string += optimizer.__repr__()
        format_string += ')'
        return format_string

    def _hook_for_profile(self):
        for optimizer in self.optimizers:
            optimizer._hook_for_profile()

    def state_dict(self):
        return [optimizer.state_dict() for optimizer in self.optimizers]

    def load_state_dict(self, state_dict):
        for state, optimizer in zip(state_dict, self.optimizers):
            optimizer.load_state_dict(state)

    def zero_grad(self, set_to_none: bool = False):
        for optimizer in self.optimizers:
            optimizer.zero_grad(set_to_none=set_to_none)

    def add_param_group(self, param_group):
        raise NotImplementedError()

    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for optimizer in self.optimizers:
            optimizer.step()

        return loss


In [None]:
#| export
class MultipleScheduler(object):

    def __init__(self, sched):
        self.schedulers = sched

    def step(self, *args, **kwargs):
        for sched in self.schedulers: sched.step(*args, **kwargs)

    def get_last_lr(self):
        return list(chain(*[s.get_last_lr() for s in self.schedulers]))

    def state_dict(self):
        return [sched.state_dict() for sched in self.schedulers]

    def load_state_dict(self, state_dict):
        for sched,state in zip(self.schedulers, state_dict):
            sched.load_state_dict(state)

In [None]:
#| export
@patch
def create_optimizer_and_scheduler(self:XCLearner, num_training_steps: int):
    import pdb; pdb.set_trace()
    NO_DECAY = ['bias', 'LayerNorm.weight']

    dense, sparse = [], []
    for k, p in model.named_parameters():
        if p.requires_grad:
            if "meta_embeddings" not in k: dense.append((k,p))
            else: sparse.append(p)

    params = [
        {'params': [p for n, p in dense if not any(nd in n for nd in NO_DECAY)], 'weight_decay': 0.01},
        {'params': [p for n, p in dense if any(nd in n for nd in NO_DECAY)], 'weight_decay': 0.0},
    ]

    optimizer_list = [torch.optim.AdamW(params, **{'lr': self.args.learning_rate, 'eps': 1e-6}),
                      torch.optim.SparseAdam(sparse, **{'lr': self.args.learning_rate * self.args.free_parameter_lr_coefficient, 'eps': 1e-6})]

    self.optimizer = MultipleOptimizer(optimizer_list)
    scheduler_list = [transformers.get_linear_schedule_with_warmup(self.optimizer.optimizers[0], num_warmup_steps=self.args.warmup_steps,
                                                                   num_training_steps=num_training_steps),
                        transformers.get_cosine_schedule_with_warmup(self.optimizer.optimizers[1],
                                                                     num_warmup_steps=self.args.free_parameter_warmup_steps,
                                                                     num_training_steps=num_training_steps)]

    self.lr_scheduler = MultipleScheduler(scheduler_list)

## Training

In [None]:
#| export
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/99-wikititles-with-oak-multi-optimizer-1-0',
    logging_first_step=True,
    per_device_train_batch_size=10, #800,
    per_device_eval_batch_size=800,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    evaluation_strategy="steps",
    eval_steps=10, #5000,
    save_steps=10, #5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    representation_search_type='BRUTEFORCE',
    
    output_representation_attribute='data_fused_repr',
    label_representation_attribute='data_repr',
    metadata_representation_attribute='data_repr',
    data_augmentation_attribute='data_repr',
    representation_attribute='data_fused_repr',
    clustering_representation_attribute='data_fused_repr',

    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    use_data_metadata_for_clustering=True,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,

    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    
    use_distributional_representation=False,
    use_encoder_parallel=True,
    max_grad_norm=None, 
    fp16=True,
    
    label_names=['lnk2data_idx'],
    
    prune_metadata=False,
    num_metadata_prune_warmup_epochs=10,
    num_metadata_prune_epochs=5,
    metadata_prune_batch_size=2048,
    prune_metadata_names=['lnk_meta'],
    use_data_metadata_for_pruning=True,

    predict_with_augmentation=False,
    use_augmentation_index_representation=True,

    data_aug_meta_name='lnk',
    augmentation_num_beams=3,
    data_aug_prefix='lnk',
    use_label_metadata=False,
    
    data_meta_batch_size=2048,
    augment_metadata=False,
    num_metadata_augment_warmup_epochs=10,
    num_metadata_augment_epochs=5,

    free_parameter_warmup_steps=0,
    free_parameter_lr_coefficient=1000,
)


In [None]:
#| export
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = OAK001.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=bsz, num_batch_labels=5000, 
                               margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                               data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                               data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                               
                               num_metadata=block.train.dset.meta['lnk_meta'].n_meta, resize_length=5000,
                               
                               calib_margin=0.3, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                               calib_loss_weight=0.1, use_calib_loss=True,

                               use_query_loss=True,

                               meta_loss_weight=0.0, 
                               
                               fusion_loss_weight=0.1, use_fusion_loss=False,
                               
                               use_encoder_parallel=False)
model.init_retrieval_head()
model.init_cross_head()
model.init_meta_embeddings()

meta_embed_file = '/home/aiscuser/scratch/OGB_Weights/LF-WikiSeeAlsoTitles-320K/emb_weights.npy'
# meta_embeddings = np.load(meta_embed_file)
# model.encoder.set_pretrained_meta_embeddings(torch.tensor(meta_embeddings, dtype=torch.float32))

model.encoder.set_pretrained_meta_embeddings(torch.zeros(block.train.dset.meta['lnk_meta'].n_meta, model.config.dim))
model.encoder.freeze_pretrained_meta_embeddings()

Some weights of OAK001 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_fused_head.layer_norm.bias', 'encoder.dr_fused_head.layer_norm.weight', 'encoder.dr_fused_head.projector.bias', 'encoder.dr_fused_head.projector.weight', 'encoder.dr_fused_head.transform.bias', 'encoder.dr_fused_head.transform.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'encoder.meta_head.layer_norm.bias', 'encoder.meta_head.layer_norm.weight', 'encoder.meta_head.projector.bias', 'encoder.meta_head.projector.weight', 

In [None]:
#| export
metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
#| export
learn = XCLearner(
    model=model, 
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [None]:
learn.train()

In [None]:
#| export
if __name__ == '__main__':
    mp.freeze_support()
    learn.train()

In [None]:
o = learn.predict(block.test.dset)

  0%|          | 0/196 [00:00<?, ?it/s]

  return torch.sparse_csr_tensor(data_ptr, data_idx, scores, device=data_ptr.device)


  self._set_arrayXarray(i, j, x)


In [None]:
display_metric(o.metrics)

Unnamed: 0,P@1,P@3,P@5,P@10,N@1,N@3,N@5,N@10,PSP@1,PSP@3,PSP@5,PSP@10,PSN@1,PSN@3,PSN@5,PSN@10,R@10,R@100,R@200,loss,runtime,samples_per_second,steps_per_second
0,17.296,11.305,8.5469,5.5415,17.296,17.0792,17.7819,19.1372,16.3694,16.7477,17.8326,20.6162,16.3694,17.404,18.4949,20.0678,22.993,37.2897,41.666,0.9435,120.816,1469.3,0.919


In [None]:
from tqdm.auto import tqdm

In [None]:
data_lbl = block.train.dset.data.data_lbl
idx = np.random.randint(data_lbl.shape[0], size=1600)

In [None]:
data_lbl_indices = [o.indices.tolist() for o in data_lbl]

In [None]:
len(data_lbl_indices)

1813391

In [None]:
%%timeit
indices = [data_lbl[i].indices.tolist() for i in idx]

107 ms ± 589 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
%%timeit
indices = [data_lbl_indices[i] for i in idx]

177 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
train_dset = block.train.dset.data

In [None]:
train_dset.data_info['indices'] = data_lbl_indices

In [None]:
{k:len(v) for k,v in train_dset.data_info.items()}

{'identifier': 1813391,
 'input_text': 1813391,
 'input_ids': 1813391,
 'attention_mask': 1813391,
 'indices': 1813391}