In [1]:
import ir_datasets

import pyterrier as pt
print(pt.__version__)

pt.init()
from pathlib import Path
import re
import pandas as pd
from pyterrier.measures import RR, nDCG, MAP, MRR, Recall, Precision
from tqdm.notebook import tqdm

0.13.0


Java started and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]
java is now started automatically with default settings. To force initialisation early, run:
pt.java.init() # optional, forces java initialisation
  pt.init()


In [2]:
msmarco = ir_datasets.load("msmarco-passage/dev/small")
print(msmarco.docs[0])

GenericDoc(doc_id='0', text='The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.')


In [3]:
# OPTIONAL - Index the data

idx_path = Path.cwd() / "indices" / "msmarco_dev_small"

def msmarco_gen(limit=100000):
    lastdoc = 0
    for elem in msmarco.docs:
        if lastdoc >= limit:
            break
        yield {
            "docno": elem.doc_id,
            "text": elem.text,
        }
        lastdoc += 1

max(len(text.encode("utf-8")) for _, text in msmarco.docs)

if not idx_path.exists() or not any(idx_path.iterdir()):
    indexer = pt.IterDictIndexer(
        str(idx_path),
        meta={
            "docno": 20,
            "text": 4096,
        },
        stemmer="porter",
        stopwords="terrier",
    )

    index_ref = indexer.index(msmarco_gen())
else:
    print("Indices already exist, skipping creation")

Indices already exist, skipping creation


In [4]:
# Define index paths
index_dir_msmarco = Path.cwd() / "indices" / "msmarco_dev_small"
# Load the indexes
index_msmarco = pt.IndexFactory.of(str(index_dir_msmarco))
#index_antique = pt.IndexFactory.of(str(index_dir_antique))

# Use BM25 as the baseline retriever
retriever_msmarco = pt.BatchRetrieve(index_msmarco, wmodel="BM25")
#retriever_antique = pt.BatchRetrieve(index_antique, wmodel="BM25")

12:26:20.050 [main] WARN org.terrier.structures.BaseCompressingMetaIndex -- Structure meta reading data file directly from disk (SLOW) - try index.meta.data-source=fileinmem in the index properties file. 1,9 GiB of memory would be required.


  retriever_msmarco = pt.BatchRetrieve(index_msmarco, wmodel="BM25")


In [5]:
# Convert qrels to DataFrames
qrels_msmarco = pd.DataFrame(msmarco.qrels_iter())
#qrels_antique = pd.DataFrame(antique.qrels_iter())

# Convert queries to DataFrames
queries_msmarco = pd.DataFrame(msmarco.queries_iter())
#queries_antique = pd.DataFrame(antique.queries_iter())

# Rename columns for PyTerrier compatibility
qrels_msmarco.rename(columns={"query_id": "qid", "doc_id": "docno", "relevance": "label"}, inplace=True)
#qrels_antique.rename(columns={"query_id": "qid", "doc_id": "docno", "relevance": "label"}, inplace=True)

queries_msmarco.rename(columns={"query_id": "qid", "text": "query"}, inplace=True)
#queries_antique.rename(columns={"query_id": "qid", "text": "query"}, inplace=True)

In [6]:
import torch
import re

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("CUDA Available:", torch.cuda.is_available())
print("Current Device:", torch.cuda.current_device())
print("Device Name:", torch.cuda.get_device_name(0))
print("Device Count:", torch.cuda.device_count())

Using device: cuda
CUDA Available: True
Current Device: 0
Device Name: NVIDIA GeForce RTX 3060
Device Count: 1


In [7]:
#Model imports
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Load the T5 model
MODEL_ID = "prhegde/t5-query-reformulation-RL"
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
model.eval()




T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [8]:
# Function to rewrite queries
def rewrite_query(query, nsent=1):
    input_ids = tokenizer(query, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        output = model.generate(input_ids, max_length=20, num_beams=5, do_sample=False)
    return tokenizer.decode(output[0], skip_special_tokens=True)

def clean_query(query):
    query = query.replace("\n", " ").replace("\r", " ").replace("?", "")
    query = query.encode("ascii", "ignore").decode()
    query = re.sub(r"[\"`]", "", query)
    query = re.sub(r"\s+", " ", query).strip()
    query = re.sub(r"[^\w\s]", "", query)
    return query

queries_unchanged = queries_msmarco[["qid", "query"]]#.copy(deep=True)

queries_unchanged["query"] = queries_unchanged["query"].apply(clean_query)
# queries_antique["rewritten_query"] = queries_antique["query"].apply(rewrite_query).apply(clean_query)

queries_msmarco["rewritten_query"] = [
    clean_query(rewrite_query(query))
    for query in tqdm(queries_msmarco["query"], desc="rewriting queries")
]


rewriting queries:   0%|          | 0/6980 [00:00<?, ?it/s]

In [9]:
bm25_msmarco = pt.terrier.Retriever(index_msmarco, wmodel="BM25")


In [10]:
print(queries_unchanged['query'].head(1))


print("Unchanged queries")


# results_unchanged = pt.Experiment(
#     [bm25_msmarco],
#     queries_unchanged[["qid", "query"]],  # Use rewritten queries
#     qrels_msmarco,
#     eval_metrics=[RR @ 10, nDCG @ 20, MAP],
# )
queries_msmarco['query'] = queries_msmarco['rewritten_query']
print(queries_msmarco['query'].head(1))

print("Rewritten queries")

results_rewritten = pt.Experiment(
    [bm25_msmarco],
    queries_msmarco[["qid", "query"]],  # Use rewritten queries
    qrels_msmarco,
    eval_metrics=[RR @ 10, nDCG @ 20, MAP],
)


from tabulate import tabulate

# Optionally, print the results of both experiments using tabulate
# print("Results for unchanged queries:")
# print(tabulate(results_unchanged, headers='keys', tablefmt='pretty', showindex=False))

print("\nResults for rewritten queries:")
print(tabulate(results_rewritten, headers='keys', tablefmt='pretty', showindex=False))

0    what is paula deens brother
Name: query, dtype: object
Unchanged queries
0    who is paula deens brother
Name: query, dtype: object
Rewritten queries

Results for rewritten queries:
+-------------------+---------------------+---------------------+---------------------+
|       name        |        RR@10        |       nDCG@20       |         AP          |
+-------------------+---------------------+---------------------+---------------------+
| TerrierRetr(BM25) | 0.15906268476827234 | 0.21941736655219785 | 0.16683199072530513 |
+-------------------+---------------------+---------------------+---------------------+


In [12]:
# Get first 20 queries
queries_unchanged_20 = queries_unchanged.head(20).copy()
queries_rewritten_20 = queries_msmarco.head(20).copy()

# Run experiments
results_unchanged = pt.Experiment(
    [bm25_msmarco],
    queries_unchanged_20,
    qrels_msmarco,
    eval_metrics=[RR @ 10, nDCG @ 20, MAP],
    perquery=True,

)

results_rewritten = pt.Experiment(
    [bm25_msmarco],
    queries_rewritten_20,
    qrels_msmarco,
    eval_metrics=[RR @ 10, nDCG @ 20, MAP],
    perquery=True,
)

In [15]:
i = 0
for qid in results_unchanged['qid'].unique():
    print("Query number " + str(i))
    i+=1
    # Filter the results for the current qid
    unchanged_row = results_unchanged[results_unchanged['qid'] == qid]
    rewritten_row = results_rewritten[results_rewritten['qid'] == qid]
    queries_unchanged_row = queries_unchanged_20[queries_unchanged_20['qid'] == qid]
    queries_rewritten_row = queries_rewritten_20[queries_rewritten_20['qid'] == qid]



    # Print results for unchanged
    print("=== Unchanged Results for qid:", qid, "===")
    if not unchanged_row.empty:
        print("Query:", queries_unchanged_row['query'].values)  # Print the matching query

        print(unchanged_row)
    else:
        print("No results found.")

    # Print results for rewritten
    print("=== Rewritten Results for qid:", qid, "===")
    if not rewritten_row.empty:
        print("Query:", queries_rewritten_row['query'].values)  # Print the matching query

        print(rewritten_row)
    else:
        print("No results found.")

    print("\n")  # Add a newline for better separation

    if(i > 10):
        break

Query number 0
=== Unchanged Results for qid: 1048585 ===
Query: ['what is paula deens brother']
                name      qid  measure  value
0  TerrierRetr(BM25)  1048585       AP    1.0
1  TerrierRetr(BM25)  1048585  nDCG@20    1.0
2  TerrierRetr(BM25)  1048585    RR@10    1.0
=== Rewritten Results for qid: 1048585 ===
Query: ['who is paula deens brother']
                name      qid  measure    value
0  TerrierRetr(BM25)  1048585       AP  0.50000
1  TerrierRetr(BM25)  1048585  nDCG@20  0.63093
2  TerrierRetr(BM25)  1048585    RR@10  0.50000


Query number 1
=== Unchanged Results for qid: 1048642 ===
Query: ['what is paranoid sc']
                 name      qid  measure     value
9   TerrierRetr(BM25)  1048642       AP  0.166667
10  TerrierRetr(BM25)  1048642  nDCG@20  0.356207
11  TerrierRetr(BM25)  1048642    RR@10  0.250000
=== Rewritten Results for qid: 1048642 ===
Query: ['paranoid schizophrenia definition']
                 name      qid  measure    value
9   TerrierRetr(BM