In [1]:
import numpy as np
from nudge import NUDGEM, NUDGEN
from util.knnretriever import kNNRetriever
from util.utils import calc_metrics_batch, load_hf_datasets, embed_data_and_query_sets

  from tqdm.autonotebook import tqdm, trange


In [2]:
dataset_name = 'nq'
dataset, query_sets = load_hf_datasets(dataset_name)
data_emb, query_sets = embed_data_and_query_sets(dataset, query_sets, "BAAI/bge-small-en-v1.5")

loading dataset


README.md:   0%|          | 0.00/814 [00:00<?, ?B/s]

data.parquet:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7631395 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/7631395 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7631395 [00:00<?, ? examples/s]

loading qs train


qs.parquet:   0%|          | 0.00/14.1M [00:00<?, ?B/s]

qs.parquet:   0%|          | 0.00/1.83M [00:00<?, ?B/s]

qs.parquet:   0%|          | 0.00/2.33M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/61804 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/7978 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

qs_rel.parquet:   0%|          | 0.00/841k [00:00<?, ?B/s]

qs_rel.parquet:   0%|          | 0.00/114k [00:00<?, ?B/s]

qs_rel.parquet:   0%|          | 0.00/147k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/81364 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/10495 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/13148 [00:00<?, ? examples/s]

loading qs dev
loading qs test
embedding data


Batches:   0%|          | 0/238482 [00:00<?, ?it/s]

embedding qs train


Batches:   0%|          | 0/1932 [00:00<?, ?it/s]

embedding qs dev


Batches:   0%|          | 0/250 [00:00<?, ?it/s]

embedding qs test


Batches:   0%|          | 0/313 [00:00<?, ?it/s]

In [3]:
max_nontest_index = -1
for split in ["train", "dev"]:
    max_nontest_index = max(np.array([indx for curr_q_ans_indx in query_sets[split]['q_ans_indx'] for indx in curr_q_ans_indx]).max()+1, max_nontest_index)
nontrain_dataset = dataset.loc[max_nontest_index:]
if nontrain_dataset.shape[0] == 0:
    embeddings = data_emb
    nontrain_embeddings  = None
else:
    embeddings = data_emb[:max_nontest_index]
    nontrain_embeddings = data_emb[max_nontest_index:]

In [5]:
nudgen =  NUDGEN()
new_embs_nudgen = nudgen.finetune_embeddings(embeddings, query_sets['train'], query_sets['dev'], nontrain_embeddings)
nudge_nret = kNNRetriever(new_embs_nudgen, nontrain_embeddings)
nudge_n_res = nudge_nret.retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])

Calculating G
Finding gamma


In [6]:
nudgem =  NUDGEM()
new_embs_nudgem = nudgem.finetune_embeddings(embeddings, query_sets['train'], query_sets['dev'], nontrain_embeddings)
nudge_mret = kNNRetriever(new_embs_nudgem, nontrain_embeddings,dist_metric='dot')
nudge_m_res = nudge_mret.retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])

Calculating G
Finding gamma


In [7]:
no_ft_ret = kNNRetriever(embeddings, nontrain_embeddings)
no_ft_res = no_ft_ret.retrieve_topk_from_emb_batch(k=10, q_embeds=query_sets['test']['q_embs'])

In [8]:
metrics = [('recall',10), ('ndcg',10)]
no_ft_accs = calc_metrics_batch(metrics,no_ft_res, query_sets['test']['q_ans_indx'], query_sets['test']['q_ans_indx_rel'])
nudgem_accs = calc_metrics_batch(metrics,nudge_m_res, query_sets['test']['q_ans_indx'], query_sets['test']['q_ans_indx_rel'])
nudgen_accs = calc_metrics_batch(metrics,nudge_n_res, query_sets['test']['q_ans_indx'], query_sets['test']['q_ans_indx_rel'])
print(f"No Fine-Tuning {metrics[0][0]}@{metrics[0][1]}: {no_ft_accs[0]*100:.1f}, {metrics[1][0]}@{metrics[1][1]}: {no_ft_accs[1]*100:.1f}")
print(f"NUDGE-M {metrics[0][0]}@{metrics[0][1]}: {nudgem_accs[0]*100:.1f}, {metrics[1][0]}@{metrics[1][1]}: {nudgem_accs[1]*100:.1f}")
print(f"NUDGE-N {metrics[0][0]}@{metrics[0][1]}: {nudgen_accs[0]*100:.1f}, {metrics[1][0]}@{metrics[1][1]}: {nudgen_accs[1]*100:.1f}")


No Fine-Tuning recall@10: 36.3, ndcg@10: 21.2
NUDGE-M recall@10: 43.6, ndcg@10: 38.6
NUDGE-N recall@10: 58.0, ndcg@10: 45.9
