In [14]:
!pip install beir

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




# Experiment 1

In [3]:
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from sentence_transformers import SentenceTransformer
# from beir.retrieval.search.dense.prob_index_search import ProbRankModel, ProbRankModelTrainer, DataLoader
from beir.retrieval.embedding import SentenceTransformerEmbedding 
from beir.retrieval.search.dense import knn_search
import numpy as np
import logging
import pathlib, os
import random


In [4]:

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

In [5]:
dataset = "nfcorpus"

#### Download nfcorpus.zip dataset and unzip the dataset
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join("./drive/MyDrive/NNsearch-project/Datasets/", "nfcorpus-1")
data_path = util.download_and_unzip(url, out_dir)

train_corpus, train_queries, train_qrels = GenericDataLoader(data_folder=data_path).load(split="train")
dev_corpus, dev_queries, dev_qrels = GenericDataLoader(data_folder=data_path).load(split="dev")
test_corpus, test_queries, test_qrels = GenericDataLoader(data_folder=data_path).load(split="test")

2023-11-28 13:21:10 - Loading Corpus...


100%|██████████| 3633/3633 [00:00<00:00, 193753.10it/s]


2023-11-28 13:21:10 - Loaded 3633 TRAIN Documents.
2023-11-28 13:21:10 - Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants di

100%|██████████| 3633/3633 [00:00<00:00, 316126.02it/s]


2023-11-28 13:21:10 - Loaded 3633 DEV Documents.
2023-11-28 13:21:10 - Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants died

100%|██████████| 3633/3633 [00:00<00:00, 364682.81it/s]

2023-11-28 13:21:10 - Loaded 3633 TEST Documents.
2023-11-28 13:21:10 - Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants die




In [33]:
model = knn_search.KNNSearch()
retriever = EvaluateRetrieval(model, score_function="euclidean")
results = retriever.retrieve(train_corpus, train_queries)

logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
ndcg, _map, recall, precision = retriever.evaluate(train_qrels, results, retriever.k_values)

Batches: 100%|██████████| 1/1 [00:00<00:00, 82.66it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 480.89it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 704.45it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 711.74it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 713.68it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 723.78it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 701.86it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 688.04it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 697.77it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 677.48it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 700.57it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 692.47it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 710.78it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 721.54it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 720.55it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 715.02it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 715.39it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 715.7

2023-11-26 20:12:21 - Retriever evaluation for k in: [1, 3, 5, 10, 100, 1000]
2023-11-26 20:12:21 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2023-11-26 20:12:21 - 

2023-11-26 20:12:21 - NDCG@1: 0.0031
2023-11-26 20:12:21 - NDCG@3: 0.0114
2023-11-26 20:12:21 - NDCG@5: 0.0103
2023-11-26 20:12:21 - NDCG@10: 0.0098
2023-11-26 20:12:21 - NDCG@100: 0.0154
2023-11-26 20:12:21 - NDCG@1000: 0.1802
2023-11-26 20:12:21 - 

2023-11-26 20:12:21 - MAP@1: 0.0000
2023-11-26 20:12:21 - MAP@3: 0.0007
2023-11-26 20:12:21 - MAP@5: 0.0008
2023-11-26 20:12:21 - MAP@10: 0.0009
2023-11-26 20:12:21 - MAP@100: 0.0015
2023-11-26 20:12:21 - MAP@1000: 0.0100
2023-11-26 20:12:21 - 

2023-11-26 20:12:21 - Recall@1: 0.0000
2023-11-26 20:12:21 - Recall@3: 0.0018
2023-11-26 20:12:21 - Recall@5: 0.0019
2023-11-26 20:12:21 - Recall@10: 0.0031
2023-11-26 20:12:21 - Recall@100: 0.0209
2023-11-26 20:12:21 - Recall@1000: 0.6233

In [34]:
print(f"ndcg: {ndcg}")
print(f"_map: {_map}")
print(f"_recall: {recall}")
print(f"precision: {precision}")

ndcg: {'NDCG@1': 0.0031, 'NDCG@3': 0.01136, 'NDCG@5': 0.01027, 'NDCG@10': 0.00984, 'NDCG@100': 0.01544, 'NDCG@1000': 0.18017}
_map: {'MAP@1': 3e-05, 'MAP@3': 0.00073, 'MAP@5': 0.00076, 'MAP@10': 0.0009, 'MAP@100': 0.00155, 'MAP@1000': 0.01004}
_recall: {'Recall@1': 3e-05, 'Recall@3': 0.00184, 'Recall@5': 0.00194, 'Recall@10': 0.00306, 'Recall@100': 0.02088, 'Recall@1000': 0.6233}
precision: {'P@1': 0.0031, 'P@3': 0.01445, 'P@5': 0.01176, 'P@10': 0.01022, 'P@100': 0.00913, 'P@1000': 0.02076}


In [None]:
retriever = EvaluateRetrieval(model, score_function="euclidean")
results = retriever.retrieve(test_corpus, test_queries)

logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
ndcg, _map, recall, precision = retriever.evaluate(test_qrels, results, retriever.k_values)

# Experiment 2

In [1]:

from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from sentence_transformers import SentenceTransformer

from beir.retrieval.search.dense.prob_index import ProbRankModel, ProbRankModelTrainer, DataLoader
from beir.retrieval.search.dense.prob_knn_search import ProbIndexSearch
from beir.retrieval.embedding import SentenceTransformerEmbedding 
from beir.retrieval.search.dense import knn_search

import numpy as np
import logging
import pathlib, os
import random

# load data
dataset = "nfcorpus"

#### Download nfcorpus.zip dataset and unzip the dataset
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join("./drive/MyDrive/NNsearch-project/Datasets/", "nfcorpus-1")
data_path = util.download_and_unzip(url, out_dir)

train_corpus, train_queries, train_qrels = GenericDataLoader(data_folder=data_path).load(split="train")
dev_corpus, dev_queries, dev_qrels = GenericDataLoader(data_folder=data_path).load(split="dev")
test_corpus, test_queries, test_qrels = GenericDataLoader(data_folder=data_path).load(split="test")

  from tqdm.autonotebook import tqdm
100%|██████████| 3633/3633 [00:00<00:00, 345209.81it/s]
100%|██████████| 3633/3633 [00:00<00:00, 361284.74it/s]
100%|██████████| 3633/3633 [00:00<00:00, 367923.18it/s]


In [2]:

import numpy as np
from torch import nn, optim

num_clusters = 256

dataloader = DataLoader(train_corpus, train_queries, test_queries, test_corpus, train_qrels, test_qrels)
X_train, y_train, X_test, y_test, corpus_clusters  = dataloader.load_data(num_clusters)


##### model
input_dimension = X_train.shape[1] 
model_f = ProbRankModel(input_dimension, num_clusters)

##### trainer
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_f.parameters(), lr=0.1)
# get trainer
trainer = ProbRankModelTrainer(model_f, criterion=criterion, optimizer=optimizer, batch_size = 16, num_epochs = 100, num_clusters = num_clusters)
# train
trainer.fit(X_train, y_train)





Batches: 100%|██████████| 1/1 [00:00<00:00,  1.40it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 455.41it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 530.79it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 556.13it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 518.01it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 539.81it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 560.81it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 536.42it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 530.39it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 547.27it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 560.51it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 547.20it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 549.06it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 551.59it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 565.19it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 514.01it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 553.48it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 538.9

Successfully embedded the corpus!


Batches: 100%|██████████| 1/1 [00:00<00:00, 546.13it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 669.48it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 701.27it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 706.83it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 716.73it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 668.52it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 704.45it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 716.85it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 692.47it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 710.30it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 682.67it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 725.66it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 697.66it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 724.53it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 732.12it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 719.06it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 722.78it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 697.

Successfully embedded the queries!


Batches: 100%|██████████| 1/1 [00:00<00:00, 455.26it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 517.05it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 569.41it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 559.84it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 521.94it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 537.94it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 556.64it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 535.26it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 533.15it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 544.22it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 591.33it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 533.83it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 543.94it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 579.80it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 580.21it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 511.00it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 548.42it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 522.

Successfully embedded the corpus!


Batches: 100%|██████████| 1/1 [00:00<00:00, 556.50it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 674.43it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 702.56it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 716.61it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 714.29it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 722.16it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 716.49it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 715.02it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 712.59it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 733.78it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 728.56it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 715.39it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 732.37it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 711.26it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 715.87it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 727.42it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 742.75it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 717.

Successfully embedded the queries!


Batches: 100%|██████████| 1/1 [00:00<00:00, 504.43it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 529.25it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 561.94it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 555.54it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 507.60it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 528.38it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 551.59it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 528.52it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 524.68it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 528.98it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 586.94it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 535.94it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 546.35it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 566.64it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 571.12it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 512.31it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 540.29it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 515.

Successfully embedded the corpus!
Epoch [5/100], Loss: 16.2490
Epoch [10/100], Loss: 16.2490
Epoch [15/100], Loss: 16.2490
Epoch [20/100], Loss: 16.2490
Epoch [25/100], Loss: 16.2490
Epoch [30/100], Loss: 16.2490
Epoch [35/100], Loss: 16.2490
Epoch [40/100], Loss: 16.2490
Epoch [45/100], Loss: 16.2490
Epoch [50/100], Loss: 16.2490
Epoch [55/100], Loss: 16.2490
Epoch [60/100], Loss: 16.2490
Epoch [65/100], Loss: 16.2490
Epoch [70/100], Loss: 16.2490
Epoch [75/100], Loss: 16.2490
Epoch [80/100], Loss: 16.2490
Epoch [85/100], Loss: 16.2490
Epoch [90/100], Loss: 16.2490
Epoch [95/100], Loss: 16.2490
Epoch [100/100], Loss: 16.2490


In [5]:
# Retrieve the trained model
index_model = trainer.get_model()
print(isinstance(index_model, ProbRankModel))

test_preds, test_labels = trainer.predict(X_test)


True


In [6]:
model = ProbIndexSearch(trainer, num_clusters = 256, topk_cluster = 5, topk_emb = 20)
retriever = EvaluateRetrieval(model, score_function="") # we won't use the score function here, the probability is calculated inherently in prob_knn_search 

if isinstance(model, ProbIndexSearch):
    results = retriever.retrieve(test_corpus, test_queries, X_train = X_train, y_train = y_train, X_test = X_test, y_test = y_test, cluster_dict = corpus_clusters, test_qrels=test_qrels)

#### Evaluate your retrieval using NDCG@k, MAP@K ...
ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(test_qrels, results, retriever.k_values)

#### Print top-k documents retrieved ####
top_k = model.topk_emb

query_id, ranking_scores = random.choice(list(results.items()))
scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
logging.info("Query : %s\n" % test_queries[query_id])

for rank in range(top_k):
    doc_id = scores_sorted[rank][0]
    # Format: Rank x: ID [Title] Body
    logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, test_corpus[doc_id].get("title"), test_corpus[doc_id].get("text")))

ndcg, _map, recall, precision = retriever.evaluate(test_qrels, results, retriever.k_values)


print(f"ndcg: {ndcg}")
print(f"_map: {_map}")
print(f"_recall: {recall}")
print(f"precision: {precision}")

Batches: 100%|██████████| 1/1 [00:00<00:00, 83.39it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 490.28it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 687.59it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 719.68it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 706.11it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 699.52it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 712.23it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 705.16it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 719.68it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 734.43it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 666.71it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 717.10it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 714.05it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 697.42it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 716.49it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 713.56it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 702.80it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 719.0

Successfully embedded the queries!
Epoch [5/100], Loss: 16.2565
Epoch [10/100], Loss: 16.2565
Epoch [15/100], Loss: 16.2565
Epoch [20/100], Loss: 16.2565
Epoch [25/100], Loss: 16.2565
Epoch [30/100], Loss: 16.2565
Epoch [35/100], Loss: 16.2565
Epoch [40/100], Loss: 16.2565
Epoch [45/100], Loss: 16.2565
Epoch [50/100], Loss: 16.2565
Epoch [55/100], Loss: 16.2565
Epoch [60/100], Loss: 16.2565
Epoch [65/100], Loss: 16.2565
Epoch [70/100], Loss: 16.2565
Epoch [75/100], Loss: 16.2565
Epoch [80/100], Loss: 16.2565
Epoch [85/100], Loss: 16.2565
Epoch [90/100], Loss: 16.2565
Epoch [95/100], Loss: 16.2565
Epoch [100/100], Loss: 16.2565


Batches: 100%|██████████| 255/255 [02:04<00:00,  2.05it/s]


ndcg: {'NDCG@1': 0.02477, 'NDCG@3': 0.02039, 'NDCG@5': 0.01821, 'NDCG@10': 0.01992, 'NDCG@100': 0.02566, 'NDCG@1000': 0.0225}
_map: {'MAP@1': 0.00375, 'MAP@3': 0.00438, 'MAP@5': 0.00457, 'MAP@10': 0.00515, 'MAP@100': 0.00676, 'MAP@1000': 0.00676}
_recall: {'Recall@1': 0.00375, 'Recall@3': 0.0047, 'Recall@5': 0.00502, 'Recall@10': 0.00696, 'Recall@100': 0.03044, 'Recall@1000': 0.03048}
precision: {'P@1': 0.02477, 'P@3': 0.02064, 'P@5': 0.01734, 'P@10': 0.02043, 'P@100': 0.01375, 'P@1000': 0.00138}


In [7]:
print(type(trainer))

<class 'beir.retrieval.search.dense.prob_index.ProbRankModelTrainer'>


In [8]:
print(f"ndcg: {ndcg}")
print(f"_map: {_map}")
print(f"_recall: {recall}")
print(f"precision: {precision}")

ndcg: {'NDCG@1': 0.02477, 'NDCG@3': 0.02039, 'NDCG@5': 0.01821, 'NDCG@10': 0.01992, 'NDCG@100': 0.02566, 'NDCG@1000': 0.0225}
_map: {'MAP@1': 0.00375, 'MAP@3': 0.00438, 'MAP@5': 0.00457, 'MAP@10': 0.00515, 'MAP@100': 0.00676, 'MAP@1000': 0.00676}
_recall: {'Recall@1': 0.00375, 'Recall@3': 0.0047, 'Recall@5': 0.00502, 'Recall@10': 0.00696, 'Recall@100': 0.03044, 'Recall@1000': 0.03048}
precision: {'P@1': 0.02477, 'P@3': 0.02064, 'P@5': 0.01734, 'P@10': 0.02043, 'P@100': 0.01375, 'P@1000': 0.00138}


In [None]:
# mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
# recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="recall_cap")
# hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
# top_k_accuracy = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="top_k_accuracy")

# #### Print top-k documents retrieved ####
# top_k = 10

# query_id, ranking_scores = random.choice(list(results.items()))
# scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
# logging.info("Query : %s\n" % queries[query_id])

# for rank in range(top_k):
#     doc_id = scores_sorted[rank][0]
#     # Format: Rank x: ID [Title] Body
#     logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))

In [9]:
# from sklearn.cluster import KMeans
# import numpy as np

# emb_model = SentenceTransformer('all-MiniLM-L6-v2').to('cuda')
# kmeans_model = KMeans(n_clusters=num_clusters, random_state=42) # default metric is euclidean distance
# STembedder = SentenceTransformerEmbedding()
# corpus_embeddings =  STembedder.embed_corpus(dev_corpus, emb_model)
# kmeans_model.fit(corpus_embeddings)

# for i, cluster_label in enumerate(kmeans_model.labels_):
#     print(type(cluster_label))


Batches: 100%|██████████| 1/1 [00:00<00:00, 432.67it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 479.84it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 560.21it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 548.92it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 515.33it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 513.32it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 556.50it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 534.92it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 514.39it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 523.11it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 585.22it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 534.17it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 540.99it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 575.67it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 573.38it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 491.19it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 541.34it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 525.

Successfully embedded the corpus!
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'>
<class 'numpy.int32'

In [6]:
# from beir.reranking.rerank import Rerank


# cluster_dict = corpus_clusters

# emb_model = SentenceTransformer('all-MiniLM-L6-v2').to('cuda')

# STembedder = SentenceTransformerEmbedding()
# query_embeddings =  STembedder.embed_queries(test_queries, emb_model)


# query_topk_cluster_dict = {}

# indices_and_keys = list(zip(range(len(test_queries)), test_queries))
# for i, query_id in indices_and_keys:
#     # Extract the indices of the k largest values
#     topk_indices = np.argsort(np.array(test_preds[i]))[::-1][:5]
#     # Get the corresponding cluster labels
#     topk_cluster_labels = topk_indices.tolist()
#     # Store the result in the dictionary
#     query_topk_cluster_dict[query_id] = topk_cluster_labels


# result_dict = {}

# for query_index, (query_id, topk_clusters) in enumerate(query_topk_cluster_dict.items()):
#     query_result_dict = {}
#     query_embedding = query_embeddings[query_index]
#     for cluster_label in topk_clusters:
#         print(cluster_label, type(cluster_label))
#         if cluster_label in cluster_dict:
#             cluster = cluster_dict[cluster_label]
#             for corpus_info in cluster:
#                 corpus_id = corpus_info['corpus_id']
#                 corpus_embedding = corpus_info['corpus_embedding']

#                 distance = np.linalg.norm(query_embedding - corpus_embedding)
#                 query_result_dict[corpus_id] = distance

#     result_dict[query_id] = query_result_dict # all embeddings from topk_clusters


# ################################################
# #### (2) RERANK Top-20 docs using Cross-Encoder
# ################################################
# from beir.reranking.models import CrossEncoder

# #### Reranking using Cross-Encoder models #####
# #### https://www.sbert.net/docs/pretrained_cross-encoders.html
# cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-electra-base')

# #### Or use MiniLM, TinyBERT etc. CE models (https://www.sbert.net/docs/pretrained-models/ce-msmarco.html)
# # cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
# # cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-6')

# reranker = Rerank(cross_encoder_model, batch_size=128)
# topk_emb = 20
# queries = test_queries
# corpus = test_corpus

# # Rerank all results using the reranker provided
# rerank_results = reranker.rerank(corpus, queries, result_dict, top_k=len(result_dict))

# #### Evaluate your retrieval using NDCG@k, MAP@K ...
# ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(test_qrels, rerank_results, retriever.k_values)

# #### Print top-k documents retrieved ####
# top_k = topk_emb

# query_id, ranking_scores = random.choice(list(rerank_results.items()))
# scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
# logging.info("Query : %s\n" % queries[query_id])

# for rank in range(top_k):
#     doc_id = scores_sorted[rank][0]
#     # Format: Rank x: ID [Title] Body
#     logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))


# result_dict
# print(len(result_dict))

Batches: 100%|██████████| 1/1 [00:00<00:00, 523.83it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 671.95it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 705.52it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 709.58it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 731.99it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 730.08it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 726.92it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 719.93it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 717.96it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 743.41it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 659.17it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 727.42it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 740.52it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 728.18it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 729.70it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 717.83it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 735.58it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 749.

Successfully embedded the queries!
172 <class 'int'>
255 <class 'int'>
126 <class 'int'>
92 <class 'int'>
91 <class 'int'>
172 <class 'int'>
255 <class 'int'>
126 <class 'int'>
92 <class 'int'>
91 <class 'int'>
172 <class 'int'>
255 <class 'int'>
126 <class 'int'>
92 <class 'int'>
91 <class 'int'>
172 <class 'int'>
255 <class 'int'>
126 <class 'int'>
92 <class 'int'>
91 <class 'int'>
172 <class 'int'>
255 <class 'int'>
126 <class 'int'>
92 <class 'int'>
91 <class 'int'>
172 <class 'int'>
255 <class 'int'>
126 <class 'int'>
92 <class 'int'>
91 <class 'int'>
172 <class 'int'>
255 <class 'int'>
126 <class 'int'>
92 <class 'int'>
91 <class 'int'>
172 <class 'int'>
255 <class 'int'>
126 <class 'int'>
92 <class 'int'>
91 <class 'int'>
172 <class 'int'>
255 <class 'int'>
126 <class 'int'>
92 <class 'int'>
91 <class 'int'>
172 <class 'int'>
255 <class 'int'>
126 <class 'int'>
92 <class 'int'>
91 <class 'int'>
172 <class 'int'>
255 <class 'int'>
126 <class 'int'>
92 <class 'int'>
91 <class 'int

Batches: 100%|██████████| 139/139 [01:07<00:00,  2.06it/s]


323
