## Getting started (Adjust settings to your experiment's needs)

In [2]:
import os
import sys
import logging
from beir import util
from beir.datasets.data_loader import GenericDataLoader
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from ir_measures import nDCG, AP, P, R, RR
from IRutils import dataprocessor, models, train, inference
from IRutils.dataset import TripletRankingDataset
from transformers import AutoTokenizer

#################### THINGS TO CHANGE FOR YOUR EXPERIMENTS ####################

dataset_name = "scifact"  # SELECT YOUR EXPERIMENT DATASET HERE
model_name = "distilbert-base-uncased"  # SELECT YOUR MODEL HERE

# Create dataset for a specific query length range (e.g., short queries)
length_setting = 'medium'
ranges = {'short': (1, 6), 'medium': (7, 10), 'long': (11, sys.maxsize) }

metrics = [nDCG@10, nDCG@100, AP@10, AP@100, P@10, R@10, P@100, R@100, RR]

#################### THINGS TO CHANGE FOR YOUR EXPERIMENTS ####################

logging.disable(logging.WARNING)

datasets = {'msmarco': ['train', 'dev'],
            'hotpotqa': ['train', 'dev', 'test'],
            'arguana': ['test'],
            'quora': ['dev', 'test'],
            'scidocs': ['test'],  # small
            'fever': ['train', 'dev', 'test'],  # large
            'climate-fever': ['test'],
            'scifact': ['train', 'test'],
            'fiqa': ['train', 'dev', 'test'],
            'nfcorpus': ['train', 'dev', 'test']
            }

max_len_doc = 512  # max token length
random_state = 42

  from tqdm.autonotebook import tqdm


  0%|          | 0/5183 [00:00<?, ?it/s]

  0%|          | 0/5183 [00:00<?, ?it/s]

In [None]:
# Download and unzip the dataset
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
data_path = util.download_and_unzip(url, "datasets")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_available = False
if 'train' in datasets[dataset_name]:
    # Load the dataset
    docs, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="train")
    docs_test, queries_test, qrels_test = GenericDataLoader(data_folder=data_path).load(split="test")
    train_available = True
else:
    # Load the dataset
    docs, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

In [3]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

### Get test split (if no seperate test set available)

In [4]:
range = ranges[length_setting]

dp = dataprocessor.DataProcessor(queries, docs, qrels)

print(f'Dataset size: {len(queries)}')

# first seperate the test set (include queries of all lengths)
if not train_available:
    query_test, qrel_test = dp.get_testset(test_ratio=0.2, random_state=random_state)
    print(f'test size: {len(query_test)}')
else:
    print(f'test size: {len(queries_test)}')

Dataset size: 809
test size: 300


### Query length filtering

In [5]:
# filter by query length
query_subset, qrels_subset = dp.get_subset(range[0], range[1])  # Adjust min/max length

### Train Val split

In [6]:
print(f'Example query from {length_setting} subset:\n{query_subset.popitem()}')

query_train, query_val, qrel_train, qrel_val = dp.train_val_split(train_ratio=0.8, val_ratio=0.2, train_available=train_available, random_state=42)  # adjust if needed

print(f'Length of subset of medium validation queries: {len(query_val)}')
print(f'Length of subset of {length_setting} training queries: {len(query_train)}')
print(f'Length of subset of {length_setting} queries: {len(query_subset)}')

Example query from medium subset:
('1406', 'β-sheet opening occurs during pleurotolysin pore formation.')
Length of subset of medium validation queries: 53
Length of subset of medium training queries: 210
Length of subset of medium queries: 263


### Check if the qrels already contain negative samples （If not create later)

In [7]:
qrel_scores = list(qrels.values()) 
relevance_scores = [list(item.values()) for item in qrel_scores]
num_negatives = relevance_scores[0].count(0)
print(num_negatives)

0


### Create datasets and data loaders

In [8]:
print('Creating training dataset...')
train_dataset = TripletRankingDataset(query_train, docs, qrel_train, tokenizer, num_negatives, max_length=max_len_doc)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

print('Creating validation dataset...')
val_dataset = TripletRankingDataset(query_val, docs, qrel_val, tokenizer, num_negatives,max_length=max_len_doc)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

print('Creating testing dataset...')
if train_available:
    test_dataset = TripletRankingDataset(queries_test, docs_test, qrels_test, tokenizer, num_negatives,max_length=max_len_doc)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)
else:
    test_dataset = TripletRankingDataset(query_test, docs, qrel_test, tokenizer, num_negatives,max_length=max_len_doc) 

Creating training dataset...


100%|██████████| 210/210 [00:00<00:00, 1381.74it/s]


Creating validation dataset...


100%|██████████| 53/53 [00:00<00:00, 5058.55it/s]


Creating testing dataset...


100%|██████████| 300/300 [00:00<00:00, 5213.96it/s]


### Initialize model

In [9]:
# Initialize model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.TripletRankerModel(model_name).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)
model_path = os.path.join(os.getcwd(), f'models/{model_name}+{dataset_name}+{length_setting}_queries.pth')

### Train model (load directly if already trained)

In [10]:
print(model_path)
if os.path.isfile(model_path):
    model.load_state_dict(torch.load(model_path, map_location=device))
else:
    # Train the model
    model = train.train_triplet_ranker(model, train_loader, val_loader, optimizer, device, model_path)

C:\Users\chena\PycharmProjects\IR-rankingmodels\models/distilbert-base-uncased+scifact+medium_queries.pth


Epoch 1/10 (Training):   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 1/10, Average Training Loss: 0.6046


Validation: 100%|██████████| 153/153 [00:16<00:00,  9.34it/s]


Validation Loss: 0.5743
Validation loss improved. Saving model.


Epoch 2/10 (Training):   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 2/10, Average Training Loss: 0.5425


Validation: 100%|██████████| 153/153 [00:16<00:00,  9.38it/s]


Validation Loss: 0.5457
Validation loss improved. Saving model.


Epoch 3/10 (Training):   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 3/10, Average Training Loss: 0.4183


Validation: 100%|██████████| 153/153 [00:16<00:00,  9.46it/s]


Validation Loss: 0.3761
Validation loss improved. Saving model.


Epoch 4/10 (Training):   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 4/10, Average Training Loss: 0.1768


Validation: 100%|██████████| 153/153 [00:16<00:00,  9.40it/s]


Validation Loss: 0.3640
Validation loss improved. Saving model.


Epoch 5/10 (Training):   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 5/10, Average Training Loss: 0.1320


Validation: 100%|██████████| 153/153 [00:16<00:00,  9.43it/s]

Validation Loss: 0.4467
Validation loss did not improve. Patience: 1/3





Epoch 6/10 (Training):   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 6/10, Average Training Loss: 0.0807


Validation: 100%|██████████| 153/153 [00:16<00:00,  9.42it/s]

Validation Loss: 0.3974
Validation loss did not improve. Patience: 2/3





Epoch 7/10 (Training):   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 7/10, Average Training Loss: 0.0993


Validation: 100%|██████████| 153/153 [00:16<00:00,  9.44it/s]


Validation Loss: 0.2897
Validation loss improved. Saving model.


Epoch 8/10 (Training):   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 8/10, Average Training Loss: 0.0526


Validation: 100%|██████████| 153/153 [00:16<00:00,  9.37it/s]


Validation Loss: 0.2201
Validation loss improved. Saving model.


Epoch 9/10 (Training):   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 9/10, Average Training Loss: 0.0736


Validation: 100%|██████████| 153/153 [00:16<00:00,  9.43it/s]

Validation Loss: 0.3975
Validation loss did not improve. Patience: 1/3





Epoch 10/10 (Training):   0%|          | 0/595 [00:00<?, ?it/s]

Epoch 10/10, Average Training Loss: 0.0641


Validation: 100%|██████████| 153/153 [00:16<00:00,  9.38it/s]


Validation Loss: 0.2573
Validation loss did not improve. Patience: 2/3
Loaded best model based on validation loss.
Training complete!


## Run inference on test set

In [11]:
# Example usage (replace with your data and model)
if train_available:
    metric_scores = inference.evaluate(model, test_loader, device, qrels_test)
else:
    metric_scores = inference.evaluate(model, test_loader, device, qrel_test)

for metric in metrics:
    print(f'Metric {metric} score: {metric_scores[metric]:.4f}')

Evaluating: 100%|██████████| 848/848 [01:32<00:00,  9.20it/s]


Metric nDCG@10 score: 0.6881
Metric nDCG@100 score: 0.7235
Metric AP@10 score: 0.6300
Metric AP@100 score: 0.6392
Metric P@10 score: 0.0947
Metric R@10 score: 0.8586
Metric P@100 score: 0.0113
Metric R@100 score: 1.0000
Metric RR score: 0.6447


## Write results to output

In [12]:
# Save results to a file
with open(f"results/{model_name}+{dataset_name}+{length_setting}_queries.txt", "w") as f:
    f.write(f"Evaluation Results for {model_name} model finetuned on {length_setting} queries from {dataset_name} dataset:\n")
    f.write(f"normalized Discounted Cumulative Gain@10: {metric_scores[nDCG@10]:.4f}\n")
    f.write(f"normalized Discounted Cumulative Gain@100: {metric_scores[nDCG@100]:.4f}\n")
    f.write(f"\n")
    f.write(f"[Mean] Average Precision@10: {metric_scores[AP@10]:.4f}\n")
    f.write(f"[Mean] Average Precision@100: {metric_scores[AP@100]:.4f}\n")
    f.write(f"\n")
    f.write(f"Precision@10: {metric_scores[P@10]:.4f}\n")
    f.write(f"Recall@10: {metric_scores[R@10]:.4f}\n")
    f.write(f"\n")
    f.write(f"Precision@100: {metric_scores[P@100]:.4f}\n")
    f.write(f"Recall@100: {metric_scores[R@100]:.4f}\n")
    f.write(f"\n")
    f.write(f"[Mean] Reciprocal Rank: {metric_scores[RR]:.4f}\n")
    f.write(f"\n")
    f.write(f"----------------------------------------------------\n")
    f.write(f"\n")
    f.write(f"Explanation of metrics:\n")
    f.write(f"NDCG@k (Normalized Discounted Cumulative Gain: Ranking Quality | Prioritizes highly relevant documents appearing earlier in the ranking.\n")
    f.write(f"MAP (Mean Average Precision): Overall Relevance | Measures ranking precision across all relevant documents. Best for small-scale retrieval tasks.\n")
    f.write(f"Precision@k: Relevance | Measures how many of the top-k documents are relevant. Works well in precision-sensitive applications.\n")
    f.write(f"Recall@k: Coverage | Measures how many relevant documents appear in the top-k results. Important in recall-sensitive tasks.\n")
    f.write(f"MRR (Mean Reciprocal Rank): Single Relevant Result | Focuses on ranking the first relevant document. Good for QA tasks.\n")