In [1]:
#| default_exp 36_oak-for-wikiseealsotitles-free-param-clustering

In [2]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np
from transformers import DistilBertConfig

from xcai.basics import *
from xcai.models.oak import OAK010
from xcai.clustering.cluster import BalancedClusters

from xclib.utils.sparse import retain_topk

from fastcore.utils import *

In [4]:
os.environ['WANDB_MODE'] = 'disabled'

In [5]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='oak_00-wikiseealsotitles'

## Load data

In [6]:
build_block = False

data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-lnk_distilbert-base-uncased_xcs.pkl'

if build_block:
    block = XCBlock.from_cfg(data_dir, 'data_lnk', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                             sampling_features=[('lbl2data',4), ('lnk2data',3)], oversample=False)
    with open(pkl_file, 'wb') as file: pickle.dump(block, file)
else:
    with open(pkl_file, 'rb') as file: block = pickle.load(file)

## Metadata clusters

In [7]:
#| export
def get_metadata_remap(meta_repr:torch.Tensor, cluster_sz:int=3):
    clusters = BalancedClusters.proc(meta_repr.half(), min_cluster_sz=cluster_sz)

    meta_remap = torch.zeros(meta_repr.shape[0], dtype=torch.int64)
    for i,o in enumerate(clusters): meta_remap[o] = i

    return meta_remap, len(clusters)
    

In [8]:
meta_embed_file = '/home/aiscuser/scratch/OGB_Weights/LF-WikiSeeAlsoTitles-320K/emb_weights.npy'
#meta_embeddings = np.load(meta_embed_file)
meta_embeddings = torch.randn(block.train.dset.meta['lnk_meta'].n_meta, 768)

meta_remap, n_clusters = get_metadata_remap(meta_embeddings, 3)

Updating clusters with size 3
Tree depth = 18
doing random split
lengths: [328043, 328043]
remaining levels for GPU split=17
==> gpu splitting random clusters 0 to 2
 rank=0 => Total clusters 2	Avg. Cluster size                 164021.50	Time to split nodes on this level 0.17 sec
 rank=0 => Total clusters 4	Avg. Cluster size                 82010.75	Time to split nodes on this level 0.06 sec
 rank=0 => Total clusters 8	Avg. Cluster size                 41005.38	Time to split nodes on this level 0.04 sec
 rank=0 => Total clusters 16	Avg. Cluster size                 20502.69	Time to split nodes on this level 0.06 sec
 rank=1 => Total clusters 2	Avg. Cluster size                 164021.50	Time to split nodes on this level 0.16 sec
 rank=1 => Total clusters 4	Avg. Cluster size                 82010.75	Time to split nodes on this level 0.03 sec
 rank=0 => Total clusters 32	Avg. Cluster size                 10251.34	Time to split nodes on this level 0.49 sec
 rank=1 => Total clusters 8	Avg.

Process SpawnProcess-3:
Traceback (most recent call last):
  File "/home/scai/phd/aiz218323/scratch/anaconda3/envs/xc_nlg_2/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/scai/phd/aiz218323/scratch/anaconda3/envs/xc_nlg_2/lib/python3.9/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/scratch/scai/phd/aiz218323/Projects/xcai/xcai/clustering/fast_cluster.py", line 105, in balanced_cluster_gpu
    new_clusters += split_cluster(embs[cluster], cluster)
  File "/scratch/scai/phd/aiz218323/Projects/xcai/xcai/clustering/fast_cluster.py", line 75, in split_cluster
    _centeroids = labels_features[cluster]
KeyboardInterrupt


KeyboardInterrupt: 

## Driver

In [None]:
#| export
if __name__ == '__main__':
    build_block = False
    pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
    data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'
    
    output_dir = '/home/scai/phd/aiz218323/scratch/outputs/medic/36_oak-for-wikiseealsotitles-free-param-clustering'
    meta_embed_file = '/home/aiscuser/scratch/OGB_Weights/LF-WikiSeeAlsoTitles-320K/emb_weights.npy'

    """ Load data """
    pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data-lnk_distilbert-base-uncased_xcs.pkl'

    if build_block:
        block = XCBlock.from_cfg(data_dir, 'data_lnk', transform_type='xcs', tokenizer='distilbert-base-uncased', 
                                 sampling_features=[('lbl2data',4), ('lnk2data',3)], oversample=False)
        with open(pkl_file, 'wb') as file: pickle.dump(block, file)
    else:
        with open(pkl_file, 'rb') as file: block = pickle.load(file)

    """ Prune metadata """
    data_meta = retain_topk(block.train.dset.meta.lnk_meta.data_meta, k=5)
    lbl_meta = block.train.dset.meta.lnk_meta.lbl_meta
    block.train.dset.meta.lnk_meta.update_meta_matrix(data_meta, lbl_meta)
    
    data_meta = retain_topk(block.test.dset.meta.lnk_meta.data_meta, k=3)
    lbl_meta = block.test.dset.meta.lnk_meta.lbl_meta
    block.test.dset.meta.lnk_meta.update_meta_matrix(data_meta, lbl_meta)

    block.collator.tfms.tfms[0].sampling_features = [('lbl2data',4),('lnk2data',3)]
    block.collator.tfms.tfms[0].oversample = False
    
    block.train.dset.meta.lnk_meta.meta_info = None
    block.test.dset.meta.lnk_meta.meta_info = None

    """ Training arguements """
    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
        representation_search_type='BRUTEFORCE',
        
        output_representation_attribute='data_fused_repr',
        label_representation_attribute='data_repr',
        metadata_representation_attribute='data_repr',
        data_augmentation_attribute='data_repr',
        representation_attribute='data_fused_repr',
        clustering_representation_attribute='data_fused_repr',
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        use_data_metadata_for_clustering=True,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
        
        use_distributional_representation=False,
        use_encoder_parallel=True,
        max_grad_norm=None, 
        fp16=True,
        
        label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lnk2data_idx'],
        
        prune_metadata=False,
        num_metadata_prune_warmup_epochs=10,
        num_metadata_prune_epochs=5,
        metadata_prune_batch_size=2048,
        prune_metadata_names=['lnk_meta'],
        use_data_metadata_for_pruning=True,
    
        predict_with_augmentation=False,
        use_augmentation_index_representation=True,
    
        data_aug_meta_name='lnk',
        augmentation_num_beams=None,
        data_aug_prefix='lnk',
        use_label_metadata=False,
        
        data_meta_batch_size=2048,
        augment_metadata=False,
        num_metadata_augment_warmup_epochs=10,
        num_metadata_augment_epochs=5,
    
        use_cpu_for_searching=True,
        use_cpu_for_clustering=True,
    )
    
    """ metadata clustering """
    # meta_embeddings = np.load(meta_embed_file)
    meta_remap, n_clusters = get_metadata_remap(meta_embeddings, 3)

    """ model """
    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
    model = OAK010.from_pretrained('sentence-transformers/msmarco-distilbert-base-v4', batch_size=bsz, num_batch_labels=5000, 
                                   margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
                               
                                   data_aug_meta_prefix='lnk2data', lbl2data_aug_meta_prefix=None, 
                                   data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
                                   
                                   n_metadata=block.train.dset.meta['lnk_meta'].n_meta, n_clusters=n_clusters, resize_length=5000,
                                   
                                   calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False, 
                                   calib_loss_weight=0.1, use_calib_loss=False,
    
                                   use_query_loss=True,
    
                                   meta_loss_weight=0.0, 
                                   
                                   fusion_loss_weight=0.0, use_fusion_loss=False,
                                   
                                   use_encoder_parallel=True)
    
    model.init_retrieval_head()
    model.init_cross_head()
    model.init_meta_embeddings()
    
    # model.encoder.set_pretrained_meta_embeddings(torch.tensor(meta_embeddings, dtype=torch.float32))
    model.encoder.set_pretrained_meta_embeddings(torch.zeros(block.train.dset.meta['lnk_meta'].n_meta, model.config.dim))
    model.encoder.freeze_pretrained_meta_embeddings()
    
    model.encoder.set_metadata_remap(meta_remap)
    
    """ Training """
    metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                      pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

    learn = XCLearner(
        model=model, 
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    mp.freeze_support()
    learn.train()
    

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,P@1,P@10,P@3,P@5,N@1,N@10,N@3,N@5,Psp@1,Psp@10,Psp@3,Psp@5,Psn@1,Psn@10,Psn@3,Psn@5,R@200,R@10,R@100
10,0.0788,0.086276,0.175101,0.056812,0.115025,0.087132,0.175101,0.194849,0.173305,0.180583,0.163741,0.209324,0.168619,0.180018,0.163741,0.201765,0.174354,0.185467,0.429889,0.235065,0.384685


  0%|          | 0/15617 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


  0%|          | 0/15617 [00:00<?, ?it/s]