In [1]:
import ir_datasets

import pyterrier as pt
print(pt.__version__)

pt.init()
from pathlib import Path
import re
import pandas as pd
from pyterrier.measures import RR, nDCG, MAP, MRR, Recall, Precision
from tqdm.notebook import tqdm

0.13.0


Java started and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]
java is now started automatically with default settings. To force initialisation early, run:
pt.java.init() # optional, forces java initialisation
  pt.init()


In [2]:
antique = ir_datasets.load("antique/train")
print(antique.docs[0])

GenericDoc(doc_id='2020338_0', text="A small group of politicians believed strongly that the fact that Saddam Hussien remained in power after the first Gulf War was a signal of weakness to the rest of the world, one that invited attacks and terrorism. Shortly after taking power with George Bush in 2000 and after the attack on 9/11, they were able to use the terrorist attacks to justify war with Iraq on this basis and exaggerated threats of the development of weapons of mass destruction. The military strength of the U.S. and the brutality of Saddam's regime led them to imagine that the military and political victory would be relatively easy.")


In [3]:
# OPTIONAL - Index the data

idx_path = Path.cwd() / "indices" / "antique_train"

def antique_gen(limit=500000):
    lastdoc = 0
    for elem in antique.docs:
        if lastdoc >= limit:
            break
        yield {
            "docno": elem.doc_id,
            "text": elem.text,
        }
        lastdoc += 1

max(len(text.encode("utf-8")) for _, text in antique.docs)

if not idx_path.exists() or not any(idx_path.iterdir()):
    indexer = pt.IterDictIndexer(
        str(idx_path),
        meta={
            "docno": 20,
            "text": 4096,
        },
        stemmer="porter",
        stopwords="terrier",
    )

    index_ref = indexer.index(antique_gen())
else:
    print("Indices already exist, skipping creation")

Indices already exist, skipping creation


In [4]:
# Define index paths
index_dir_antique = Path.cwd() / "indices" / "antique_train"
# Load the indexes
#index_msmarco = pt.IndexFactory.of(str(index_dir_msmarco))
index_antique = pt.IndexFactory.of(str(index_dir_antique))

# Use BM25 as the baseline retriever
#retriever_msmarco = pt.BatchRetrieve(index_msmarco, wmodel="BM25")
#retriever_antique = pt.BatchRetrieve(index_antique, wmodel="BM25")

In [5]:
# Convert qrels to DataFrames
#qrels_msmarco = pd.DataFrame(msmarco.qrels_iter())
qrels_antique = pd.DataFrame(antique.qrels_iter())

# Convert queries to DataFrames
#queries_msmarco = pd.DataFrame(msmarco.queries_iter())
queries_antique = pd.DataFrame(antique.queries_iter())

# Rename columns for PyTerrier compatibility
#qrels_msmarco.rename(columns={"query_id": "qid", "doc_id": "docno", "relevance": "label"}, inplace=True)
qrels_antique.rename(columns={"query_id": "qid", "doc_id": "docno", "relevance": "label"}, inplace=True)

#queries_msmarco.rename(columns={"query_id": "qid", "text": "query"}, inplace=True)
queries_antique.rename(columns={"query_id": "qid", "text": "query"}, inplace=True)

In [6]:
import torch
import re

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("CUDA Available:", torch.cuda.is_available())
print("Current Device:", torch.cuda.current_device())
print("Device Name:", torch.cuda.get_device_name(0))
print("Device Count:", torch.cuda.device_count())


Using device: cuda
CUDA Available: True
Current Device: 0
Device Name: NVIDIA GeForce RTX 3060
Device Count: 1


In [7]:
#Model imports
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Load the T5 model
MODEL_ID = "prhegde/t5-query-reformulation-RL"
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
model.eval()





T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [8]:
# Function to rewrite queries
def rewrite_query(query, nsent=1):
    input_ids = tokenizer(query, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        output = model.generate(input_ids, max_length=20, num_beams=1, do_sample=True)
    return tokenizer.decode(output[0], skip_special_tokens=True)

def clean_query(query):
    query = query.replace("\n", " ").replace("\r", " ").replace("?", "")
    query = query.encode("ascii", "ignore").decode()
    query = re.sub(r"[\"`]", "", query)
    query = re.sub(r"\s+", " ", query).strip()
    query = re.sub(r"[^\w\s]", "", query)
    return query
queries_unchanged = queries_antique[["qid", "query"]]#.copy(deep=True)

queries_unchanged["query"] = queries_unchanged["query"].apply(clean_query)
# queries_antique["rewritten_query"] = queries_antique["query"].apply(rewrite_query).apply(clean_query)

queries_antique["rewritten_query"] = [
    clean_query(rewrite_query(query))
    for query in tqdm(queries_antique["query"], desc="rewriting queries")
]


rewriting queries:   0%|          | 0/2426 [00:00<?, ?it/s]

In [9]:
bm25_antique = pt.terrier.Retriever(index_antique, wmodel="BM25")


In [10]:
print(queries_unchanged['query'].head(1))


print("Unchanged queries")


# results_unchanged = pt.Experiment(
#     [bm25_antique],
#     queries_unchanged[["qid", "query"]],  # Use rewritten queries
#     qrels_antique,
#     eval_metrics=[RR @ 10, nDCG @ 20, MAP],
# )
queries_antique['query'] = queries_antique['rewritten_query']
print(queries_antique['query'].head(1))

print("Rewritten queries")

results_rewritten = pt.Experiment(
    [bm25_antique],
    queries_antique[["qid", "query"]],  # Use rewritten queries
    qrels_antique,
    eval_metrics=[RR @ 10, nDCG @ 20, MAP],
)


from tabulate import tabulate

# Optionally, print the results of both experiments using tabulate
# print("Results for unchanged queries:")
# print(tabulate(results_unchanged, headers='keys', tablefmt='pretty', showindex=False))

print("\nResults for rewritten queries:")
print(tabulate(results_rewritten, headers='keys', tablefmt='pretty', showindex=False))


0    What causes severe swelling and pain in the knees
Name: query, dtype: object
Unchanged queries
0    causes for severe swelling of the knees
Name: query, dtype: object
Rewritten queries

Results for rewritten queries:
+-------------------+---------------------+---------------------+---------------------+
|       name        |        RR@10        |       nDCG@20       |         AP          |
+-------------------+---------------------+---------------------+---------------------+
| TerrierRetr(BM25) | 0.24573764377968843 | 0.12678901410908794 | 0.07606560976303532 |
+-------------------+---------------------+---------------------+---------------------+


In [11]:
# Combine the original and rewritten queries side by side
comparison_df = pd.DataFrame({
    "Original Query": queries_unchanged['query'].head(20).values,
    "Rewritten Query": queries_antique['rewritten_query'].head(20).values
})

print(comparison_df.to_string(index=False))

                                                                                     Original Query                                                                 Rewritten Query
                                                  What causes severe swelling and pain in the knees                                         causes for severe swelling of the knees
                                             why dont they put parachutes underneath airplane seats                                          what should parachutes be on airplanes
                                                                  how to clean alloy cylinder heads                                          how to clean a7 ct head cylinder heads
                                                                           how do i get them whiter                   what is the best product to use for a person with white hairs
                                                                     What is Cloud 9 and 7th Heaven 