In [2]:
#| default_exp 03_benchmarking_nvembed_bm25

In [3]:
%load_ext autoreload
%autoreload 2

from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [4]:
#| export
import os,torch, torch.multiprocessing as mp, pickle, numpy as np, math, transformers
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

from xcai.basics import *

from xclib.utils.sparse import retain_topk

from fastcore.utils import *

In [5]:
os.environ['WANDB_MODE'] = 'disabled'

In [6]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['WANDB_PROJECT']='oakVn_00-wikiseealsotitles'

## Huggingface `NV-Embed-v2` example

In [11]:
task_name_to_instruct = {"example": "Given a question, retrieve passages that answer the question",}

query_prefix = "Instruct: "+task_name_to_instruct["example"]+"\nQuery: "
queries = [
    'are judo throws allowed in wrestling?', 
    'how to become a radiology technician in michigan?'
    ]

passage_prefix = ""
passages = [
    "Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.",
    "Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan."
]

In [25]:
model = AutoModel.from_pretrained('nvidia/NV-Embed-v2', trust_remote_code=True)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [14]:
max_length = 32768
query_embeddings = model.encode(queries, instruction=query_prefix, max_length=max_length)
passage_embeddings = model.encode(passages, instruction=passage_prefix, max_length=max_length)

  'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),


In [16]:
scores = (query_embeddings @ passage_embeddings.T) * 100
print(scores.tolist())

[[87.42693328857422, 0.46283310651779175], [0.9652641415596008, 86.0372085571289]]


## Load data

In [6]:
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'

output_dir = '/home/scai/phd/aiz218323/scratch/outputs/mogic/03_benchmarking_nvembed_bm25'

In [7]:
tokenizer = AutoTokenizer.from_pretrained('nvidia/NV-Embed-v2')

In [19]:
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_distilbert-base-uncased_xcs.pkl'

In [7]:
pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_nv-embed-v2_xcs.pkl'

In [9]:
block = XCBlock.from_cfg(data_dir, 'data', transform_type='xcs', tokenizer='nvidia/NV-Embed-v2', 
                         sampling_features=[('lbl2data',1)], max_sequence_length=64, oversample=False)

In [73]:
def prompt_func(x):
    return f'''Instruct: Given the title of a wikipedia article and the corresponding categories of that article on wikipedia, \
your task is to predict the titles of all articles which are likely to be listed in the see also section of the mentioned article.\
\nQuery: {x}'''
    

In [10]:
def prompt_func(x):
    return f'''Instruct: Given the title of a wikipedia article, your task is to predict the titles of all articles which are \
likely to be listed in the see also section of the mentioned article.\nQuery: {x}'''
    

In [11]:
input_text = [prompt_func(o) for o in block.train.dset.data.data_info['input_text']]
tokenized_text = tokenizer.batch_encode_plus(input_text, truncation=True, max_length=64)
block.train.dset.data.data_info.update(tokenized_text)

input_text = [prompt_func(o) for o in block.test.dset.data.data_info['input_text']]
tokenized_text = tokenizer.batch_encode_plus(input_text, truncation=True, max_length=64)
block.test.dset.data.data_info.update(tokenized_text)

In [12]:
with open(pkl_file, 'wb') as file: pickle.dump(block, file)

In [8]:
with open(pkl_file, 'rb') as file: block = pickle.load(file)

In [9]:
batch = next(iter(block.train.dl))

In [10]:
batch.keys()

dict_keys(['plbl2data_idx', 'plbl2data_data2ptr', 'lbl2data_idx', 'lbl2data_identifier', 'lbl2data_input_text', 'lbl2data_input_ids', 'lbl2data_attention_mask', 'lbl2data_data2ptr', 'data_identifier', 'data_input_text', 'data_input_ids', 'data_attention_mask', 'data_idx'])

In [102]:
m = AutoModel.from_pretrained('nvidia/NV-Embed-v2', trust_remote_code=True)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [103]:
o = m(**{'input_ids': batch['data_input_ids'], 'attention_mask': batch['data_attention_mask']})



In [107]:
o['sentence_embeddings'].shape

torch.Size([10, 47, 4096])

In [11]:
batch['data_input_ids'].shape

torch.Size([10, 47])

## Driver

In [None]:
#| export
def prompt_func(x):
    return f'''Instruct: Given the title of a wikipedia article, your task is to predict the titles of all articles which are \
likely to be listed in the see also section of the mentioned article.\nQuery: {x}'''

In [None]:
#| export
if __name__ == '__main__':
    build_block = True
    pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets/'
    data_dir = '/home/scai/phd/aiz218323/Projects/XC_NLG/data'
    
    output_dir = '/home/scai/phd/aiz218323/scratch/outputs/mogic/03_benchmarking_nvembed_bm25'

    """ Load data """
    pkl_file = f'{pkl_dir}/processed/wikiseealsotitles_data_nv-embed-v2_xcs.pkl'

    if build_block:
        block = XCBlock.from_cfg(data_dir, 'data', transform_type='xcs', tokenizer='nvidia/NV-Embed-v2', 
                                 sampling_features=[('lbl2data',1)], max_sequence_length=64, oversample=False)

        input_text = [prompt_func(o) for o in block.train.dset.data.data_info['input_text']]
        tokenized_text = tokenizer.batch_encode_plus(input_text, truncation=True, max_length=64)
        block.train.dset.data.data_info.update(tokenized_text)
        
        input_text = [prompt_func(o) for o in block.test.dset.data.data_info['input_text']]
        tokenized_text = tokenizer.batch_encode_plus(input_text, truncation=True, max_length=64)
        block.test.dset.data.data_info.update(tokenized_text)

        with open(pkl_file, 'wb') as file: pickle.dump(block, file)
        exit()
    else:
        with open(pkl_file, 'rb') as file: block = pickle.load(file)
    

    """ Training arguements """
    args = XCLearningArguments(
        output_dir=output_dir,
        logging_first_step=True,
        per_device_train_batch_size=800,
        per_device_eval_batch_size=800,
        representation_num_beams=200,
        representation_accumulation_steps=10,
        save_strategy="steps",
        evaluation_strategy="steps",
        eval_steps=5000,
        save_steps=5000,
        save_total_limit=5,
        num_train_epochs=300,
        predict_with_representation=True,
        adam_epsilon=1e-6,
        warmup_steps=100,
        weight_decay=0.01,
        learning_rate=2e-4,
        representation_search_type='BRUTEFORCE',
        
        output_representation_attribute='data_repr',
        label_representation_attribute='data_repr',
        metadata_representation_attribute='data_repr',
        data_augmentation_attribute='data_repr',
        representation_attribute='data_repr',
        clustering_representation_attribute='data_repr',
    
        group_by_cluster=True,
        num_clustering_warmup_epochs=10,
        num_cluster_update_epochs=5,
        num_cluster_size_update_epochs=25,
        use_data_metadata_for_clustering=True,
        clustering_type='EXPO',
        minimum_cluster_size=2,
        maximum_cluster_size=1600,

        metric_for_best_model='P@1',
        load_best_model_at_end=True,
        target_indices_key='plbl2data_idx',
        target_pointer_key='plbl2data_data2ptr',
        
        use_distributional_representation=False,
        use_encoder_parallel=True,
        max_grad_norm=None, 
        fp16=True,
        
        label_names=['lbl2data_idx', 'lbl2data_input_ids', 'lbl2data_attention_mask'],
    
        predict_with_augmentation=False,
        use_augmentation_index_representation=True,
    
        use_cpu_for_searching=False,
        use_cpu_for_clustering=True,
    )

    """ model """
    bsz = max(args.per_device_train_batch_size, args.per_device_eval_batch_size)*torch.cuda.device_count()
    model = NVM009.from_pretrained('nvidia/NV-Embed-v2', bsz=bsz, margin=0.3, tau=0.1, n_negatives=10, apply_softmax=True, 
                                   use_encoder_parallel=False)
    
    model.init_dr_head()
    
    """ Training """
    metric = PrecRecl(block.n_lbl, block.test.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                      pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

    learn = XCLearner(
        model=model, 
        args=args,
        train_dataset=block.train.dset,
        eval_dataset=block.test.dset,
        data_collator=block.collator,
        compute_metrics=metric,
    )
    
    mp.freeze_support()
    learn.train()
    

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,P@1,P@10,P@3,P@5,N@1,N@10,N@3,N@5,Psp@1,Psp@10,Psp@3,Psp@5,Psn@1,Psn@10,Psn@3,Psn@5,R@200,R@10,R@100
10,0.0788,0.086276,0.175101,0.056812,0.115025,0.087132,0.175101,0.194849,0.173305,0.180583,0.163741,0.209324,0.168619,0.180018,0.163741,0.201765,0.174354,0.185467,0.429889,0.235065,0.384685


  0%|          | 0/15617 [00:00<?, ?it/s]

  self._set_arrayXarray(i, j, x)


  0%|          | 0/15617 [00:00<?, ?it/s]