In [1]:
import ir_datasets
import tqdm
import pyterrier as pt
print(pt.__version__)

pt.init()
from pathlib import Path
import re
import pandas as pd
from pyterrier.measures import RR, nDCG, MAP

0.13.0


Java started and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]
java is now started automatically with default settings. To force initialisation early, run:
pt.java.init() # optional, forces java initialisation
  pt.init()


In [2]:
msmarco = ir_datasets.load("msmarco-passage/train/judged")

In [3]:
# OPTIONAL - Index the data

idx_path = Path.cwd() / "indices" / "msmarco_train_judged"

def msmarco_gen(limit=100000):
    lastdoc = 0
    for elem in msmarco.docs:
        if lastdoc >= limit:
            break
        yield {
            "docno": elem.doc_id,
            "text": elem.text,
        }
        lastdoc += 1

max(len(text.encode("utf-8")) for _, text in msmarco.docs)

if not idx_path.exists() or not any(idx_path.iterdir()):
    indexer = pt.IterDictIndexer(
        str(idx_path),
        meta={
            "docno": 20,
            "text": 4096,
        },
        stemmer="porter",
        stopwords="terrier",
    )

    index_ref = indexer.index(msmarco_gen())
else:
    print("Indices already exist, skipping creation")

In [4]:
# Define index paths
index_dir_msmarco = Path.cwd() / "indices" / "msmarco_train_judged"
# Load the indexes
index_msmarco = pt.IndexFactory.of(str(index_dir_msmarco))
#index_antique = pt.IndexFactory.of(str(index_dir_antique))

# Use BM25 as the baseline retriever
retriever_msmarco = pt.BatchRetrieve(index_msmarco, wmodel="BM25")
#retriever_antique = pt.BatchRetrieve(index_antique, wmodel="BM25")

  retriever_msmarco = pt.BatchRetrieve(index_msmarco, wmodel="BM25")


In [5]:
# Convert qrels to DataFrames
qrels_msmarco = pd.DataFrame(msmarco.qrels_iter()).head(1000)
#qrels_antique = pd.DataFrame(antique.qrels_iter())

# Convert queries to DataFrames
queries_msmarco = pd.DataFrame(msmarco.queries_iter()).head(1000)
#queries_antique = pd.DataFrame(antique.queries_iter())

# Rename columns for PyTerrier compatibility
qrels_msmarco.rename(columns={"query_id": "qid", "doc_id": "docno", "relevance": "label"}, inplace=True)
#qrels_antique.rename(columns={"query_id": "qid", "doc_id": "docno", "relevance": "label"}, inplace=True)

queries_msmarco.rename(columns={"query_id": "qid", "text": "query"}, inplace=True)
#queries_antique.rename(columns={"query_id": "qid", "text": "query"}, inplace=True)

In [6]:
def clean_query(query):
    query = query.encode("ascii", "ignore").decode()
    query = query.replace("'", "").replace('"', "").replace("`", "")
    query = re.sub(r"\s+", " ", query).strip()
    return query

queries_msmarco["query"] = queries_msmarco["query"].apply(clean_query)
#queries_antique["query"] = queries_antique["query"].apply(clean_query)

In [7]:
#Model imports
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Load the T5 model
MODEL_ID = "prhegde/t5-query-reformulation-RL"
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
model.eval()


  from .autonotebook import tqdm as notebook_tqdm





T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [8]:
# Function to rewrite queries
def rewrite_query(query, nsent=1):
    input_ids = tokenizer(query, return_tensors="pt").input_ids
    with torch.no_grad():
        output = model.generate(input_ids, max_length=35, num_beams=1, do_sample=True, repetition_penalty=1.8)
    return tokenizer.decode(output[0], skip_special_tokens=True)

def clean_query(query):
    query = query.replace("\n", " ").replace("\r", " ").replace("?", "")
    query = query.encode("ascii", "ignore").decode()
    query = re.sub(r"[\"`]", "", query)
    query = re.sub(r"\s+", " ", query).strip()
    query = re.sub(r"[^\w\s]", "", query)
    return query

queries_msmarco["rewritten_query"] = queries_msmarco["query"].apply(clean_query).apply(rewrite_query).apply(clean_query)


In [9]:
bm25_msmarco = pt.terrier.Retriever(index_msmarco, wmodel="BM25")


In [None]:
queries_msmarco['query'] = queries_msmarco['rewritten_query']

pt.Experiment(
    [bm25_msmarco],
    queries_msmarco,  # Use rewritten queries
    qrels_msmarco,
    eval_metrics=[RR @ 10, nDCG @ 20, MAP],
)