This file is going to use BEIR to download a dataset and create embeddings 

# Part 1. embeddings, saving & loading

In [1]:
# constants
DATASET = "quora"
sbert_model_name = "msmarco-distilbert-base-tas-b"
device = "cpu" # cuda for gpu usage
k_queries = 100
k_documents = 10000

In [4]:
# import libraries
from time import time
from beir import util
from beir_reengineered import NewSentenceBERT
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
import os, json, random

  from tqdm.autonotebook import tqdm


In [5]:
#### Download nfcorpus.zip dataset and unzip the dataset
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(DATASET)
out_dir = "datasets"
data_path = util.download_and_unzip(url, out_dir)

In [6]:
# import libraries
from time import time
from beir import util
from beir_reengineered import NewSentenceBERT
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
import os, json, random

In [7]:
#### Download nfcorpus.zip dataset and unzip the dataset
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(DATASET)
out_dir = "datasets"
data_path = util.download_and_unzip(url, out_dir)

In [8]:
#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
# data folder would contain these files:
# (1) nfcorpus/corpus.jsonl  (format: jsonlines)
# (2) nfcorpus/queries.jsonl (format: jsonlines)
# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))

corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")


100%|██████████| 522931/522931 [00:02<00:00, 205572.89it/s]


In [9]:
#### Dense Retrieval using SBERT (Sentence-BERT) ####
#### Provide any pretrained sentence-transformers model
#### The model was fine-tuned using cosine-similarity.
#### Complete list - https://www.sbert.net/docs/pretrained_models.html

beir_sbert = NewSentenceBERT(sbert_model_name, device=device)
model = DRES(beir_sbert, batch_size=256, corpus_chunk_size=512*9999)

In [10]:
# Create sub-sample
subset_of_queries = random.sample(queries.keys(), k_queries)
queries = {qid: queries[qid] for qid in subset_of_queries}
qrels = {qid: qrels[qid] for qid in subset_of_queries}
true_documents = set([docid for qid in qrels for docid in qrels[qid]])
false_documents = set(random.sample(list(set([docid for docid in corpus if docid not in true_documents])), k_documents))
subset_of_corpus = true_documents | false_documents
corpus = {docid: corpus[docid] for docid in subset_of_corpus}

In [11]:
# Encode queries
queries_l = [queries[qid] for qid in queries]
query_embeddings = model.model.encode_queries(
    queries_l,
    batch_size=model.batch_size,
    show_progress_bar=model.show_progress_bar,
    convert_to_tensor=model.convert_to_tensor
).cpu().numpy()

Batches: 100%|██████████| 1/1 [00:02<00:00,  2.26s/it]


In [12]:
# Encode documents
corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True)
corpus_l = [corpus[cid] for cid in corpus_ids]
sub_corpus_embeddings = model.model.encode_corpus(
    corpus_l,
    batch_size=model.batch_size,
    show_progress_bar=model.show_progress_bar,
    convert_to_tensor=model.convert_to_tensor
).cpu().numpy()

Batches: 100%|██████████| 40/40 [03:32<00:00,  5.32s/it]


In [13]:
# Save as new dataset
os.makedirs("datasets/subquora/qrels", exist_ok=True)
with open("datasets/subquora/queries.jsonl", "w") as f:
    f.writelines([json.dumps({"_id": qid, "text": queries[qid], "metadata":{}})+"\n" for qid in queries])
with open("datasets/subquora/corpus.jsonl", "w") as f:
    f.writelines([json.dumps({"_id": docid, "title": corpus[docid].get("title"), "text": corpus[docid].get("text"), "metadata":{}})+"\n" for docid in corpus])
with open("datasets/subquora/qrels/test.tsv", "w") as f:
    f.write("query-id\tcorpus-id\tscore\n")
    for qid in qrels:
        for docid in qrels[qid]:
            f.write("{}\t{}\t{}\n".format(qid, docid, qrels[qid][docid]))

In [14]:
# # Save embeddings
corpus_embeddings_dict = dict(zip(corpus_ids, sub_corpus_embeddings))
query_embeddings_dict = dict(zip(queries.keys(), query_embeddings))
import pickle

with open("datasets/subquora/corpus_embeddings.pkl", "wb") as f:
    pickle.dump(corpus_embeddings_dict, f)
with open("datasets/subquora/query_embeddings.pkl", "wb") as f:
    pickle.dump(query_embeddings_dict, f)

In [21]:
# import os

# # List files in the directory
# files = os.listdir("datasets/subquora")

# # Check if the file is in the list
# if "query_embeddings.pkl" in files:
#     print("File found in directory.")
# else:
#     print("File not found in directory.")


File found in directory.


In [16]:
# import pickle
# with open("datasets/subquora/corpus_embeddings.pkl", "rb") as f:
#         corpus_embeddings_test = pickle.load(f)
#         print(type(corpus_embeddings_test))
# corpus_embeddings_test

<class 'dict'>


In [41]:
# with open("datasets/subquora/query_embeddings.pkl", "rb") as f:
#         query_embeddings_test = pickle.load(f)
#         print(type(query_embeddings_test))
# query_embeddings_test
# <class 'dict'>
# {'148989': array([-5.42879179e-02, -1.91017106e-01, -1.20709697e-02,  2.05447242e-01,
#          4.04155254e-01,  2.48869240e-01, -4.52450693e-01,  1.38501331e-01,
#         -2.10237205e-01, -1.51516631e-01,  4.76916820e-01,  3.10554087e-01,

<class 'dict'>


{'148989': array([-5.42879179e-02, -1.91017106e-01, -1.20709697e-02,  2.05447242e-01,
         4.04155254e-01,  2.48869240e-01, -4.52450693e-01,  1.38501331e-01,
        -2.10237205e-01, -1.51516631e-01,  4.76916820e-01,  3.10554087e-01,
         1.72650844e-01, -5.25526404e-02, -3.24088454e-01,  4.29388463e-01,
        -3.78398784e-02,  1.83394328e-01, -1.89989984e-01, -9.43334680e-03,
         2.17485838e-02,  4.44550365e-02, -1.46826476e-01, -1.65164605e-01,
         1.36961639e-01, -7.75941536e-02, -9.14802477e-02,  3.59186679e-01,
        -1.92213535e-01, -1.17055297e-01, -2.74288625e-01,  3.07339847e-01,
        -7.98731074e-02, -1.71654567e-01,  1.52620167e-01,  1.42504826e-01,
         3.79293151e-02, -1.36001781e-02,  1.02324180e-01, -1.81025997e-01,
        -2.76163489e-01,  1.60303656e-02, -8.72119442e-02,  2.68899441e-01,
        -2.27519628e-02,  1.74226537e-01, -2.49796689e-01,  2.79491276e-01,
         3.36900443e-01,  1.85794935e-01,  1.92436472e-01, -3.22766632e-01,
  

In [23]:
# query_embeddings_test
# # query_ids = list(query_embeddings_test.keys())
# # print(query_ids)
# # corpus_ids = sorted(query_ids, reverse=True)
# # corpus_ids

# import numpy as np

# max_decimal_places = 0

# for key, array in query_embeddings_test.items():
#     for element in array:
#         decimal_places = len(str(element).split('.')[1])
#         if decimal_places > max_decimal_places:
#             max_decimal_places = decimal_places

# print("Maximum number of decimal places:", max_decimal_places) # Maximum number of decimal places: 12

Maximum number of decimal places: 12


# Part 2. (only sometimes) round embeddings

In [18]:
from beir.retrieval.search import BaseSearch # type: ignore beir/retrieval/search/dense/exact_search.py
from beir.util import cos_sim #beir/util.py
import torch # type: ignore
import numpy as np
from typing import Dict
import heapq
import logging
logger = logging.getLogger(__name__)

In [42]:
# ExperiementRetrievalExactSearch is parent class for any model we are using for our experiement that can be used for retrieval
# Abstract class is BaseSearch
class ExperiementRetrievalExactSearch(BaseSearch):
    def __init__(
            self,
            model,
            rounding_decimal: int = 16,
            path_corpus_embeddings: str = "datasets/subquora/corpus_embeddings.pkl",
            path_query_embeddings: str = "datasets/subquora/query_embeddings.pkl",
            **kwargs):
        #model is class should do nothing
        self.model = model
        self.rounding_decimal = rounding_decimal
        self.path_corpus_embeddings = path_corpus_embeddings
        self.path_query_embeddings = path_query_embeddings
        self.show_progress_bar = kwargs.get("show_progress_bar", True)
        self.convert_to_tensor = kwargs.get("convert_to_tensor", True)
        self.results = {}

        logger.info("Load in Encoded Queries and Corpus from Pickle...")
        # Verify file existence
        if not os.path.exists(self.path_corpus_embeddings):
            raise FileNotFoundError(f"File '{self.path_corpus_embeddings}' not found.")
        if not os.path.exists(self.path_query_embeddings):
            raise FileNotFoundError(f"File '{self.path_query_embeddings}' not found.")

        with open(self.path_query_embeddings, "rb") as f:
            self.query_embeddings = pickle.load(f)

        with open(self.path_corpus_embeddings, "rb") as f:
            self.corpus_embeddings = pickle.load(f)

        # rounding decimal
        if rounding_decimal < 16:
            logger.info("Rounding decimal places of Queries and Corpus...")
            for key, value in self.query_embeddings.items():
                self.query_embeddings[key] = np.round(value, decimals=rounding_decimal)

            for key, value in self.corpus_embeddings.items():
                self.corpus_embeddings[key] = np.round(value, decimals=rounding_decimal)

    def search(self,
               corpus: Dict[str, Dict[str, str]],
               queries: Dict[str, str],
               top_k: int,
               score_function: str,
               return_sorted: bool = False,
               **kwargs) -> Dict[str, Dict[str, float]]:
        # Create embeddings for all queries using model.encode_queries()
        # Runs semantic search against the corpus embeddings
        # Returns a ranked list with the corpus ids

        query_ids = list(self.query_embeddings.keys())
        self.results = {qid: {} for qid in query_ids}


        logger.info("Sorting Corpus by document length (Longest first)...")
        corpus_ids = sorted(list(self.corpus_embeddings.keys()), reverse=True)

        result_heaps = {qid: [] for qid in query_ids}  # Keep only the top-k docs for each query

        # Convert dictionary values to PyTorch tensors
        corpus_tensors = [torch.tensor(embedding) for embedding in self.corpus_embeddings.values()]
        query_tensors = [torch.tensor(embedding) for embedding in self.query_embeddings.values()]
        # Stack tensors along a new dimension (batch dimension)
        corpus_embeddings_tensor = torch.stack(corpus_tensors)
        query_embeddings_tensor = torch.stack(query_tensors)

        # Compute similarites using  cosine-similarity
        cos_scores = cos_sim(query_embeddings_tensor, corpus_embeddings_tensor)
        cos_scores[torch.isnan(cos_scores)] = -1

        # Get top-k values
        cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(top_k+1, len(cos_scores[1])), dim=1, largest=True, sorted=return_sorted)
        cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
        cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()

        for query_itr in range(len(query_embeddings_tensor)):
            query_id = query_ids[query_itr]
            for sub_corpus_id, score in zip(cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]):
                corpus_id = corpus_ids[sub_corpus_id]
                if corpus_id != query_id:
                    if len(result_heaps[query_id]) < top_k:
                        # Push item on the heap
                        heapq.heappush(result_heaps[query_id], (score, corpus_id))
                    else:
                        # If item is larger than the smallest in the heap, push it on the heap then pop the smallest element
                        heapq.heappushpop(result_heaps[query_id], (score, corpus_id))

        for qid in result_heaps:
            for score, corpus_id in result_heaps[qid]:
                self.results[qid][corpus_id] = score

        return self.results

In [43]:
#### Load the SBERT model and retrieve using cosine-similarity
model = ExperiementRetrievalExactSearch(beir_sbert)
retriever = EvaluateRetrieval(model, score_function="cos_sim") # or "cos_sim" for cosine similarity
results = retriever.retrieve(corpus, queries)



In [44]:
#### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K  where k = [1,3,5,10,100,1000]
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
print(ndcg, _map, recall, precision )

{'NDCG@1': 0.0, 'NDCG@3': 0.0, 'NDCG@5': 0.0, 'NDCG@10': 0.0, 'NDCG@100': 0.00336, 'NDCG@1000': 0.01034} {'MAP@1': 0.0, 'MAP@3': 0.0, 'MAP@5': 0.0, 'MAP@10': 0.0, 'MAP@100': 0.00034, 'MAP@1000': 0.00045} {'Recall@1': 0.0, 'Recall@3': 0.0, 'Recall@5': 0.0, 'Recall@10': 0.0, 'Recall@100': 0.02, 'Recall@1000': 0.07455} {'P@1': 0.0, 'P@3': 0.0, 'P@5': 0.0, 'P@10': 0.0, 'P@100': 0.0002, 'P@1000': 0.00013}


In [47]:
for decimal in range(12,-1,-1):
    # print(decimal, type(decimal))
    model = ExperiementRetrievalExactSearch(beir_sbert, decimal)
    retriever = EvaluateRetrieval(model, score_function="cos_sim") # or "cos_sim" for cosine similarity
    results = retriever.retrieve(corpus, queries)
    ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
    print("decimal_places:", decimal , ndcg, _map, recall, precision )

decimal_places: 12 {'NDCG@1': 0.0, 'NDCG@3': 0.0, 'NDCG@5': 0.0, 'NDCG@10': 0.0, 'NDCG@100': 0.00336, 'NDCG@1000': 0.01034} {'MAP@1': 0.0, 'MAP@3': 0.0, 'MAP@5': 0.0, 'MAP@10': 0.0, 'MAP@100': 0.00034, 'MAP@1000': 0.00045} {'Recall@1': 0.0, 'Recall@3': 0.0, 'Recall@5': 0.0, 'Recall@10': 0.0, 'Recall@100': 0.02, 'Recall@1000': 0.07455} {'P@1': 0.0, 'P@3': 0.0, 'P@5': 0.0, 'P@10': 0.0, 'P@100': 0.0002, 'P@1000': 0.00013}
decimal_places: 11 {'NDCG@1': 0.0, 'NDCG@3': 0.0, 'NDCG@5': 0.0, 'NDCG@10': 0.0, 'NDCG@100': 0.00336, 'NDCG@1000': 0.01034} {'MAP@1': 0.0, 'MAP@3': 0.0, 'MAP@5': 0.0, 'MAP@10': 0.0, 'MAP@100': 0.00034, 'MAP@1000': 0.00045} {'Recall@1': 0.0, 'Recall@3': 0.0, 'Recall@5': 0.0, 'Recall@10': 0.0, 'Recall@100': 0.02, 'Recall@1000': 0.07455} {'P@1': 0.0, 'P@3': 0.0, 'P@5': 0.0, 'P@10': 0.0, 'P@100': 0.0002, 'P@1000': 0.00013}
decimal_places: 10 {'NDCG@1': 0.0, 'NDCG@3': 0.0, 'NDCG@5': 0.0, 'NDCG@10': 0.0, 'NDCG@100': 0.00336, 'NDCG@1000': 0.01034} {'MAP@1': 0.0, 'MAP@3': 0.0, '