In [None]:
#| default_exp 12_mogic-for-wikiseealsotitles-noise

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
#| export
import os,torch,json, torch.multiprocessing as mp, joblib, numpy as np, scipy.sparse as sp
from transformers import DistilBertConfig

from xcai.main import *
from xcai.basics import *
from xcai.clustering.cluster import get_cluster_mapping, get_cluster_size

from xcai.models.oak import OAK008
from xcai.models.distillation import DTL004,TCH001

In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT'] = 'mogicX_01-wikiseealsotitles'

## Setup

In [None]:
output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/12_mogic-for-wikiseealsotitles-noise'

config_file = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/(mapped)LF-WikiSeeAlsoTitles-320K/configs/data_category_linker_noise-050.json'
config_key = 'data_category'

teacher_model = '/home/scai/phd/aiz218323/scratch/outputs/xc_nlg/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4/teacher/'
student_model = 'sentence-transformers/msmarco-distilbert-base-v4'
meta_embed_init_file = '/data/OGB_Weights/LF-WikiSeeAlsoTitles-320K/emb_weights.npy'

meta_name = 'lnk'

In [None]:
do_train_inference = False
do_test_inference = False

save_train_inference = False
save_test_inference = False

save_representation = False

use_sxc_sampler, only_test, use_pretrained = True, False, False

pkl_dir = '/scratch/scai/phd/aiz218323/datasets/processed/'

In [None]:
pkl_file = f'{pkl_dir}/mogicX/wikiseealsotitles-noise_data-category-linker_distilbert-base-uncased'
pkl_file = f'{pkl_file}_sxc' if use_sxc_sampler else f'{pkl_file}_xcs'
if only_test: pkl_file = f'{pkl_file}_only-test'
pkl_file = f'{pkl_file}.joblib'

In [None]:
do_inference = do_train_inference or do_test_inference or save_train_inference or save_test_inference or save_representation

In [None]:
pkl_file

'/scratch/scai/phd/aiz218323/datasets/processed//mogicX/wikiseealsotitles-noise_data-category-linker_distilbert-base-uncased_sxc.joblib'

In [None]:
config_file = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/(mapped)LF-WikiSeeAlsoTitles-320K/configs/data_category.json'

pkl_file = f'{pkl_dir}/mogicX/wikiseealsotitles_data-category_distilbert-base-uncased'
pkl_file = f'{pkl_file}_sxc' if use_sxc_sampler else f'{pkl_file}_xcs'
pkl_file = f'{pkl_file}.joblib'

In [None]:
%%time
os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
block = build_block(pkl_file, config_file, use_sxc_sampler, config_key, n_slbl_samples=4, main_oversample=False, n_sdata_meta_samples=3, meta_oversample=False, 
                    train_meta_topk=5, test_meta_topk=3)



CPU times: user 7min 46s, sys: 1min 3s, total: 8min 50s
Wall time: 5min 42s


In [None]:
args = XCLearningArguments(
    output_dir=output_dir,
    logging_first_step=True,
    per_device_train_batch_size=512,
    per_device_eval_batch_size=512,
    representation_num_beams=200,
    representation_accumulation_steps=10,
    save_strategy="steps",
    eval_strategy="steps",
    eval_steps=5000,
    save_steps=5000,
    save_total_limit=5,
    num_train_epochs=300,
    predict_with_representation=True,
    adam_epsilon=1e-6,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=2e-4,
    representation_search_type='BRUTEFORCE',

    output_representation_attribute='data_fused_repr',
    label_representation_attribute='data_repr',
    metadata_representation_attribute='data_repr',
    data_augmentation_attribute='data_repr',
    representation_attribute='data_fused_repr',
    clustering_representation_attribute='data_fused_repr',

    group_by_cluster=True,
    num_clustering_warmup_epochs=10,
    num_cluster_update_epochs=5,
    num_cluster_size_update_epochs=25,
    use_data_metadata_for_clustering=True,
    clustering_type='EXPO',
    minimum_cluster_size=2,
    maximum_cluster_size=1600,

    metric_for_best_model='P@1',
    load_best_model_at_end=True,
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',

    use_distributional_representation=False,
    use_encoder_parallel=True,
    max_grad_norm=None,
    fp16=True,

    label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'plbl2data_idx', 'plbl2data_data2ptr',
                 f'{meta_name}2data_idx', f'{meta_name}2data_input_ids', f'{meta_name}2data_attention_mask', f'{meta_name}2data_data2ptr'],
                 
    prune_metadata=False,
    num_metadata_prune_warmup_epochs=10,
    num_metadata_prune_epochs=5,
    metadata_prune_batch_size=2048,
    prune_metadata_names=[f'{meta_name}_meta'],
    use_data_metadata_for_pruning=True,

    predict_with_augmentation=False,
    use_augmentation_index_representation=True,

    data_aug_meta_name=meta_name,
    augmentation_num_beams=None,
    data_aug_prefix=meta_name,
    use_label_metadata=False,

    data_meta_batch_size=2048,
    augment_metadata=False,
    num_metadata_augment_warmup_epochs=10,
    num_metadata_augment_epochs=5,

    use_cpu_for_searching=True,
    use_cpu_for_clustering=True,
)


comet_ml version 3.39.1 is installed, but version 3.43.2 or higher is required. Please update comet_ml to the latest version to enable Comet logging with pip install 'comet-ml>=3.43.2'.


In [None]:
metric = PrecReclMrr(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                     pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])

In [None]:
def model_fn(teacher_model, student_model, mname, meta_embed_init_file, do_inference, use_pretrained, bsz):
    m_teacher = TCH001.from_pretrained(teacher_model, n_data=block.train.dset.n_data, n_lbl=block.n_lbl)
    m_teacher.freeze_embeddings()

    if not do_inference or use_pretrained:
        cluster_sz = 3
        cluster_file = f'{teacher_model}/clusters_{cluster_sz:03d}.joblib'
        if os.path.exists(cluster_file): 
            label_cluster_mapping, n_clusters = joblib.load(cluster_file)
        else:
            label_cluster_mapping, n_clusters = get_cluster_mapping(m_teacher.lbl_repr.weight, cluster_sz=3)
            joblib.dump((label_cluster_mapping, n_clusters), cluster_file)
    else:
        n_clusters = get_cluster_size(m_teacher.lbl_repr.weight.shape[0], cluster_sz=3)

    m_student = OAK008.from_pretrained(student_model, batch_size=bsz, num_batch_labels=5000,
                                       margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,

                                       data_aug_meta_prefix=f'{meta_name}2data', lbl2data_aug_meta_prefix=None,
                                       data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,

                                       num_metadata=block.train.dset.meta[f'{meta_name}_meta'].n_meta, resize_length=5000,
                                       n_clusters=n_clusters, n_labels=block.n_lbl,

                                       calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False,
                                       calib_loss_weight=0.1, use_calib_loss=True,

                                       use_query_loss=True,

                                       meta_loss_weight=0.0,

                                       fusion_loss_weight=0.1, use_fusion_loss=False,

                                       use_encoder_parallel=True)

    if not do_inference or use_pretrained:
        model = DTL004(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1,
                       n_negatives=10, apply_softmax=True, teacher_data_student_label_loss_weight=1.0,student_data_teacher_label_loss_weight=0.0, 
                       data_mse_loss_weight=0.1, label_mse_loss_weight=0.0)
        model.m_student.set_label_cluster_mapping(label_cluster_mapping)
        meta_embeddings = np.load(meta_embed_init_file)
        model.m_student.encoder.set_pretrained_meta_embeddings(torch.tensor(meta_embeddings, dtype=torch.float32))
    else:
        model = DTL004.from_pretrained(mname, m_student=m_student, m_teacher=m_teacher, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1,
                                       n_negatives=10, apply_softmax=True, teacher_data_student_label_loss_weight=1.0,student_data_teacher_label_loss_weight=0.0, 
                                       data_mse_loss_weight=0.1, label_mse_loss_weight=0.0)
    return model

def init_fn(model):
    model.init_retrieval_head()
    model.init_cross_head()
    model.init_meta_embeddings()
    model.init_label_embeddings()
    model.encoder.freeze_pretrained_meta_embeddings()
    

In [None]:
bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

model = load_model(args.output_dir, model_fn, {'teacher_model': teacher_model, 'student_model': student_model, 'mname': None, 
                                               'meta_embed_init_file': meta_embed_init_file, 'do_inference': do_inference, 'use_pretrained': use_pretrained, 
                                               'bsz': bsz},
                   init_fn, do_inference=do_inference, use_pretrained=use_pretrained)

Updating clusters with size 3
Tree depth = 17
doing random split
lengths: [156165, 156165]
remaining levels for GPU split=16
==> gpu splitting random clusters 0 to 2
[34m rank=0 => Total clusters 2	Avg. Cluster size                 78082.50	Time to split nodes on this level 0.30 sec
[34m rank=0 => Total clusters 4	Avg. Cluster size                 39041.25	Time to split nodes on this level 0.09 sec
[34m rank=0 => Total clusters 8	Avg. Cluster size                 19520.62	Time to split nodes on this level 0.12 sec
[34m rank=0 => Total clusters 16	Avg. Cluster size                 9760.31	Time to split nodes on this level 0.23 sec
[34m rank=0 => Total clusters 32	Avg. Cluster size                 4880.16	Time to split nodes on this level 0.25 sec
[34m rank=0 => Total clusters 64	Avg. Cluster size                 2440.08	Time to split nodes on this level 0.31 sec
[34m rank=0 => Total clusters 128	Avg. Cluster size                 1220.04	Time to split nodes on this level 1.57 sec


Some weights of OAK007 were not initialized from the model checkpoint at sentence-transformers/msmarco-distilbert-base-v4 and are newly initialized: ['encoder.cross_head.k.bias', 'encoder.cross_head.k.weight', 'encoder.cross_head.o.bias', 'encoder.cross_head.o.weight', 'encoder.cross_head.q.bias', 'encoder.cross_head.q.weight', 'encoder.cross_head.v.bias', 'encoder.cross_head.v.weight', 'encoder.dr_head.layer_norm.bias', 'encoder.dr_head.layer_norm.weight', 'encoder.dr_head.projector.bias', 'encoder.dr_head.projector.weight', 'encoder.dr_head.transform.bias', 'encoder.dr_head.transform.weight', 'label_cluster_mapping', 'label_embeddings.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
learn = XCLearner(
    model=model,
    args=args,
    train_dataset=block.train.dset,
    eval_dataset=block.test.dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

main(learn, input_args, n_lbl=block.n_lbl)

## Driver

In [None]:
#| export
if __name__ == '__main__':
    output_dir = '/scratch/scai/phd/aiz218323/outputs/mogicX/12_mogic-for-wikiseealsotitles-noise'

    config_file = '/home/scai/phd/aiz218323/scratch/datasets/benchmarks/(mapped)LF-WikiSeeAlsoTitles-320K/configs/data_category_linker_noise-050.json'
    config_key = 'data_category'
    
    teacher_model = '/home/scai/phd/aiz218323/scratch/outputs/xc_nlg/67-ngame-ep-for-wikiseealso-with-input-concatenation-1-4/teacher/'
    student_model = 'sentence-transformers/msmarco-distilbert-base-v4'
    meta_embed_init_file = '/data/OGB_Weights/LF-WikiSeeAlsoTitles-320K/emb_weights.npy'
    
    meta_name = 'lnk'

    input_args = parse_args()

    pkl_file = f'{input_args.pickle_dir}/mogicX/wikiseealsotitles-noise_data-category-linker_distilbert-base-uncased'
    pkl_file = f'{pkl_file}_sxc' if input_args.use_sxc_sampler else f'{pkl_file}_xcs'
    if input_args.only_test: pkl_file = f'{pkl_file}_only-test'
    pkl_file = f'{pkl_file}.joblib'

    do_inference = input_args.do_train_inference or input_args.do_test_inference or input_args.save_train_prediction or input_args.save_test_prediction or input_args.save_representation

    os.makedirs(os.path.dirname(pkl_file), exist_ok=True)
    block = build_block(pkl_file, config_file, input_args.use_sxc_sampler, config_key, do_build=input_args.build_block, only_test=input_args.only_test, 
                        n_slbl_samples=4, main_oversample=False, n_sdata_meta_samples=3, meta_oversample=False, train_meta_topk=5, test_meta_topk=3)

    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=512,
        per_device_eval_batch_size=512,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        eval_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
        representation_search_type='BRUTEFORCE',
    
        output_representation_attribute='data_fused_repr',
        label_representation_attribute='data_repr',
        metadata_representation_attribute='data_repr',
        data_augmentation_attribute='data_repr',
        representation_attribute='data_fused_repr',
        clustering_representation_attribute='data_fused_repr',
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        use_data_metadata_for_clustering=True,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
    
        use_distributional_representation=False,
        use_encoder_parallel=True,
        max_grad_norm=None,
        fp16=True,

        label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'plbl2data_idx', 'plbl2data_data2ptr',
                 f'{meta_name}2data_idx', f'{meta_name}2data_input_ids', f'{meta_name}2data_attention_mask', f'{meta_name}2data_data2ptr'],
        
        prune_metadata=False,
        num_metadata_prune_warmup_epochs=10,
        num_metadata_prune_epochs=5,
        metadata_prune_batch_size=2048,
        prune_metadata_names=[f'{meta_name}_meta'],
        use_data_metadata_for_pruning=True,
    
        predict_with_augmentation=False,
        use_augmentation_index_representation=True,
    
        data_aug_meta_name=meta_name,
        augmentation_num_beams=None,
        data_aug_prefix=meta_name,
        use_label_metadata=False,
    
        data_meta_batch_size=2048,
        augment_metadata=False,
        num_metadata_augment_warmup_epochs=10,
        num_metadata_augment_epochs=5,
    
        use_cpu_for_searching=True,
        use_cpu_for_clustering=True,
    )

    def model_fn(teacher_model, student_model, mname, meta_embed_init_file, do_inference, use_pretrained, bsz):
        m_teacher = TCH001.from_pretrained(teacher_model, n_data=block.train.dset.n_data, n_lbl=block.n_lbl)
        m_teacher.freeze_embeddings()
    
        if not do_inference or use_pretrained:
            cluster_sz = 3
            cluster_file = f'{teacher_model}/clusters_{cluster_sz:03d}.joblib'
            if os.path.exists(cluster_file): 
                label_cluster_mapping, n_clusters = joblib.load(cluster_file)
            else:
                label_cluster_mapping, n_clusters = get_cluster_mapping(m_teacher.lbl_repr.weight, cluster_sz=3)
                joblib.dump((label_cluster_mapping, n_clusters), cluster_file)
        else:
            n_clusters = get_cluster_size(m_teacher.lbl_repr.weight.shape[0], cluster_sz=3)

        m_student = OAK008.from_pretrained(student_model, batch_size=bsz, num_batch_labels=5000,
                                           margin=0.3, num_negatives=10, tau=0.1, apply_softmax=True,
    
                                           data_aug_meta_prefix=f'{meta_name}2data', lbl2data_aug_meta_prefix=None,
                                           data_pred_meta_prefix=None, lbl2data_pred_meta_prefix=None,
    
                                           num_metadata=block.train.dset.meta[f'{meta_name}_meta'].n_meta, resize_length=5000,
                                           n_clusters=n_clusters, n_labels=block.n_lbl,
    
                                           calib_margin=0.05, calib_num_negatives=10, calib_tau=0.1, calib_apply_softmax=False,
                                           calib_loss_weight=0.1, use_calib_loss=True,
    
                                           use_query_loss=True,
    
                                           meta_loss_weight=0.0,
    
                                           fusion_loss_weight=0.1, use_fusion_loss=False,
    
                                           use_encoder_parallel=True)
    
        if not do_inference or use_pretrained:
            model = DTL004(DistilBertConfig(), m_student=m_student, m_teacher=m_teacher, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1,
                           n_negatives=10, apply_softmax=True, teacher_data_student_label_loss_weight=1.0,student_data_teacher_label_loss_weight=0.0, 
                           data_mse_loss_weight=0.1, label_mse_loss_weight=0.0)
            model.m_student.set_label_cluster_mapping(label_cluster_mapping)
            meta_embeddings = np.load(meta_embed_init_file)
            model.m_student.encoder.set_pretrained_meta_embeddings(torch.tensor(meta_embeddings, dtype=torch.float32))
        else:
            model = DTL004.from_pretrained(mname, m_student=m_student, m_teacher=m_teacher, bsz=bsz, tn_targ=5000, margin=0.3, tau=0.1,
                                           n_negatives=10, apply_softmax=True, teacher_data_student_label_loss_weight=1.0,student_data_teacher_label_loss_weight=0.0, 
                                           data_mse_loss_weight=0.1, label_mse_loss_weight=0.0)
        return model

    def init_fn(model):
        model.init_retrieval_head()
        model.init_cross_head()
        model.init_meta_embeddings()
        model.init_label_embeddings()
        model.encoder.freeze_pretrained_meta_embeddings()
        
    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()

    model = load_model(args.output_dir, model_fn, {'teacher_model': teacher_model, 'student_model': student_model, 'mname': None, 
                                                   'meta_embed_init_file': meta_embed_init_file, 'do_inference': do_inference, 'use_pretrained': use_pretrained, 
                                                   'bsz': bsz},
                       init_fn, do_inference=do_inference, use_pretrained=input_args.use_pretrained)
    
    metric = PrecReclMrr(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                         pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200], mk=[5, 10, 20])
    
    learn = XCLearner(
        model=model,
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    main(learn, input_args, n_lbl=block.n_lbl)
    