In [49]:
from beir import LoggingHandler
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval import models

import pandas as pd

import dataclasses
import json
import logging
import os
import pathlib

logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])



In [58]:
@dataclasses.dataclass
class BeirTopKSimilarityMetadata:
    dataset_name: str
    num_queries: int
    num_documents: int
    num_positive_annotations: int
    model_name: str
    top_k: bool
    include_text: bool

    def __init__(self, dataset_name, queries, corpus, qrels, model_name, top_k, include_text):
        self.dataset_name = dataset_name
        self.num_queries = len(queries)
        self.num_documents = len(corpus)
        self.num_positive_annotations = sum(len(rel_docids) for qid, rel_docids in qrels.items())
        self.model_name = model_name
        self.top_k = top_k
        self.include_text = include_text
        

def load_data(dataset_name, datasets_path=".."):
    url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset_name)
    out_dir = os.path.join("..", "datasets")
    data_path = util.download_and_unzip(url, out_dir)
    
    #### Provide the data_path where scifact has been downloaded and unzipped
    corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
    return data_path, corpus, queries, qrels


def compute_similarity(model_name, batch_size, score_function=None) -> dict[str, dict[str, float]]:
    if score_function is None:
        score_function = "dot"
    if score_function not in ("dot", "cos_sim"):
        raise ValueError(f"score_function must be either 'dot' or 'cos_sim'. Received: '{score_function}'")

    model = DRES(models.SentenceBERT(model_name, trust_remote_code=True), batch_size=batch_size, trust_remote_code=True)
    retriever = EvaluateRetrieval(model, score_function=score_function)
    q_doc_sim = retriever.retrieve(corpus, queries)
    return q_doc_sim


def get_top_k_similar_docs(doc_sim: dict[str, float], k: int) -> dict[str, float]:
    return dict(sorted(doc_sim.items(), key=lambda x: x[1], reverse=True)[:k])


def get_labels_similarities(query, top_k_similar_docs, qrels):
    labels = {}
    similarity_scores = {}
    for doc, similarity_score in top_k_similar_docs.items():
        if doc in qrels[query] and qrels[query][doc] == 1:
            labels[doc] = 1
        else:
            labels[doc] = 0
        similarity_scores[doc] = similarity_score
    return labels, similarity_scores
        

def get_top_k_data_frame(q_doc_sim, top_k, include_text=False):
    qids = []
    docids = []
    rels = []
    sims = []
    if include_text:
        q_texts = []
        title_texts = []
        doc_texts = []
    
    for qid, doc_sim in q_doc_sim.items():
        top_k_doc_scores = get_top_k_similar_docs(q_doc_sim[qid], top_k)
        docid_labels, similarities = get_labels_similarities(qid, top_k_doc_scores, qrels)
        for docid, rel in docid_labels.items():
            qids.append(qid)
            docids.append(docid)
            rels.append(rel)
            sims.append(similarities[docid])
            if include_text:
                q_texts.append(queries[qid])
                title_texts.append(corpus[docid].get("title", ""))
                doc_texts.append(corpus[docid].get("text", ""))
    data = {
        "qid": qids, 
        "docid": docids, 
        "rel": rels,
        "sim": sims,
    }
    
    if include_text:
        data.update({
            "query": q_texts,
            "title": title_texts,
            "corpus": doc_texts,
        })
    return pd.DataFrame(data)    


def main(dataset_name=None):
    datasets = [
        "trec-covid",
        "nq",
        "hotpotqa",
        "arguana",
        "webis-touche2020",
        "cqadupstack",
        "dbpedia-entity",
        "scidocs",
        "fever",
    ]
    # dataset_name = "scifact"
    model_name = "msmarco-distilbert-base-tas-b"
    batch_size = 64
    top_k = 500
    include_text = False

    def process_dataset(dataset_name):
        # Load dataset.
        logging.info("## Load dataset.")
        data_path, corpus, queries, qrels = load_data(dataset_name)
        # Compute similarity scores.
        logging.info("## Compute similarity scores.")
        q_doc_sim = compute_similarity(model_name, batch_size)    
        # Get DF.
        logging.info("## Get the dataframe.")
        df = get_top_k_data_frame(q_doc_sim, top_k, include_text)
        metadata = BeirTopKSimilarityMetadata(
            dataset_name, queries, corpus, qrels, model_name, top_k, include_text)

        # Write to files
        logging.info("## Write to files.")
        df.to_csv(data_path + ".csv")
        with open(data_path + "_metadata.json", "w") as fout:
            json.dump(dataclasses.asdict(metadata), fout, indent=4)

    
    if dataset_name is None:
        for dataset_idx, dataset_name in enumerate(datasets):
            logging.info(f"#### DATASET {dataset_idx}: {dataset_name}")
            process_dataset(dataset_idx, dataset_name)
    else:
        # Only process the specified dataset.
        logging.info(f"#### DATASET: {dataset_name}")
        process_dataset(dataset_name)


In [59]:
main(dataset_name="scifact")

2024-09-02 16:12:49 - #### DATASET: scifact
2024-09-02 16:12:49 - ## Load dataset.
2024-09-02 16:12:49 - Loading Corpus...


  0%|          | 0/5183 [00:00<?, ?it/s]

2024-09-02 16:12:49 - Loaded 5183 TEST Documents.
2024-09-02 16:12:49 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 vers

Batches:   0%|          | 0/5 [00:00<?, ?it/s]

2024-09-02 16:12:50 - Sorting Corpus by document length (Longest first)...
2024-09-02 16:12:50 - Scoring Function: Dot Product (dot)
2024-09-02 16:12:50 - Encoding Batch 1/1...


Batches:   0%|          | 0/81 [00:00<?, ?it/s]

2024-09-02 16:13:05 - ## Get the dataframe.
2024-09-02 16:13:05 - ## Write to files.
