# Re-Rank

Given a search query, we first use a retrieval system that retrieves a large list of e.g. 100 possible hits which are potentially relevant for the query. For the retrieval, we can use either lexical search, e.g. with a vector engine like Elasticsearch, or we can use dense retrieval with a bi-encoder. However, the retrieval system might retrieve documents that are not that relevant for the search query. Hence, in a second stage, we use a re-ranker based on a cross-encoder that scores the relevancy of all candidates for the given search query. The output will be a ranked list of hits we can present to the user.

The retriever has to be efficient for large document collections with millions of entries. However, it might return irrelevant candidates. A re-ranker based on a Cross-Encoder can substantially improve the final results for the user. The query and a possible document is passed simultaneously to transformer network, which then outputs a single score between 0 and 1 indicating how relevant the document is for the given query.


Characteristics of Cross Encoder (a.k.a reranker) models:

- Calculates a similarity score given pairs of texts.
- Generally provides superior performance compared to a Sentence Transformer (a.k.a. bi-encoder) model.
- Often slower than a Sentence Transformer model, as it requires computation for each pair rather than each text.
- Due to the previous 2 characteristics, Cross Encoders are often used to re-rank the top-k results from a Sentence Transformer model.

## load and use

In [None]:
from sentence_transformers import CrossEncoder

model_name = '../../cache/officials/ms-marco-MiniLM-L-6-v2'
model = CrossEncoder(model_name)

# two way pose no difference to model
pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
# pairs = [('what is panda?', 'hi'), ('what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.')]
scores = model.predict(pairs, show_progress_bar=True)
print(scores)


In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model_name = '../../cache/officials/ms-marco-MiniLM-L-6-v2'
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]

model.eval()
with torch.no_grad():
    inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors="pt")
    scores = model(**inputs).logits
    print(scores)

## retrieval and re-rank

In [1]:
import json
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import gzip
import os
import torch

if not torch.cuda.is_available():
    print("Warning: No GPU found. Please add GPU to your notebook")
    
def search(query, bi_encoder, cross_encoder):
    print("Input question:", query)

    ##### Retreving / Semantic Search #####
    # Encode the query using the bi-encoder and find potentially relevant passages
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    question_embedding = question_embedding.cuda()
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
    hits = hits[0]  # Get the hits for the first query, or the only query
    # hits now has score, corpus_id. 

    ##### Re-Ranking #####
    # Now, score all retrieved passages with the cross_encoder
    cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)
    # cross_scores is just a one-dim array, meaning the corresponding score for each paired data. 

    # add the cross-encoder scores into hits. So that hits has scores of both retrieval and re-rank.
    for idx in range(len(cross_scores)):
        hits[idx]['cross-score'] = cross_scores[idx]

    # Output of top-5 hits from bi-encoder
    print("\n-------------------------\n")
    print("Top-3 Bi-Encoder Retrieval hits")
    hits = sorted(hits, key=lambda x: x['score'], reverse=True)
    for hit in hits[0:3]:
        print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))

    # Output of top-5 hits from re-ranker
    print("\n-------------------------\n")
    print("Top-3 Cross-Encoder Re-ranker hits")
    hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
    for hit in hits[0:3]:
        print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))


#We use the Bi-Encoder to encode all passages, so that we can use it with semantic search
bi_encoder = SentenceTransformer('../../cache/officials/all-MiniLM-L6-v2')
bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens
top_k = 32                          #Number of passages we want to retrieve with the bi-encoder
#The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('../../cache/officials/ms-marco-MiniLM-L-6-v2')

# As dataset, we use Simple English Wikipedia. Compared to the full English wikipedia, it has only
# about 170k articles. We split these articles into paragraphs and encode them with the bi-encoder

wikipedia_filepath = 'simplewiki-2020-11-01.jsonl.gz'

if not os.path.exists(wikipedia_filepath):
    util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)

passages = []
with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        data = json.loads(line.strip())

        # add all paragraphs
        passages += data['paragraphs']
        
        # add only the first to save time
        # passages.append(data['paragraphs'][0])

print("Passages:", len(passages))

# We encode all passages into our vector space. This takes about 5 minutes (depends on your GPU speed)
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)

  from tqdm.autonotebook import tqdm, trange
No sentence-transformers model found with name ../../cache/officials/all-MiniLM-L6-v2. Creating a new one with mean pooling.


Passages: 509663


Batches:   0%|          | 0/15927 [00:00<?, ?it/s]

In [2]:
search("What is the capital of the United States?", bi_encoder, cross_encoder)

Input question: What is the capital of the United States?

-------------------------

Top-3 Bi-Encoder Retrieval hits
	0.708	United States Capitol.
	0.623	Its capital and largest city is Boston. It is on the east coast of the United States. It is next to the Atlantic Ocean and the states of Rhode Island, Connecticut, New York, Vermont, and New Hampshire. The word "Massachusetts" comes from Native American language. It means "place with hills."
	0.605	The continental United States is the area of the United States of America that is located in the continent of North America. It includes 49 of the 50 states (48 of which are located south of Canada and north of Mexico, known as the "lower 48 states", the other being Alaska) and the District of Columbia, which contains the federal capital, Washington, D.C. The only state which is not part of this is Hawaii (as they are islands in the Pacific Ocean and not part of North America).

-------------------------

Top-3 Cross-Encoder Re-ranker hits

In [3]:
search("How long do cats live?", bi_encoder, cross_encoder)


Input question: How long do cats live?

-------------------------

Top-3 Bi-Encoder Retrieval hits
	0.724	Reliable information on the lifespans of house cats is hard to find. However, research has been done to get an estimate (an educated guess) on how long cats usually live. Cats usually live for 13 to 20 years. Sometimes cats can live for 22 to 30 years but there are claims of cats dying at ages of more than 30 years old.
	0.711	The life expectancy of an indoor cat is around 17 years, but the life expectancy of outdoor cats is 5.6 years.
	0.663	Very little is known about how long they live. Some reports say "30 years or more", while others say "50 years or more".

-------------------------

Top-3 Cross-Encoder Re-ranker hits
	10.431	Reliable information on the lifespans of house cats is hard to find. However, research has been done to get an estimate (an educated guess) on how long cats usually live. Cats usually live for 13 to 20 years. Sometimes cats can live for 22 to 30 years but