In [1]:
! pip3 install transformers=='4.48.3'

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
! pip install beir=='2.0.0'

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [3]:
! nvidia-smi

Wed Mar 19 09:38:52 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.77                 Driver Version: 565.77         CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A5000               On  |   00000000:03:00.0 Off |                  Off |
|  0%   26C    P8             26W /  230W |       2MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
import os
import pathlib
import logging
from datetime import timedelta
from typing import List, Dict, Union, Tuple

import numpy as np
import torch
from torch import Tensor
import torch.distributed as dist
from tqdm import trange
from transformers import AutoTokenizer, AutoModel
from transformers.file_utils import PaddingStrategy

from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

import torch
import numpy as np
from typing import List, Dict
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES


# Configure logging
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)


class MedCPTDenseRetriever:
    def __init__(self, model_path_query="ncbi/MedCPT-Query-Encoder", model_path_corpus="ncbi/MedCPT-Article-Encoder"):
        """Initialize the MedCPT Query and Article Encoder"""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load Query Encoder
        self.bert_q = AutoModel.from_pretrained(model_path_query).to(self.device)
        self.bert_q.eval()

        # Load Article Encoder
        self.bert_d = AutoModel.from_pretrained(model_path_corpus).to(self.device)
        self.bert_d.eval()

        self.tokenizer = AutoTokenizer.from_pretrained(model_path_query)

    def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> np.ndarray:
        """Encodes queries into dense embeddings"""
        query_embeddings = []

        with torch.no_grad():
            for start_idx in trange(0, len(queries), batch_size):
                encoded = self.tokenizer(
                    queries[start_idx:start_idx+batch_size], truncation=True, padding=True, return_tensors='pt', max_length=512
                ).to(self.device)

                model_out = self.bert_q(**encoded)
                query_embeddings += model_out.last_hidden_state[:, 0, :].detach().cpu()

        return torch.stack(query_embeddings).numpy()

    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 32, **kwargs) -> np.ndarray:
        """Encodes articles (title + abstract) into dense embeddings"""
        corpus_embeddings = []

        with torch.no_grad():
            for start_idx in trange(0, len(corpus), batch_size):
                titles = [row.get('title', '') for row in corpus[start_idx: start_idx + batch_size]]
                texts = [row.get('text', '') for row in corpus[start_idx: start_idx + batch_size]]

                encoded = self.tokenizer(
                    titles, texts, truncation='longest_first', padding=True, return_tensors='pt', max_length=512
                ).to(self.device)

                model_out = self.bert_d(**encoded)
                corpus_embeddings += model_out.last_hidden_state[:, 0, :].detach().cpu()

        return torch.stack(corpus_embeddings).numpy()

In [10]:
medcpt = DRES(MedCPTDenseRetriever())
retriever_medcpt = EvaluateRetrieval(medcpt, score_function="dot") # or "cos_sim" for cosine similarity

In [6]:
# Download and load dataset
dataset = "nfcorpus" # dataset name
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
out_dir = "datasets"
data_path = util.download_and_unzip(url, out_dir)

In [7]:
# Load corpus, queries, and qrels
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

2025-03-19 09:42:37 - Loading Corpus...


  0%|          | 0/3633 [00:00<?, ?it/s]

2025-03-19 09:42:37 - Loaded 3633 TEST Documents.
2025-03-19 09:42:37 - Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants die

In [8]:
import json
# load paraphrased dataset
with open(f"{dataset}_query_paraphrased_gpt4o.json", encoding='utf-8') as f:
    # Load the JSON data into a Python dictionary
    queries_para = json.load(f)

In [9]:
queries_p = {}
for q in queries_para:
  queries_p[q] = queries_para[q]['query_p']

In [11]:
results_medcpt = retriever_medcpt.retrieve(corpus, queries_p)

2025-03-19 09:43:08 - Encoding Queries...


100%|██████████| 3/3 [00:00<00:00,  8.74it/s]


2025-03-19 09:43:08 - Sorting Corpus by document length (Longest first)...
2025-03-19 09:43:08 - Scoring Function: Dot Product (dot)
2025-03-19 09:43:08 - Encoding Batch 1/1...


100%|██████████| 29/29 [00:26<00:00,  1.10it/s]


In [12]:
# Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
ndcg, _map, recall, precision = retriever_medcpt.evaluate(qrels, results_medcpt, retriever_medcpt.k_values)

2025-03-19 09:43:43 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2025-03-19 09:43:43 - 

2025-03-19 09:43:43 - NDCG@1: 0.3111
2025-03-19 09:43:43 - NDCG@3: 0.3020
2025-03-19 09:43:43 - NDCG@5: 0.2895
2025-03-19 09:43:43 - NDCG@10: 0.2720
2025-03-19 09:43:43 - NDCG@100: 0.2799
2025-03-19 09:43:43 - NDCG@1000: 0.3728
2025-03-19 09:43:43 - 

2025-03-19 09:43:43 - MAP@1: 0.0364
2025-03-19 09:43:43 - MAP@3: 0.0644
2025-03-19 09:43:43 - MAP@5: 0.0756
2025-03-19 09:43:43 - MAP@10: 0.0924
2025-03-19 09:43:43 - MAP@100: 0.1291
2025-03-19 09:43:43 - MAP@1000: 0.1442
2025-03-19 09:43:43 - 

2025-03-19 09:43:43 - Recall@1: 0.0364
2025-03-19 09:43:43 - Recall@3: 0.0782
2025-03-19 09:43:43 - Recall@5: 0.0966
2025-03-19 09:43:43 - Recall@10: 0.1299
2025-03-19 09:43:43 - Recall@100: 0.3220
2025-03-19 09:43:43 - Recall@1000: 0.6537
2025-03-19 09:43:43 - 

2025-03-19 09:43:43 - P@1: 0.3220
2025-03-19 09:43:43

In [13]:
print(f"Model: MedCPT; Dataset: {dataset} (paraphrased)")
print("-" * 150)
print(ndcg)
print(_map)
print(recall)
print(precision)

Model: MedCPT; Dataset: nfcorpus (paraphrased)
------------------------------------------------------------------------------------------------------------------------------------------------------
{'NDCG@1': 0.31115, 'NDCG@3': 0.30205, 'NDCG@5': 0.28955, 'NDCG@10': 0.27197, 'NDCG@100': 0.27993, 'NDCG@1000': 0.37275}
{'MAP@1': 0.0364, 'MAP@3': 0.06437, 'MAP@5': 0.07563, 'MAP@10': 0.09239, 'MAP@100': 0.12911, 'MAP@1000': 0.14422}
{'Recall@1': 0.0364, 'Recall@3': 0.07819, 'Recall@5': 0.09665, 'Recall@10': 0.12994, 'Recall@100': 0.32201, 'Recall@1000': 0.65368}
{'P@1': 0.32198, 'P@3': 0.29205, 'P@5': 0.25944, 'P@10': 0.21115, 'P@100': 0.07913, 'P@1000': 0.02132}


In [14]:
# Download and load dataset
dataset = "scifact" # dataset name
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
out_dir = "datasets"
data_path = util.download_and_unzip(url, out_dir)

In [15]:
# Load corpus, queries, and qrels
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

2025-03-19 09:44:21 - Loading Corpus...


  0%|          | 0/5183 [00:00<?, ?it/s]

2025-03-19 09:44:21 - Loaded 5183 TEST Documents.
2025-03-19 09:44:21 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 vers

In [16]:
import json
# load paraphrased dataset
with open(f"{dataset}_query_paraphrased_gpt4o.json", encoding='utf-8') as f:
    # Load the JSON data into a Python dictionary
    queries_para = json.load(f)

In [17]:
queries_p = {}
for q in queries_para:
  queries_p[q] = queries_para[q]['query_p']

In [18]:
results_medcpt = retriever_medcpt.retrieve(corpus, queries_p)

2025-03-19 09:44:37 - Encoding Queries...


100%|██████████| 3/3 [00:00<00:00, 13.78it/s]


2025-03-19 09:44:37 - Sorting Corpus by document length (Longest first)...
2025-03-19 09:44:37 - Scoring Function: Dot Product (dot)
2025-03-19 09:44:37 - Encoding Batch 1/1...


100%|██████████| 41/41 [00:32<00:00,  1.25it/s]


In [19]:
# Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
ndcg, _map, recall, precision = retriever_medcpt.evaluate(qrels, results_medcpt, retriever_medcpt.k_values)

2025-03-19 09:45:10 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2025-03-19 09:45:10 - 

2025-03-19 09:45:10 - NDCG@1: 0.5133
2025-03-19 09:45:10 - NDCG@3: 0.6085
2025-03-19 09:45:10 - NDCG@5: 0.6415
2025-03-19 09:45:10 - NDCG@10: 0.6613
2025-03-19 09:45:10 - NDCG@100: 0.6994
2025-03-19 09:45:10 - NDCG@1000: 0.7019
2025-03-19 09:45:10 - 

2025-03-19 09:45:10 - MAP@1: 0.4890
2025-03-19 09:45:10 - MAP@3: 0.5757
2025-03-19 09:45:10 - MAP@5: 0.5968
2025-03-19 09:45:10 - MAP@10: 0.6059
2025-03-19 09:45:10 - MAP@100: 0.6157
2025-03-19 09:45:10 - MAP@1000: 0.6158
2025-03-19 09:45:10 - 

2025-03-19 09:45:10 - Recall@1: 0.4890
2025-03-19 09:45:10 - Recall@3: 0.6767
2025-03-19 09:45:10 - Recall@5: 0.7578
2025-03-19 09:45:10 - Recall@10: 0.8149
2025-03-19 09:45:10 - Recall@100: 0.9783
2025-03-19 09:45:10 - Recall@1000: 0.9967
2025-03-19 09:45:10 - 

2025-03-19 09:45:10 - P@1: 0.5133
2025-03-19 09:45:10

In [20]:
print(f"Model: MedCPT; Dataset: {dataset} (paraphrased)")
print("-" * 150)
print(ndcg)
print(_map)
print(recall)
print(precision)

Model: MedCPT; Dataset: scifact (paraphrased)
------------------------------------------------------------------------------------------------------------------------------------------------------
{'NDCG@1': 0.51333, 'NDCG@3': 0.6085, 'NDCG@5': 0.64155, 'NDCG@10': 0.66131, 'NDCG@100': 0.69944, 'NDCG@1000': 0.70188}
{'MAP@1': 0.489, 'MAP@3': 0.57566, 'MAP@5': 0.59682, 'MAP@10': 0.60587, 'MAP@100': 0.61569, 'MAP@1000': 0.6158}
{'Recall@1': 0.489, 'Recall@3': 0.67672, 'Recall@5': 0.75778, 'Recall@10': 0.81494, 'Recall@100': 0.97833, 'Recall@1000': 0.99667}
{'P@1': 0.51333, 'P@3': 0.24444, 'P@5': 0.16733, 'P@10': 0.09133, 'P@100': 0.01107, 'P@1000': 0.00113}


In [21]:
# Download and load dataset
dataset = "trec-covid" # dataset name
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
out_dir = "datasets"
data_path = util.download_and_unzip(url, out_dir)

In [22]:
# Load corpus, queries, and qrels
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

2025-03-19 09:45:26 - Loading Corpus...


  0%|          | 0/171332 [00:00<?, ?it/s]

2025-03-19 09:45:27 - Loaded 171332 TEST Documents.
2025-03-19 09:45:27 - Doc Example: {'text': 'OBJECTIVE: This retrospective chart review describes the epidemiology and clinical features of 40 patients with culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia. METHODS: Patients with positive M. pneumoniae cultures from respiratory specimens from January 1997 through December 1998 were identified through the Microbiology records. Charts of patients were reviewed. RESULTS: 40 patients were identified, 33 (82.5%) of whom required admission. Most infections (92.5%) were community-acquired. The infection affected all age groups but was most common in infants (32.5%) and pre-school children (22.5%). It occurred year-round but was most common in the fall (35%) and spring (30%). More than three-quarters of patients (77.5%) had comorbidities. Twenty-four isolates (60%) were associated with pneumonia, 14 (35%) with upper respiratory tract 

In [23]:
import json
# load paraphrased dataset
with open(f"{dataset}_query_paraphrased_gpt4o.json", encoding='utf-8') as f:
    # Load the JSON data into a Python dictionary
    queries_para = json.load(f)

In [24]:
queries_p = {}
for q in queries_para:
  queries_p[q] = queries_para[q]['query_p']

In [25]:
results_medcpt = retriever_medcpt.retrieve(corpus, queries_p)

2025-03-19 09:45:48 - Encoding Queries...


100%|██████████| 1/1 [00:00<00:00, 22.48it/s]

2025-03-19 09:45:48 - Sorting Corpus by document length (Longest first)...





2025-03-19 09:45:49 - Scoring Function: Dot Product (dot)
2025-03-19 09:45:49 - Encoding Batch 1/4...


100%|██████████| 391/391 [07:25<00:00,  1.14s/it]


2025-03-19 09:53:14 - Encoding Batch 2/4...


100%|██████████| 391/391 [05:33<00:00,  1.17it/s]


2025-03-19 09:58:48 - Encoding Batch 3/4...


100%|██████████| 391/391 [02:17<00:00,  2.84it/s]


2025-03-19 10:01:06 - Encoding Batch 4/4...


100%|██████████| 167/167 [00:08<00:00, 20.15it/s]


In [26]:
# Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
ndcg, _map, recall, precision = retriever_medcpt.evaluate(qrels, results_medcpt, retriever_medcpt.k_values)

2025-03-19 10:01:15 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2025-03-19 10:01:15 - 

2025-03-19 10:01:15 - NDCG@1: 0.4700
2025-03-19 10:01:15 - NDCG@3: 0.4564
2025-03-19 10:01:15 - NDCG@5: 0.4717
2025-03-19 10:01:15 - NDCG@10: 0.4454
2025-03-19 10:01:15 - NDCG@100: 0.3893
2025-03-19 10:01:15 - NDCG@1000: 0.4223
2025-03-19 10:01:15 - 

2025-03-19 10:01:15 - MAP@1: 0.0013
2025-03-19 10:01:15 - MAP@3: 0.0035
2025-03-19 10:01:15 - MAP@5: 0.0056
2025-03-19 10:01:15 - MAP@10: 0.0100
2025-03-19 10:01:15 - MAP@100: 0.0671
2025-03-19 10:01:15 - MAP@1000: 0.1938
2025-03-19 10:01:15 - 

2025-03-19 10:01:15 - Recall@1: 0.0013
2025-03-19 10:01:15 - Recall@3: 0.0039
2025-03-19 10:01:15 - Recall@5: 0.0067
2025-03-19 10:01:15 - Recall@10: 0.0126
2025-03-19 10:01:15 - Recall@100: 0.1027
2025-03-19 10:01:15 - Recall@1000: 0.4293
2025-03-19 10:01:15 - 

2025-03-19 10:01:15 - P@1: 0.5200
2025-03-19 10:01:15

In [27]:
print(f"Model: MedCPT; Dataset: {dataset} (paraphrased)")
print("-" * 150)
print(ndcg)
print(_map)
print(recall)
print(precision)

Model: MedCPT; Dataset: trec-covid (paraphrased)
------------------------------------------------------------------------------------------------------------------------------------------------------
{'NDCG@1': 0.47, 'NDCG@3': 0.45642, 'NDCG@5': 0.4717, 'NDCG@10': 0.44542, 'NDCG@100': 0.38933, 'NDCG@1000': 0.42232}
{'MAP@1': 0.00133, 'MAP@3': 0.00351, 'MAP@5': 0.0056, 'MAP@10': 0.00998, 'MAP@100': 0.06712, 'MAP@1000': 0.19384}
{'Recall@1': 0.00133, 'Recall@3': 0.00385, 'Recall@5': 0.00675, 'Recall@10': 0.01263, 'Recall@100': 0.1027, 'Recall@1000': 0.42929}
{'P@1': 0.52, 'P@3': 0.48, 'P@5': 0.516, 'P@10': 0.486, 'P@100': 0.4172, 'P@1000': 0.19526}
