In [1]:
! pip3 install transformers=='4.48.3'

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [2]:
! pip install beir=='2.0.0'

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [3]:
! nvidia-smi

Wed Mar 19 10:14:58 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.77                 Driver Version: 565.77         CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A5000               On  |   00000000:03:00.0 Off |                  Off |
|  0%   21C    P8             25W /  230W |       2MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [4]:
import os
import pathlib
import logging
from datetime import timedelta
from typing import List, Dict, Union, Tuple

import numpy as np
import torch
from torch import Tensor
import torch.distributed as dist
from tqdm import trange
from transformers import AutoTokenizer, AutoModel
from transformers.file_utils import PaddingStrategy

from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

import torch
import numpy as np
from typing import List, Dict
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES


# Configure logging
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)


class ContrieverDenseRetriever:
    def __init__(self, model_path="facebook/contriever-msmarco"):
        """Initialize the Contriever model for both query and document encoding"""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load Model
        self.model = AutoModel.from_pretrained(model_path).to(self.device)
        self.model.eval()

        self.tokenizer = AutoTokenizer.from_pretrained(model_path)

    def mean_pooling(self, token_embeddings, mask):
        """Applies mean pooling to get sentence-level embeddings"""
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        return sentence_embeddings

    def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> np.ndarray:
        """Encodes queries into dense embeddings"""
        query_embeddings = []

        with torch.no_grad():
            for start_idx in trange(0, len(queries), batch_size):
                encoded = self.tokenizer(
                    queries[start_idx:start_idx+batch_size], truncation=True, padding=True, return_tensors='pt', max_length=512
                ).to(self.device)

                outputs = self.model(**encoded)
                embeds = self.mean_pooling(outputs.last_hidden_state, encoded['attention_mask'])
                query_embeddings.append(embeds.cpu().numpy())

        return np.vstack(query_embeddings)

    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 16, **kwargs) -> np.ndarray:
        """Encodes articles (title + abstract) into dense embeddings"""
        corpus_embeddings = []

        with torch.no_grad():
            for start_idx in trange(0, len(corpus), batch_size):
                titles = [row.get('title', '') for row in corpus[start_idx: start_idx + batch_size]]
                texts = [row.get('text', '') for row in corpus[start_idx: start_idx + batch_size]]

                encoded = self.tokenizer(
                    titles, texts, truncation='longest_first', padding=True, return_tensors='pt', max_length=512
                ).to(self.device)

                outputs = self.model(**encoded)
                embeds = self.mean_pooling(outputs.last_hidden_state, encoded['attention_mask'])
                corpus_embeddings.append(embeds.cpu().numpy())

        return np.vstack(corpus_embeddings)

In [5]:
contriever = DRES(ContrieverDenseRetriever())
retriever_contriever = EvaluateRetrieval(contriever, score_function="dot") # or "cos_sim" for cosine similarity

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [6]:
# Download and load dataset
dataset = "nfcorpus" # dataset name
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
out_dir = "datasets"
data_path = util.download_and_unzip(url, out_dir)

In [7]:
# Load corpus, queries, and qrels
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

2025-03-19 10:21:16 - Loading Corpus...


  0%|          | 0/3633 [00:00<?, ?it/s]

2025-03-19 10:21:16 - Loaded 3633 TEST Documents.
2025-03-19 10:21:16 - Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participants die

In [8]:
import json
# load paraphrased dataset
with open(f"{dataset}_query_paraphrased_gpt4o.json", encoding='utf-8') as f:
    # Load the JSON data into a Python dictionary
    queries_para = json.load(f)

In [9]:
queries_p = {}
for q in queries_para:
  queries_p[q] = queries_para[q]['query_p']

In [10]:
results_contriever = retriever_contriever.retrieve(corpus, queries_p)

2025-03-19 10:21:38 - Encoding Queries...


100%|██████████| 3/3 [00:00<00:00,  7.04it/s]


2025-03-19 10:21:38 - Sorting Corpus by document length (Longest first)...
2025-03-19 10:21:38 - Scoring Function: Dot Product (dot)
2025-03-19 10:21:38 - Encoding Batch 1/1...


100%|██████████| 29/29 [00:27<00:00,  1.05it/s]


In [11]:
# Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
ndcg, _map, recall, precision = retriever_contriever.evaluate(qrels, results_contriever, retriever_contriever.k_values)

2025-03-19 10:22:06 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2025-03-19 10:22:06 - 

2025-03-19 10:22:06 - NDCG@1: 0.3669
2025-03-19 10:22:06 - NDCG@3: 0.3306
2025-03-19 10:22:06 - NDCG@5: 0.3142
2025-03-19 10:22:06 - NDCG@10: 0.2882
2025-03-19 10:22:06 - NDCG@100: 0.2725
2025-03-19 10:22:06 - NDCG@1000: 0.3667
2025-03-19 10:22:06 - 

2025-03-19 10:22:06 - MAP@1: 0.0459
2025-03-19 10:22:06 - MAP@3: 0.0723
2025-03-19 10:22:06 - MAP@5: 0.0864
2025-03-19 10:22:06 - MAP@10: 0.1008
2025-03-19 10:22:06 - MAP@100: 0.1283
2025-03-19 10:22:06 - MAP@1000: 0.1427
2025-03-19 10:22:06 - 

2025-03-19 10:22:06 - Recall@1: 0.0459
2025-03-19 10:22:06 - Recall@3: 0.0814
2025-03-19 10:22:06 - Recall@5: 0.1076
2025-03-19 10:22:06 - Recall@10: 0.1368
2025-03-19 10:22:06 - Recall@100: 0.2821
2025-03-19 10:22:06 - Recall@1000: 0.6219
2025-03-19 10:22:06 - 

2025-03-19 10:22:06 - P@1: 0.3808
2025-03-19 10:22:06

In [12]:
print(f"Model: Contriever; Dataset: {dataset} (paraphrased)")
print("-" * 150)
print(ndcg)
print(_map)
print(recall)
print(precision)

Model: Contriever; Dataset: nfcorpus (paraphrased)
------------------------------------------------------------------------------------------------------------------------------------------------------
{'NDCG@1': 0.36687, 'NDCG@3': 0.33062, 'NDCG@5': 0.31419, 'NDCG@10': 0.28821, 'NDCG@100': 0.27245, 'NDCG@1000': 0.36667}
{'MAP@1': 0.04593, 'MAP@3': 0.07234, 'MAP@5': 0.08641, 'MAP@10': 0.10076, 'MAP@100': 0.12834, 'MAP@1000': 0.14266}
{'Recall@1': 0.04593, 'Recall@3': 0.08135, 'Recall@5': 0.10763, 'Recall@10': 0.13679, 'Recall@100': 0.28206, 'Recall@1000': 0.6219}
{'P@1': 0.3808, 'P@3': 0.31269, 'P@5': 0.27492, 'P@10': 0.21765, 'P@100': 0.07288, 'P@1000': 0.02046}


In [13]:
# Download and load dataset
dataset = "scifact" # dataset name
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
out_dir = "datasets"
data_path = util.download_and_unzip(url, out_dir)

In [14]:
# Load corpus, queries, and qrels
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

2025-03-19 10:22:06 - Loading Corpus...


  0%|          | 0/5183 [00:00<?, ?it/s]

2025-03-19 10:22:06 - Loaded 5183 TEST Documents.
2025-03-19 10:22:06 - Doc Example: {'text': 'Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of prematurity on cerebral white matter development, early gestation preterm infants (n = 10) were studied a second time at term. In the central white matter the mean apparent diffusion coefficient at 28 wk was high, 1.8 microm2/ms, and decreased toward term to 1.2 microm2/ms. In the posterior limb of the internal capsule, the mean apparent diffusion coefficients at both times were similar (1.2 vers

In [15]:
import json
# load paraphrased dataset
with open(f"{dataset}_query_paraphrased_gpt4o.json", encoding='utf-8') as f:
    # Load the JSON data into a Python dictionary
    queries_para = json.load(f)

In [16]:
queries_p = {}
for q in queries_para:
  queries_p[q] = queries_para[q]['query_p']

In [17]:
results_contriever = retriever_contriever.retrieve(corpus, queries_p)

2025-03-19 10:22:34 - Encoding Queries...


100%|██████████| 3/3 [00:00<00:00, 11.39it/s]


2025-03-19 10:22:34 - Sorting Corpus by document length (Longest first)...
2025-03-19 10:22:34 - Scoring Function: Dot Product (dot)
2025-03-19 10:22:34 - Encoding Batch 1/1...


100%|██████████| 41/41 [00:37<00:00,  1.11it/s]


In [18]:
# Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
ndcg, _map, recall, precision = retriever_contriever.evaluate(qrels, results_contriever, retriever_contriever.k_values)

2025-03-19 10:23:11 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2025-03-19 10:23:11 - 

2025-03-19 10:23:11 - NDCG@1: 0.5433
2025-03-19 10:23:11 - NDCG@3: 0.6295
2025-03-19 10:23:11 - NDCG@5: 0.6516
2025-03-19 10:23:11 - NDCG@10: 0.6747
2025-03-19 10:23:11 - NDCG@100: 0.7015
2025-03-19 10:23:11 - NDCG@1000: 0.7084
2025-03-19 10:23:11 - 

2025-03-19 10:23:11 - MAP@1: 0.5192
2025-03-19 10:23:11 - MAP@3: 0.5991
2025-03-19 10:23:11 - MAP@5: 0.6160
2025-03-19 10:23:11 - MAP@10: 0.6262
2025-03-19 10:23:11 - MAP@100: 0.6319
2025-03-19 10:23:11 - MAP@1000: 0.6321
2025-03-19 10:23:11 - 

2025-03-19 10:23:11 - Recall@1: 0.5192
2025-03-19 10:23:11 - Recall@3: 0.6892
2025-03-19 10:23:11 - Recall@5: 0.7433
2025-03-19 10:23:11 - Recall@10: 0.8118
2025-03-19 10:23:11 - Recall@100: 0.9360
2025-03-19 10:23:11 - Recall@1000: 0.9900
2025-03-19 10:23:11 - 

2025-03-19 10:23:11 - P@1: 0.5433
2025-03-19 10:23:11

In [19]:
print(f"Model: Contriever; Dataset: {dataset} (paraphrased)")
print("-" * 150)
print(ndcg)
print(_map)
print(recall)
print(precision)

Model: Contriever; Dataset: scifact (paraphrased)
------------------------------------------------------------------------------------------------------------------------------------------------------
{'NDCG@1': 0.54333, 'NDCG@3': 0.6295, 'NDCG@5': 0.65159, 'NDCG@10': 0.67473, 'NDCG@100': 0.70149, 'NDCG@1000': 0.70842}
{'MAP@1': 0.51917, 'MAP@3': 0.59915, 'MAP@5': 0.61597, 'MAP@10': 0.62622, 'MAP@100': 0.63185, 'MAP@1000': 0.63213}
{'Recall@1': 0.51917, 'Recall@3': 0.68922, 'Recall@5': 0.74333, 'Recall@10': 0.81178, 'Recall@100': 0.936, 'Recall@1000': 0.99}
{'P@1': 0.54333, 'P@3': 0.24889, 'P@5': 0.16533, 'P@10': 0.091, 'P@100': 0.0106, 'P@1000': 0.00112}


In [20]:
# Download and load dataset
dataset = "trec-covid" # dataset name
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
out_dir = "datasets"
data_path = util.download_and_unzip(url, out_dir)

In [21]:
# Load corpus, queries, and qrels
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

2025-03-19 10:23:11 - Loading Corpus...


  0%|          | 0/171332 [00:00<?, ?it/s]

2025-03-19 10:23:13 - Loaded 171332 TEST Documents.
2025-03-19 10:23:13 - Doc Example: {'text': 'OBJECTIVE: This retrospective chart review describes the epidemiology and clinical features of 40 patients with culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia. METHODS: Patients with positive M. pneumoniae cultures from respiratory specimens from January 1997 through December 1998 were identified through the Microbiology records. Charts of patients were reviewed. RESULTS: 40 patients were identified, 33 (82.5%) of whom required admission. Most infections (92.5%) were community-acquired. The infection affected all age groups but was most common in infants (32.5%) and pre-school children (22.5%). It occurred year-round but was most common in the fall (35%) and spring (30%). More than three-quarters of patients (77.5%) had comorbidities. Twenty-four isolates (60%) were associated with pneumonia, 14 (35%) with upper respiratory tract 

In [22]:
import json
# load paraphrased dataset
with open(f"{dataset}_query_paraphrased_gpt4o.json", encoding='utf-8') as f:
    # Load the JSON data into a Python dictionary
    queries_para = json.load(f)

In [23]:
queries_p = {}
for q in queries_para:
  queries_p[q] = queries_para[q]['query_p']

In [24]:
results_contriever = retriever_contriever.retrieve(corpus, queries_p)

2025-03-19 10:23:30 - Encoding Queries...


100%|██████████| 1/1 [00:00<00:00, 17.27it/s]

2025-03-19 10:23:30 - Sorting Corpus by document length (Longest first)...





2025-03-19 10:23:30 - Scoring Function: Dot Product (dot)
2025-03-19 10:23:30 - Encoding Batch 1/4...


100%|██████████| 391/391 [07:36<00:00,  1.17s/it]


2025-03-19 10:31:07 - Encoding Batch 2/4...


100%|██████████| 391/391 [06:04<00:00,  1.07it/s]


2025-03-19 10:37:12 - Encoding Batch 3/4...


100%|██████████| 391/391 [02:24<00:00,  2.70it/s]


2025-03-19 10:39:37 - Encoding Batch 4/4...


100%|██████████| 167/167 [00:09<00:00, 18.25it/s]


In [25]:
# Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
ndcg, _map, recall, precision = retriever_contriever.evaluate(qrels, results_contriever, retriever_contriever.k_values)

2025-03-19 10:39:46 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2025-03-19 10:39:46 - 

2025-03-19 10:39:46 - NDCG@1: 0.5100
2025-03-19 10:39:46 - NDCG@3: 0.4864
2025-03-19 10:39:46 - NDCG@5: 0.4832
2025-03-19 10:39:46 - NDCG@10: 0.4527
2025-03-19 10:39:46 - NDCG@100: 0.3382
2025-03-19 10:39:46 - NDCG@1000: 0.3239
2025-03-19 10:39:46 - 

2025-03-19 10:39:46 - MAP@1: 0.0014
2025-03-19 10:39:46 - MAP@3: 0.0040
2025-03-19 10:39:46 - MAP@5: 0.0061
2025-03-19 10:39:46 - MAP@10: 0.0103
2025-03-19 10:39:46 - MAP@100: 0.0495
2025-03-19 10:39:46 - MAP@1000: 0.1299
2025-03-19 10:39:46 - 

2025-03-19 10:39:46 - Recall@1: 0.0014
2025-03-19 10:39:46 - Recall@3: 0.0044
2025-03-19 10:39:46 - Recall@5: 0.0070
2025-03-19 10:39:46 - Recall@10: 0.0128
2025-03-19 10:39:46 - Recall@100: 0.0798
2025-03-19 10:39:46 - Recall@1000: 0.3137
2025-03-19 10:39:46 - 

2025-03-19 10:39:46 - P@1: 0.5600
2025-03-19 10:39:46

In [26]:
print(f"Model: Contriever; Dataset: {dataset} (paraphrased)")
print("-" * 150)
print(ndcg)
print(_map)
print(recall)
print(precision)

Model: Contriever; Dataset: trec-covid (paraphrased)
------------------------------------------------------------------------------------------------------------------------------------------------------
{'NDCG@1': 0.51, 'NDCG@3': 0.48642, 'NDCG@5': 0.48318, 'NDCG@10': 0.45266, 'NDCG@100': 0.33818, 'NDCG@1000': 0.32393}
{'MAP@1': 0.00141, 'MAP@3': 0.00399, 'MAP@5': 0.00613, 'MAP@10': 0.01031, 'MAP@100': 0.04945, 'MAP@1000': 0.12994}
{'Recall@1': 0.00141, 'Recall@3': 0.0044, 'Recall@5': 0.007, 'Recall@10': 0.01278, 'Recall@100': 0.07982, 'Recall@1000': 0.3137}
{'P@1': 0.56, 'P@3': 0.54, 'P@5': 0.536, 'P@10': 0.488, 'P@100': 0.3518, 'P@1000': 0.15158}
