In [1]:
import ir_datasets
import tqdm
import pyterrier as pt
print(pt.__version__)

pt.init()
from pathlib import Path
import re
import pandas as pd
from pyterrier.measures import RR, nDCG, MAP

0.13.0


Java started and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]
java is now started automatically with default settings. To force initialisation early, run:
pt.java.init() # optional, forces java initialisation
  pt.init()


In [None]:
msmarco = ir_datasets.load("msmarco-passage/eval/small")
print(msmarco.docs[0])

[INFO] [starting] building docstore
[INFO] Please confirm you agree to the MSMARCO data usage agreement found at <http://www.msmarco.org/dataset.aspx>
[INFO] If you have a local copy of https://msmarco.z22.web.core.windows.net/msmarcoranking/collectionandqueries.tar.gz, you can symlink it here to avoid downloading it again: C:\Users\waded\.ir_datasets\downloads\31644046b18952c1386cd4564ba2ae69
[INFO] [starting] https://msmarco.z22.web.core.windows.net/msmarcoranking/collectionandqueries.tar.gz
docs_iter:   0%|                                   | 0/8841823 [00:00<?, ?doc/s]
https://msmarco.z22.web.core.windows.net/msmarcoranking/collectionandqueries.tar.gz: 0.0%| 0.00/1.06G [00:00<?, ?B/s][A
https://msmarco.z22.web.core.windows.net/msmarcoranking/collectionandqueries.tar.gz: 0.0%| 16.4k/1.06G [00:00<2:35:23, 113kB/s][A
https://msmarco.z22.web.core.windows.net/msmarcoranking/collectionandqueries.tar.gz: 0.0%| 49.2k/1.06G [00:00<1:47:58, 163kB/s][A
https://msmarco.z22.web.core.windows.

In [None]:
# OPTIONAL - Index the data

idx_path = Path.cwd() / "indices" / "msmarco_test"

def msmarco_gen(limit=100000):
    lastdoc = 0
    for elem in msmarco.docs:
        if lastdoc >= limit:
            break
        yield {
            "docno": elem.doc_id,
            "text": elem.text,
        }
        lastdoc += 1

max(len(text.encode("utf-8")) for _, text in msmarco.docs)

if not idx_path.exists() or not any(idx_path.iterdir()):
    indexer = pt.IterDictIndexer(
        str(idx_path),
        meta={
            "docno": 20,
            "text": 4096,
        },
        stemmer="porter",
        stopwords="terrier",
    )

    index_ref = indexer.index(msmarco_gen())
else:
    print("Indices already exist, skipping creation")

In [None]:
# Define index paths
index_dir_msmarco = Path.cwd() / "indices" / "msmarco_test"
# Load the indexes
index_msmarco = pt.IndexFactory.of(str(index_dir_msmarco))
#index_antique = pt.IndexFactory.of(str(index_dir_antique))

# Use BM25 as the baseline retriever
retriever_msmarco = pt.BatchRetrieve(index_msmarco, wmodel="BM25")
#retriever_antique = pt.BatchRetrieve(index_antique, wmodel="BM25")

In [None]:
# Convert qrels to DataFrames
qrels_msmarco = pd.DataFrame(msmarco.qrels_iter())
#qrels_antique = pd.DataFrame(antique.qrels_iter())

# Convert queries to DataFrames
queries_msmarco = pd.DataFrame(msmarco.queries_iter())
#queries_antique = pd.DataFrame(antique.queries_iter())

# Rename columns for PyTerrier compatibility
qrels_msmarco.rename(columns={"query_id": "qid", "doc_id": "docno", "relevance": "label"}, inplace=True)
#qrels_antique.rename(columns={"query_id": "qid", "doc_id": "docno", "relevance": "label"}, inplace=True)

queries_msmarco.rename(columns={"query_id": "qid", "text": "query"}, inplace=True)
#queries_antique.rename(columns={"query_id": "qid", "text": "query"}, inplace=True)

In [None]:
def clean_query(query):
    query = query.encode("ascii", "ignore").decode()
    query = query.replace("'", "").replace('"', "").replace("`", "")
    query = re.sub(r"\s+", " ", query).strip()
    return query

queries_msmarco["query"] = queries_msmarco["query"].apply(clean_query)
#queries_antique["query"] = queries_antique["query"].apply(clean_query)

In [None]:
#Model imports
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Load the T5 model
MODEL_ID = "prhegde/t5-query-reformulation-RL"
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
model.eval()


In [None]:
# Function to rewrite queries
def rewrite_query(query, nsent=1):
    input_ids = tokenizer(query, return_tensors="pt").input_ids
    with torch.no_grad():
        output = model.generate(input_ids, max_length=35, num_beams=1, do_sample=True, repetition_penalty=1.8)
    return tokenizer.decode(output[0], skip_special_tokens=True)

def clean_query(query):
    query = query.replace("\n", " ").replace("\r", " ").replace("?", "")
    query = query.encode("ascii", "ignore").decode()
    query = re.sub(r"[\"`]", "", query)
    query = re.sub(r"\s+", " ", query).strip()
    query = re.sub(r"[^\w\s]", "", query)
    return query

queries_msmarco["rewritten_query"] = queries_msmarco["query"].apply(clean_query).apply(rewrite_query).apply(clean_query)


In [None]:
bm25_msmarco = pt.terrier.Retriever(index_msmarco, wmodel="BM25")


In [None]:
queries_msmarco['query'] = queries_msmarco['rewritten_query']

pt.Experiment(
    [bm25_msmarco],
    queries_msmarco,  # Use rewritten queries
    qrels_msmarco,
    eval_metrics=[RR @ 10, nDCG @ 20, MAP],
)