# Tutorial: Rerank search results using the ColBERT Reranker #

In this tutorial, we will learn how to use a Neural Reranker to rerank results from a BM25 search.  The reranker is based on the ColBERT algorithm as described in Khattab et al., "ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT" [here](https://arxiv.org/pdf/2004.12832.pdf).

For the purposes of making this tutorial easy to understand, we show the steps using a very small document collection. Note that this technique can be used to scale to millions of documents. We have tested upto 21 million Wikipedia passages!!!

The tutorial will take you through these three steps:

1. Build a BM25 index over a small sample collection
2. Query the BM25 index to obtain initial search results
3. Rerank the initial results with a neural reranker to obtain the final search results


## Preparing a Colab Environment to run this tutorial ##

Make sure to "Enable GPU Runtime" -> make a URL with a page with screenshots on how to do this.

## Installing PrimeQA ##

First, we need to include the required modules.

In [None]:
!pip install --upgrade pip

# install primeqa
!pip install --upgrade primeqa install-jdk gdown

# Java 11 is required
# Then jdk.install will fail with "Permission denied" error if already installed
# Comment out the next two lines if already installed
import jdk
jdk_dir = jdk.install('11', path="/tmp/primeqa-jdk")

# set the JAVA_HOME environment variable to point to Java 11
%env JAVA_HOME=/tmp/primeqa-jdk/jdk-11.0.19+7/


Next we set up some paths.  Please update the `output_dir` path to a location where you have write permissions.

In [None]:
# Setup paths 
output_dir = "/tmp/primeqa-tutorial" 

import os
# create output directory if it does not exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# setup some paths
downloaded_corpus_file = os.path.join(output_dir,"sample-document-store.csv")
collection_file = os.path.join(output_dir,"sample_collection.tsv")
reranker_model_path = os.path.join(output_dir, "DrDecr.dnn")


## Download the sample corpus and the ColBERT DrDecr model ##

In [None]:
# download the sample collection
! gdown  --id 1LULJRPgN_hfuI2kG-wH4FUwXCCdDh9zh --output {output_dir}/

downloaded_corpus_file = os.path.join(data_dir,"sample-document-store.csv")
index_dir = os.path.join(output_dir,"sample_bm25_index")
collection_file = os.path.join(data_dir,"sample_collection.tsv")

# create the model directory
model_dir = os.path.join(output_dir,'models')
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
reranker_model_path = os.path.join(model_dir, "DrDecr.dnn")


## Pre-process your document collection here to be ready to be stored in your BM25 Search Index. ##


In [None]:
# download the sample collection
! pip install gdown
! gdown  --id 1LULJRPgN_hfuI2kG-wH4FUwXCCdDh9zh --output {data_dir}/

# format the sample collection for indexing
import csv
with open(collection_file, 'w') as out_f:
    tsv_writer = csv.writer(out_f, quoting=csv.QUOTE_MINIMAL, lineterminator='\n', delimiter='\t')
    tsv_writer.writerow( ('id','text','title') )
    with open(downloaded_corpus_file) as in_f:
        reader = csv.DictReader(in_f, delimiter=',',)
        for i, row in enumerate(reader):
            tsv_writer.writerow( (str(i+1), row['text'], row['title'] ))
! head -2 {collection_file}

## Now we will use the PrimeQA BM25 Indexer to build an index ##

In [None]:
from primeqa.components.indexer.sparse import BM25Indexer

# Instantiate and configure the indexer
indexer = BM25Indexer(index_root=output_dir, index_name="sample_index_bm25")
indexer.load()   # IMPORTANT: required to configure

# Index the collection
indexer.index(collection=collection_file)


## Start asking Questions ##

We're now ready to query the index we created.  

In [None]:
from primeqa.components.retriever.sparse  import BM25Retriever
import json

# Exmaple questions
question = ["Why was Einstein awarded the Nobel Prize?"]
# Instantiate and configure the retriever
print(output_dir)
retriever = BM25Retriever(index_root=output_dir, index_name="sample_index_bm25", collection=collection_file)
retriever.load()

# Search
search_results = retriever.predict(question, max_num_documents=5)
print(json.dumps(search_results,indent=2))


## Rerank the BM25 search results with a Neural Reranker ##

We will be using the DrDecr model trained on Natural Questions and XOR TyDI.  This is a model that has obtained SOTA results on the XORTyDI Retrieval task.  

Here are the steps we will take:

    1. Download the reranker model
    2. Format search results for input to the Reranker
    3. Initialize the PrimeQA ColBERTReranker
    4. Rerank the BM25 search results

In [None]:
# Download the reranker model
! wget -P {model_dir} https://huggingface.co/PrimeQA/DrDecr_XOR-TyDi_whitebox/resolve/main/DrDecr.dnn

## Format Reranker input 

The reranker encodes the question and passage texts using the Reranker model and uses the representations to compute a similarity score.

We will take the BM25 search results, obtain the passage texts and format the input to the reranker

In [None]:
# Load the corpus
import csv
id_to_document = {}
with open(collection_file) as in_f:
    reader = csv.DictReader(in_f, delimiter='\t')
    for row in reader:
        id_to_document[row['id']] = {
            'document_id': row['id'],
            'text': row['text'],
            'title': row['title']
        }
        
reranker_input = []
for result in search_results[0]:
    
    reranker_input.append({
        'document': id_to_document[result[0]],
        'score': result[1]
    })
print(json.dumps(reranker_input,indent=2))

## Run the Reranker ##
Next we will initialize the ColBERT Reranker with the DrDecr model and rerank the BM25 search results

In [None]:
# Import the ColBERT Reranker
from primeqa.components.reranker.colbert_reranker import ColBERTReranker

# Instantiate the ColBERTReranker
reranker = ColBERTReranker(reranker_model_path)
reranker.load()

# rerank the BM25 search result and output the top 3 hits
reranked_results = reranker.rerank(question, [hits_to_be_reranked], max_num_documents=3)
print(json.dumps(reranked_results,indent=2))


## Print the top ranked result before and after reranking ##

In [None]:
# print top ranked results
print(f'QUESTON: {question}')
      
print("\n========Top search result before reranking")
print(reranker_input[0]['document']['title'],  reranker_input[0]['document']['text'])

print("\n========Top search result after reranking")
print(reranked_results[0][0]['document']['title'],  reranked_results[0][0]['document']['text'])
