## Getting started (Adjust settings to your experiment's needs)

In [1]:
import os
import logging
import torch
from torch.optim import AdamW
from ir_measures import nDCG, AP, P, R, RR
from IRutils import models, train, inference
from IRutils.load_data import load, preprocess

#################### THINGS TO CHANGE FOR YOUR EXPERIMENTS ####################

dataset_name = "fiqa"  # SELECT YOUR EXPERIMENT DATASET HERE
model_name = "distilroberta-base"  # SELECT YOUR MODEL HERE
"""
Some options:
DistilBERT: "distilbert-base-uncased" (66M params)
BERT: "bert-base-uncased" (110M params) or "bert-large-uncased" (340M params)
RoBERTa: "roberta-base" (125M params) or "roberta-large" (355M params)
ALBERT: "albert-base-v2" (12M params) or "albert-xxlarge-v2" (235M params)
ELECTRA: "google/electra-small-generator" (14M params) or "google/electra-base-generator" (110M params)
DeBERTa: "microsoft/deberta-base" (140M params) or "microsoft/deberta-v3-base" (184M params)
MPNet: "microsoft/mpnet-base" (110M params)
XLM-RoBERTa: "xlm-roberta-base" (125M params) or "xlm-roberta-large" (355M params)
T5: "t5-small" (60M params) or "t5-base" (220M params)
BART: "facebook/bart-base" (140M params) or "facebook/bart-large" (406M params)
LongFormer: "allenai/longformer-base-4096" (149M params)

For the distilled/smaller variants that are closer to DistilBERT in size and speed:

TinyBERT: "huawei-noah/TinyBERT_General_4L_312D" (15M params)
MobileBERT: "google/mobilebert-uncased" (25M params)
DistilRoBERTa: "distilroberta-base" (82M params)
"""
# Create dataset for a specific query length range (e.g., short queries)
"""
Options: 
short - 0-33 percentile (length)
medium - 33-67 percentile (length)
long - 67-100 percentile (length)
full - all data (sampled to one third)
"""
length_setting = 'full'

metrics = [nDCG@10, nDCG@100, AP@10, AP@100, P@10, R@10, P@100, R@100, RR]

#################### THINGS TO CHANGE FOR YOUR EXPERIMENTS ####################

logging.disable(logging.WARNING)

max_len_doc = 512  # max token length
random_state = 42

In [2]:
train_available, docs, queries, qrels, docs_test, queries_test, qrels_test  = load(dataset_name)
print('Loading complete!')

  0%|          | 0/57638 [00:00<?, ?it/s]

  0%|          | 0/57638 [00:00<?, ?it/s]

Train and test set available!
Loading complete!


In [3]:
if train_available:
    train_loader, val_loader, test_loader, split_queries_test, split_qrels_test = preprocess(queries, docs, qrels, model_name, length_setting, train_available, 
                                                       queries_test=queries_test, docs_test=docs_test, qrels_test=qrels_test, 
                                                       max_len_doc=max_len_doc, random_state=random_state)
else:
    train_loader, val_loader, test_loader, split_queries_test, split_qrels_test = preprocess(queries, docs, qrels, model_name, length_setting, train_available, 
                                                       max_len_doc=max_len_doc, random_state=random_state)
    
print('Preprocessing complete!')

Dataset size: 5500
test size: 648
Example query from full subset:
('11104', 'Selling a stock for gain to offset other stock loss')
Length of subset of full validation queries: 1100
Length of subset of full training queries: 4399
Length of subset of full queries: 5499
Number of negatives in qrels: 0


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Creating training dataset...


100%|██████████| 4399/4399 [00:11<00:00, 384.19it/s]


Creating validation dataset...


100%|██████████| 1100/1100 [00:03<00:00, 358.31it/s]


Creating testing dataset...


100%|██████████| 648/648 [00:01<00:00, 393.18it/s]

Preprocessing complete!





### Initialize model

In [4]:
# Initialize model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.TripletRankerModel(model_name).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)

# Define model save dir
os.makedirs(f'models/{model_name}/{dataset_name}', exist_ok=True)
model_path = os.path.join(os.getcwd(), f'models/{model_name}/{dataset_name}/{length_setting}_queries.pth')

model.safetensors:   0%|          | 0.00/331M [00:00<?, ?B/s]

### Train model (load directly if already trained)

In [None]:
print(model_path)
if os.path.isfile(model_path):
    model.load_state_dict(torch.load(model_path, map_location=device))
else:
    # Train the model
    model = train.train_triplet_ranker(model, train_loader, val_loader, optimizer, device, model_path)

C:\Users\chena\PycharmProjects\IR-rankingmodels\models/distilroberta-base/fiqa/full_queries.pth


Epoch 1/10 (Training):   0%|          | 0/27880 [00:00<?, ?it/s]

## Run inference on test set

In [None]:
# Example usage (replace with your data and model)
if train_available:
    metric_scores = inference.evaluate(model, test_loader, device, qrels_test)
else:
    metric_scores = inference.evaluate(model, test_loader, device, split_qrels_test)

for metric in metrics:
    print(f'Metric {metric} score: {metric_scores[metric]:.4f}')

## Write results to output

In [None]:
save_dir = f"results/{model_name}/{dataset_name}"
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f'{length_setting}_queries.txt')

# Save results to a file
with open(save_path, "w") as f:
    f.write(f"Evaluation Results for {model_name} model finetuned on {length_setting} queries from {dataset_name} dataset:\n")
    f.write(f"normalized Discounted Cumulative Gain@10: {metric_scores[nDCG@10]:.4f}\n")
    f.write(f"normalized Discounted Cumulative Gain@100: {metric_scores[nDCG@100]:.4f}\n")
    f.write(f"\n")
    f.write(f"[Mean] Average Precision@10: {metric_scores[AP@10]:.4f}\n")
    f.write(f"[Mean] Average Precision@100: {metric_scores[AP@100]:.4f}\n")
    f.write(f"\n")
    f.write(f"Precision@10: {metric_scores[P@10]:.4f}\n")
    f.write(f"Recall@10: {metric_scores[R@10]:.4f}\n")
    f.write(f"\n")
    f.write(f"Precision@100: {metric_scores[P@100]:.4f}\n")
    f.write(f"Recall@100: {metric_scores[R@100]:.4f}\n")
    f.write(f"\n")
    f.write(f"[Mean] Reciprocal Rank: {metric_scores[RR]:.4f}\n")
    f.write(f"\n")
    f.write(f"----------------------------------------------------\n")
    f.write(f"\n")
    f.write(f"Explanation of metrics:\n")
    f.write(f"NDCG@k (Normalized Discounted Cumulative Gain: Ranking Quality | Prioritizes highly relevant documents appearing earlier in the ranking.\n")
    f.write(f"MAP (Mean Average Precision): Overall Relevance | Measures ranking precision across all relevant documents. Best for small-scale retrieval tasks.\n")
    f.write(f"Precision@k: Relevance | Measures how many of the top-k documents are relevant. Works well in precision-sensitive applications.\n")
    f.write(f"Recall@k: Coverage | Measures how many relevant documents appear in the top-k results. Important in recall-sensitive tasks.\n")
    f.write(f"MRR (Mean Reciprocal Rank): Single Relevant Result | Focuses on ranking the first relevant document. Good for QA tasks.\n")
    
print(f'Successfully written results to {save_path}.')