# Information Retrieval Lab WiSe 2024/2025: Baseline Retrieval System

This Jupyter notebook serves as a baseline retrieval system that you can improve upon.
We use subsets of the MS MARCO datasets to retrieve passages of web documents.
We will show you how to create a software submission to TIRA from this notebook.

An overview of all corpora that we use in the current course is available at [https://tira.io/datasets?query=ir-lab-wise-2024](https://tira.io/datasets?query=ir-lab-wise-2024). The dataset IDs for loading the datasets are:

- `ir-lab-wise-2024/subsampled-ms-marco-deep-learning-20241201-training`: A subsample of the TREC 2019/2020 Deep Learning tracks on the MS MARCO v1 passage dataset. Use this dataset to tune your system(s).
- `ir-lab-wise-2024/subsampled-ms-marco-rag-20241202-training` (_work in progress_): A subsample of the TREC 2024 Retrieval-Augmented Generation track on the MS MARCO v2.1 passage dataset. Use this dataset to tune your system(s).
- `ir-lab-wise-2024/ms-marco-rag-20241203-test` (work in progress): The test corpus that we have created together in the course, based on the MS MARCO v2.1 passage dataset. We will use this dataset as the test dataset, i.e., evaluation scores become available only after the submission deadline.

### Step 1: Import libraries

We will use [tira](https://tira.io/), an information retrieval shared task platform, and [ir_dataset](https://ir-datasets.com/) for loading the datasets. Subsequently, we will build a retrieval system with [PyTerrier](https://github.com/terrier-org/pyterrier), an open-source search engine framework.

First, we need to install the required libraries.

In [None]:
!pip3 install "tira>=0.0.139" ir-datasets "python-terrier==0.10.0"

Create an API client to interact with the TIRA platform (e.g., to load datasets and submit runs).

In [None]:
from tira.third_party_integrations import ensure_pyterrier_is_loaded
from tira.rest_api_client import Client

ensure_pyterrier_is_loaded()
tira = Client()

### Step 2: Load the dataset

We load the dataset by its ir_datasets ID (as listed in the Readme). Just be sure to add the `irds:` prefix before the dataset ID to tell PyTerrier to load the data from ir_datasets.

In [3]:
import pyterrier as pt

pt_dataset = pt.get_dataset('irds:ir-lab-wise-2024/subsampled-ms-marco-deep-learning-20241201-training')

In [None]:
from importlib import reload
from typing import Iterable
import data_cleaning
reload(data_cleaning)

class DataCleaningIter(Iterable):
    def __init__(self, dataset_iter) -> None:
        self.dataset_iter = iter(dataset_iter)
    
    def __iter__(self):
        return self
    
    def __next__(self):
        item = next(self.dataset_iter)
        item["text"] = data_cleaning.clean_document(item["text"])
        return item

data_cleaning_iter = DataCleaningIter(pt_dataset.get_corpus_iter(verbose=True))


### Step 3: Build an index

We will then create an index from the documents in the dataset we just loaded.

In [None]:
import os

indexer = pt.IterDictIndexer(
    index_path=os.getcwd() + os.sep + "index",
    meta={'docno': 50, 'text': 4096},
    # If an index already exists there, then overwrite it.
    overwrite=True,
)

index = indexer.index(data_cleaning_iter)

### Step 4: Define the retrieval pipeline

We will define a simple retrieval pipeline using just BM25 as a baseline. For details, refer to the PyTerrier [documentation](https://pyterrier.readthedocs.io) or [tutorial](https://github.com/terrier-org/ecir2021tutorial).

In [6]:
bm25 = pt.BatchRetrieve(index, wmodel="BM25")

### Step 5: Create the run
In the next steps, we would like to apply our retrieval system to some topics, to prepare a 'run' file, containing the retrieved documents.

First, let's have a short look at the first three topics:

In [None]:
# The `'text'` argument below selects the topics `text` field as the query.
pt_dataset.get_topics('text').head(3)

In [None]:
import torch
import string
torch.cuda.is_available()

# Configure Model

In [115]:
model_name = "google/flan-t5-base"
prompt_template = string.Template('''
Rewrite the following query by expanding on the topic, providing additional context or details. 
If applicable, offer a brief response or overview to help clarify the answer.
Query: $query
Answer: ''')
repeat_original_query_x_times: int = 4

In [116]:
import re
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
model.to(device)

class MyQueryTransformer(pt.Transformer):
    
    def generate_text(self, query: str):
        prompt = prompt_template.substitute(query=query)
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        input_ids = input_ids.to(device)
        generated_ids = model.generate(
                input_ids,
                max_length=1000,
                num_beams=5,
                no_repeat_ngram_size=2,
                early_stopping=True
        )
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) 
        generated_text = re.sub(r'[^a-zA-Z0-9 _]', '', generated_text)
        
        query = repeat_original_query_x_times * (query + " ") + generated_text
        # print(query)
        return query
    
    def transform(self, df):
        df['query'] = df["query"].apply(self.generate_text)
        return df

In [117]:
pipeline = MyQueryTransformer() >> bm25

# res = MyQueryTransformer().transform({"query": "who is aziz hashim"})
# print(res)


Now, retrieve results for all the topics (may take a while):

In [106]:
run = pipeline(pt_dataset.get_topics('text'))

That's it for the retrieval. Here are the first 10 entries of the run:

In [None]:
run.head(10)

### Step 6: Evaluate your run

In [118]:
pt.Experiment([bm25, pipeline],
    pt_dataset.get_topics('text'),
    pt_dataset.get_qrels(),
    eval_metrics = ["map", "recip_rank", "ndcg_cut_10", "P_1", "P_5", "P_10"],
    names = ["BM25", "BM25 + query2doc Query Expansion"]
)

Unnamed: 0,name,map,recip_rank,ndcg_cut_10,P_1,P_5,P_10
0,BM25,0.412718,0.786653,0.489469,0.701031,0.62268,0.574227
1,BM25 + query2doc Query Expansion,0.420156,0.798746,0.497448,0.71134,0.639175,0.580412
