<a href="https://colab.research.google.com/github/thiagolaitz/IA368-search-engines/blob/main/Project%2008/multistage_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction - Multistage pipelines

In the era of abundant data and information, efficient and accurate information retrieval systems have become crucial. Multistage pipelines offer a powerful approach to process and refine large volumes of data, enabling us to extract valuable insights effectively. In this colab notebook, we discuss latency and quality aspects of multistage pipelines for information retrieval. By examining 5 different configurations, we aim to understand the trade-offs between ranking speed and retrieval accuracy.

In [None]:
!pip install pyserini -q
!pip install faiss-cpu==1.7.2 -q

In [None]:
!pip install transformers accelerate==0.19.0 -q

In [None]:
from typing import List
from math import exp
import json

import torch
from tqdm.auto import tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BatchEncoding,
    AutoModelForSeq2SeqLM
)

from pyserini.search import get_topics
from pyserini.search.lucene import LuceneSearcher

import accelerate

In [None]:
# Gets the dictionary containing the IDs of the queries and their texts.
topics = get_topics('dl20')

# Gets a LuceneSearcher to execute the BM25 algorithm
searcher = LuceneSearcher.from_prebuilt_index('msmarco-passage')

# Models

The tested pipelines are based on BM25 + Minilm or T5

## Minilm

In [None]:
class Minilm():
    def __init__(self, model_path):
        """
        Loads the MiniLM model from the given path
        """
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device)
        self.model.eval() # Put the model in evaluation mode        

        self.tokenizer = AutoTokenizer.from_pretrained(model_path)

    def tokenize(self, query: str, doc: str):
        """
        Tokenize a query and document.
        Args:
            query: the query text
            doc: the passage text
        Returns:
            A dict containing the input_ids, token_type_ids and attention_mask
        """
        encoded_input = self.tokenizer(
            query,
            doc,
            add_special_tokens=True,
            max_length=256,
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt"
        )

        # Removes extra dimensions
        for key in encoded_input.keys():
            encoded_input[key] = torch.squeeze(encoded_input[key])
        return encoded_input

    def rescore(self, query: str, doc_batch: List[str]):
        """
        Given a query and a batch of documents it returns a list 
        of scores.
        """
        # Tokenize the inputs
        encoded_inputs = [self.tokenize(query, doc) for doc in doc_batch]
        # Add pads to keep all inputs with the same length
        padded_inputs = BatchEncoding(self.tokenizer.pad(encoded_inputs, return_tensors="pt")).to(self.device)

        with torch.no_grad():
            outputs = self.model(**padded_inputs)

        return [score[0] for score in outputs.logits]

## T5

In [None]:
class MonoT5():
    def __init__(self, model_name_or_path: str = 'castorini/monot5-base-msmarco-10k', fp16: bool = False):
        """
        Loads the T5 model from the given path.
        Args:
            model_name_or_path: path to the model
            fp16: whether the model should be loaded using FP16
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        # The training was carried out using two specific tokens for relevant and non-relevant passages
        self.token_false_id = self.tokenizer.get_vocab()['▁false']
        self.token_true_id  = self.tokenizer.get_vocab()['▁true']

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Loads the model with model_args
        model_args = {}
        if fp16:
            model_args["torch_dtype"] = torch.float16

        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, **model_args).to(self.device)

    @torch.no_grad()
    def rescore(self, query: str, batch: List[str]):
        """
        Adapted from Pygaggle's repo.
        Rescore all documents for the given query.
        Args:
            query: the query for ranking
            batch: list of passages for ranking
        """
        scores = []
        # Creates the inputs to the model
        queries_documents = [f"Query: {query} Document: {text} Relevant:" for text in batch]
        tokenized = self.tokenizer(
            queries_documents,
            padding=True,
            truncation="longest_first",
            return_tensors="pt",
            max_length=512,
        ).to(self.device)
        input_ids = tokenized["input_ids"].to(self.device)
        attention_mask = tokenized["attention_mask"].to(self.device)
        _ , batch_scores = self.greedy_decode(model=self.model,
                                            input_ids=input_ids,
                                            length=1,
                                            attention_mask=attention_mask,
                                            return_last_logits=True)
        batch_scores = batch_scores[:, [self.token_false_id, self.token_true_id]]
        batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
        batch_log_probs = batch_scores[:, 1].tolist()
        batch_probs = [exp(log_prob) for log_prob in batch_log_probs]
        scores.extend(batch_probs)
        return scores

    @torch.no_grad()
    def greedy_decode(
        self,
        model,
        input_ids: torch.Tensor,
        length: int,
        attention_mask: torch.Tensor = None,
        return_last_logits: bool = True
    ):
        """
        Adapted from Pygaggle's repo.
        Performs the greedy_decode on t5's output.
        """
        decode_ids = torch.full((input_ids.size(0), 1),
                                model.config.decoder_start_token_id,
                                dtype=torch.long).to(input_ids.device)
        encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask)
        next_token_logits = None
        for _ in range(length):
            model_inputs = model.prepare_inputs_for_generation(
                decode_ids,
                encoder_outputs=encoder_outputs,
                past=None,
                attention_mask=attention_mask,
                use_cache=True)
            outputs = model(**model_inputs)  # (batch_size, cur_len, vocab_size)
            next_token_logits = outputs[0][:, -1, :]  # (batch_size, vocab_size)
            decode_ids = torch.cat([decode_ids,
                                    next_token_logits.max(1)[1].unsqueeze(-1)],
                                dim=-1)
        if return_last_logits:
            return decode_ids, next_token_logits
        return decode_ids

# Util function

In [None]:
def get_run(path: str, model, batch_size: int, top_k: int = 100):
    """
    Creates a TREC run using BM25 + the given model.
    Args:
        path: result path for the run
        model: object with the model (minilm or t5)
        batch_size: batch_size used for inference
        top_k: Number of passagens returned by the first stage (bm25)
    """
    with open(path, "w") as fout:
        for qid, topic in tqdm(topics.items(), 'Rescoring'):
            # First stage (BM25)
            hits = searcher.search(topic["title"], top_k)
            # Separate in batches
            batches = [hits[i:i+batch_size] for i in range(0, len(hits), batch_size)]
            # Reranking with the model
            rank = 0
            for batch in batches:
                batch_content = [json.loads(hit.raw)["contents"] for hit in batch]
                scores = model.rescore(topic["title"], batch_content)
                for hit, score in zip(batch, scores):
                    fout.write(f"{qid}\tQ0\t{hit.docid}\t{rank+1}\t{score}\tRun\n")
                    rank += 1

# Configurations

## BM25 + MiniLM-L-12 - FP32


In [None]:
minilm = Minilm("cross-encoder/ms-marco-MiniLM-L-12-v2")

Downloading (…)lve/main/config.json:   0%|          | 0.00/791 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/134M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [None]:
batch_size = 32
top_k = 100

get_run("minilm.tsv", minilm, batch_size, top_k)

Rescoring:   0%|          | 0/200 [00:00<?, ?it/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [None]:
!python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 -mmap -l 2 dl20-passage minilm.tsv

Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
jtreceval-0.0.5-jar-with-dependencies.jar: 1.79MB [00:00, 5.57MB/s]                
Running command: ['java', '-jar', '/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-c', '-m', 'ndcg_cut.10', '-mmap', '-l', '2', '/root/.cache/pyserini/topics-and-qrels/qrels.dl20-passage.txt', 'minilm.tsv']
Results:
map                   	all	0.3956
ndcg_cut_10           	all	0.6710


## BM25 + t5-base - FP32

In [None]:
t5_base = MonoT5('castorini/monot5-base-msmarco-10k')

In [None]:
batch_size = 32
top_k = 100

get_run("t5_base.tsv", t5_base, batch_size, top_k)

Rescoring:   0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
!python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 -mmap -l 2 dl20-passage t5_base.tsv

Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-c', '-m', 'ndcg_cut.10', '-mmap', '-l', '2', '/root/.cache/pyserini/topics-and-qrels/qrels.dl20-passage.txt', 't5_base.tsv']
Results:
map                   	all	0.3869
ndcg_cut_10           	all	0.6699


## BM25 + t5-base - FP16

In [None]:
t5_base = MonoT5('castorini/monot5-base-msmarco-10k', fp16=True)

Downloading:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.30k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/850M [00:00<?, ?B/s]

In [None]:
batch_size = 32
top_k = 100

get_run("t5_base_fp16.tsv", t5_base, batch_size, top_k)

Rescoring:   0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
!python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 -mmap -l 2 dl20-passage t5_base_fp16.tsv

Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
jtreceval-0.0.5-jar-with-dependencies.jar: 1.79MB [00:03, 473kB/s]                 
Running command: ['java', '-jar', '/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-c', '-m', 'ndcg_cut.10', '-mmap', '-l', '2', '/root/.cache/pyserini/topics-and-qrels/qrels.dl20-passage.txt', 't5_base_fp16.tsv']
Results:
map                   	all	0.3870
ndcg_cut_10           	all	0.6698


## BM25 + t5-large - FP16

In [None]:
t5_large = MonoT5('castorini/monot5-large-msmarco-10k', fp16=True)

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.75G [00:00<?, ?B/s]

In [None]:
batch_size = 32
top_k = 100

get_run("t5_large_fp16.tsv", t5_large, batch_size, top_k)

Rescoring:   0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
!python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 -mmap -l 2 dl20-passage t5_large_fp16.tsv

Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-c', '-m', 'ndcg_cut.10', '-mmap', '-l', '2', '/root/.cache/pyserini/topics-and-qrels/qrels.dl20-passage.txt', 't5_large_fp16.tsv']
Results:
map                   	all	0.3970
ndcg_cut_10           	all	0.6692


## BM25 + t5 3B - FP16

In [None]:
t5_3b = MonoT5('castorini/monot5-3b-msmarco-10k', fp16=True)

In [None]:
batch_size = 32
top_k = 100

get_run("t5_3b_fp16.tsv", t5_3b, batch_size, top_k)

Rescoring:   0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
!python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 -mmap -l 2 dl20-passage t5_3b_fp16.tsv

Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-c', '-m', 'ndcg_cut.10', '-mmap', '-l', '2', '/root/.cache/pyserini/topics-and-qrels/qrels.dl20-passage.txt', 't5_3b_fp16.tsv']
Results:
map                   	all	0.4143
ndcg_cut_10           	all	0.6907
