In [1]:
import os
import sys
import logging
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from torch.utils.data import Dataset
import torch
from transformers import DistilBertTokenizer
from torch.utils.data import DataLoader
import torch.nn as nn
from transformers import DistilBertModel
from torch.optim import AdamW
from ir_measures import nDCG, AP, P, R, RR
from IRutils import dataprocessor, models, train, inference
from IRutils.dataset import TripletRankingDataset

logging.disable(logging.WARNING)

# Download and unzip the dataset
dataset_name = "scidocs"
model_name = "distilbert-base-uncased"

# Create dataset for a specific query length range (e.g., short queries)
length_setting = 'medium'
ranges = {'short': (1, 6), 'medium': (7, 10), 'long': (11, sys.maxsize) }

metrics = [nDCG@10, nDCG@100, AP@10, AP@100, P@10, R@10, P@100, R@100, RR]

max_len_doc = 512  # max token length

url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
data_path = util.download_and_unzip(url, "datasets")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the dataset
docs, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

  from tqdm.autonotebook import tqdm


  0%|          | 0/25657 [00:00<?, ?it/s]

In [2]:
# Load the tokenizer
tokenizer = DistilBertTokenizer.from_pretrained(model_name)

In [3]:
range = ranges[length_setting]

dp = dataprocessor.DataProcessor(queries, docs, qrels)
query_subset, qrels_subset = dp.get_subset(range[0], range[1])  # Adjust min/max length

print(f'Example query from {length_setting} subset:\n{query_subset.popitem()}')

Example query from medium subset:
('89e58773fa59ef5b57f229832c2a1b3e3efff37e', 'Analyzing EEG signals to detect unexpected obstacles during walking')


In [4]:
query_train, query_val, query_test, qrel_train, qrel_val, qrel_test = dp.train_val_test_split(train_ratio=0.8, val_ratio=0.2, test_ratio = 0.3, random_state=42)  # adjust if needed

print(f'Total length of dataset: {len(queries)}')
print(f'Length of subset of {length_setting} queries: {len(query_subset)}')
print(f'Length of subset of {length_setting} training queries: {len(query_train)}')

Total length of dataset: 1000
Length of subset of medium queries: 481
Length of subset of medium training queries: 269


In [5]:
qrel_scores = list(qrels.values())
relevance_scores = [list(item.values()) for item in qrel_scores]
num_negatives = relevance_scores[0].count(0)
print(num_negatives)

25


In [6]:
train_dataset = TripletRankingDataset(query_train, docs, qrel_train, tokenizer, num_negatives, max_length=max_len_doc)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

val_dataset = TripletRankingDataset(query_val, docs, qrel_val, tokenizer, num_negatives,max_length=max_len_doc)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

test_dataset = TripletRankingDataset(query_test, docs, qrel_test, tokenizer, num_negatives,max_length=max_len_doc)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)

In [7]:
# Initialize model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.TripletRankerModel(model_name).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)
model_path = os.path.join(os.getcwd(), f'models/{model_name}+{dataset_name}+{length_setting}_queries.pth')

In [None]:
print(model_path)
if os.path.isfile(model_path):
    model.load_state_dict(torch.load(model_path, map_location=device))
else:
    # Train the model
    model = train.train_triplet_ranker(model, train_loader, val_loader, optimizer, device, model_path)

C:\Users\chena\PycharmProjects\IR-rankingmodels\models/distilbert-base-uncased+scidocs+medium_queries.pth


Epoch 1/10 (Training):   0%|          | 0/4157 [00:00<?, ?it/s]

In [None]:
# Example usage (replace with your data and model)
metric_scores = inference.evaluate(model, test_loader, device, qrel_test)

for metric in metrics:
    print(f'Metric {metric} score: {metric_scores[metric]:.4f}')

In [None]:
# Save results to a file
with open(f"results/{model_name}+{dataset_name}+{length_setting}_queries.txt", "w") as f:
    f.write(f"Evaluation Results for {model_name} model finetuned on {length_setting} queries from {dataset_name} dataset:\n")
    f.write(f"normalized Discounted Cumulative Gain@10: {metric_scores[nDCG@10]:.4f}\n")
    f.write(f"normalized Discounted Cumulative Gain@100: {metric_scores[nDCG@100]:.4f}\n")
    f.write(f"\n")
    f.write(f"[Mean] Average Precision@10: {metric_scores[AP@10]:.4f}\n")
    f.write(f"[Mean] Average Precision@100: {metric_scores[AP@100]:.4f}\n")
    f.write(f"\n")
    f.write(f"Precision@10: {metric_scores[P@10]:.4f}\n")
    f.write(f"Recall@10: {metric_scores[R@10]:.4f}\n")
    f.write(f"\n")
    f.write(f"Precision@100: {metric_scores[P@100]:.4f}\n")
    f.write(f"Recall@100: {metric_scores[R@100]:.4f}\n")
    f.write(f"\n")
    f.write(f"[Mean] Reciprocal Rank: {metric_scores[RR]:.4f}\n")
    f.write(f"\n")
    f.write(f"----------------------------------------------------\n")
    f.write(f"\n")
    f.write(f"Explanation of metrics:\n")
    f.write(f"NDCG@k (Normalized Discounted Cumulative Gain: Ranking Quality | Prioritizes highly relevant documents appearing earlier in the ranking.\n")
    f.write(f"MAP (Mean Average Precision): Overall Relevance | Measures ranking precision across all relevant documents. Best for small-scale retrieval tasks.\n")
    f.write(f"Precision@k: Relevance | Measures how many of the top-k documents are relevant. Works well in precision-sensitive applications.\n")
    f.write(f"Recall@k: Coverage | Measures how many relevant documents appear in the top-k results. Important in recall-sensitive tasks.\n")
    f.write(f"MRR (Mean Reciprocal Rank): Single Relevant Result | Focuses on ranking the first relevant document. Good for QA tasks.\n")