In [9]:
import numpy as np
from nudge import NUDGEM, NUDGEN
from knnretriever import kNNRetriever
from utils import calc_metrics_batch, load_hf_datasets, embed_data_and_query_sets

In [None]:
dataset_name = 'nq'
dataset, query_sets = load_hf_datasets(dataset_name)
data_emb, query_sets = embed_data_and_query_sets(dataset, query_sets, "BAAI/bge-small-en-v1.5")

loading dataset
loading qs train
loading qs dev
loading qs test




embedding data


In [None]:
max_nontest_index = -1
for split in ["train", "dev"]:
    max_nontest_index = max(np.array([indx for curr_q_ans_indx in query_sets[split]['q_ans_indx'] for indx in curr_q_ans_indx]).max()+1, max_nontest_index)

nontrain_dataset = dataset.loc[max_nontest_index:]
if nontrain_dataset.shape[0] == 0:
    embeddings = data_emb
    nontrain_embeddings  = None
else:
    embeddings = data_emb[:max_nontest_index]
    nontrain_embeddings = data_emb[max_nontest_index:]    

In [None]:
nudgen =  NUDGEN()
new_embs_nudgen = nudgen.finetune_embeddings(embeddings, query_sets['train'], query_sets['dev'], (nontrain_embeddings, None))
nudge_nret = kNNRetriever(new_embs_nudgen, nontrain_embeddings)
nudge_n_res = nudge_nret.retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])

In [None]:
nudgem =  NUDGEM()
new_embs_nudgem = nudgem.finetune_embeddings(embeddings, query_sets['train'], query_sets['dev'],(nontrain_embeddings, None))
nudge_mret = kNNRetriever(new_embs_nudgem, nontrain_embeddings,dist_metric='dot')
nudge_m_res = nudge_mret.retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])

In [None]:
no_ft_ret = kNNRetriever(embeddings, nontrain_embeddings)
no_ft_res = no_ft_ret.retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])

In [None]:
metrics = [('recall',10), ('ndcg',10)]
no_ft_accs = calc_metrics_batch(metrics,no_ft_res, query_sets['test']['q_ans_indx'], query_sets['test']['q_ans_indx_rel'])
nudgem_accs = calc_metrics_batch(metrics,nudge_m_res, query_sets['test']['q_ans_indx'], query_sets['test']['q_ans_indx_rel'])
nudgen_accs = calc_metrics_batch(metrics,nudge_n_res, query_sets['test']['q_ans_indx'], query_sets['test']['q_ans_indx_rel'])
print(f"No Fine-Tuning {metrics[0][0]}@{metrics[0][1]}: {no_ft_accs[0]*100:.1f}, {metrics[1][0]}@{metrics[1][1]}: {no_ft_accs[1]*100:.1f}")
print(f"NUDGE-M {metrics[0][0]}@{metrics[0][1]}: {nudgem_accs[0]*100:.1f}, {metrics[1][0]}@{metrics[1][1]}: {nudgem_accs[1]*100:.1f}")
print(f"NUDGE-N {metrics[0][0]}@{metrics[0][1]}: {nudgen_accs[0]*100:.1f}, {metrics[1][0]}@{metrics[1][1]}: {nudgen_accs[1]*100:.1f}")
