In [49]:
!pip install -U sentence-transformers rank_bm25

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [50]:
# import all needed libraries

import json  # for reading the json file
from sentence_transformers import SentenceTransformer, CrossEncoder, util  # for sentence embedding
import gzip  # for reading the gzip file
import os  # for reading the file path
import torch  # for using the GPU
import string
from rank_bm25 import BM25Okapi
from sklearn.feature_extraction import _stop_words
from tqdm.autonotebook import tqdm
import numpy as np



In [51]:
# We use the Bi-Encoder to encode all passages, so that we can use it with semantic search
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256  # Truncate long passages to 256 tokens
top_k = 32  # Number of passages we want to retrieve with the bi-encoder

# The bi-encoder will retrieve 100 documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

In [52]:
# As dataset, we use Simple English Wikipedia We split these articles into paragraphs and encode them with the bi-encoder

wikipedia_filepath = 'simplewiki-2020-11-01.jsonl.gz'

if not os.path.exists(wikipedia_filepath):
    util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)

passages = []
with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        data = json.loads(line.strip())

        # Add all paragraphs
        # passages.extend(data['paragraphs'])

        # Only add the first paragraph
        passages.append(data['paragraphs'][0])

print("Passages:", len(passages))

# We encode all passages into our vector space. This takes about 5 minutes (depends on your GPU speed)
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)


Passages: 169597


Batches:   0%|          | 0/5300 [00:00<?, ?it/s]

In [53]:
# We also compare the results to lexical search (keyword search). Here, we use
# the BM25 algorithm which is implemented in the rank_bm25 package.

# We lower case our text and remove stop-words from indexing
def bm25_tokenizer(text):
    tokenized_doc = []
    for token in text.lower().split():
        token = token.strip(string.punctuation)

        if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS:
            tokenized_doc.append(token)
    return tokenized_doc


tokenized_corpus = []
for passage in tqdm(passages):
    tokenized_corpus.append(bm25_tokenizer(passage))

bm25 = BM25Okapi(tokenized_corpus)

  0%|          | 0/169597 [00:00<?, ?it/s]

In [70]:
from tabulate import tabulate

def search(query):
    print("Input question:", query)
    print()

    ##### BM25 search (lexical search) #####
    bm25_scores = bm25.get_scores(bm25_tokenizer(query))
    top_n = np.argpartition(bm25_scores, -5)[-5:]
    bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
    bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
    
    bm25_table = [["BM25", hit['score'], passages[hit['corpus_id']].replace("\n", " ")] for hit in bm25_hits[0:3]]

    print("Top-3 lexical search (BM25) hits:\n")
    print(tabulate(bm25_table, headers=["Method", "Score", "Passage"]))

    ##### Semantic Search #####
    # Encode the query using the bi-encoder and find potentially relevant passages
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    question_embedding = question_embedding.cuda()
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
    hits = hits[0]  # Get the hits for the first query

    ##### Re-Ranking #####
    # Now, score all retrieved passages with the cross_encoder
    cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)

    # Sort results by the cross-encoder scores
    for idx in range(len(cross_scores)):
        hits[idx]['cross-score'] = cross_scores[idx]

    # Create tables for bi-encoder and cross-encoder results
    bi_encoder_table = [["Bi-Encoder", hit['score'], passages[hit['corpus_id']].replace("\n", " ")] for hit in hits[0:3]]
    cross_encoder_table = [["Cross-Encoder", hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")] for hit in hits[0:3]]

    # Output of top-3 hits from bi-encoder
    print("\n-------------------------\n")
    print("Top-3 Bi-Encoder Retrieval hits:")
    print(tabulate(bi_encoder_table, headers=["Method", "Score", "Passage"]))

    # Output of top-3 hits from cross-encoder
    print("\n-------------------------\n")
    print("Top-3 Cross-Encoder Re-ranker hits:")
    print(tabulate(cross_encoder_table, headers=["Method", "Score", "Passage"]))

# Example usage
search(query="What is the capital of the United States?")
search(query="What is the best orchestra in the world?")
search(query="Number of countries Europe")


Input question: What is the capital of the United States?

Top-3 lexical search (BM25) hits:

Method      Score  Passage
--------  -------  ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
BM25      13.3159  Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.
BM25      11.434   Ohio is one of the 50 states in the United States. Its capital is Columbus. Columbus also is the largest city in Ohio.
BM25      11.1793  Nevada is one of the United States' states. Its capital is Carson City. Other big cities are Las Vegas and Reno.


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]


-------------------------

Top-3 Bi-Encoder Retrieval hits:
Method         Score  Passage
----------  --------  ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Bi-Encoder  0.621622  Cities in the United States:
Bi-Encoder  0.596905  The United States Capitol is the building where the United States Congress meets. It is the center of the legislative branch of the U.S. federal government. It is in Washington, D.C., on top of Capitol Hill at the east end of the National Mall.
Bi-Encoder  0.595733  In the United States:

-------------------------

Top-3 Cross-Encoder Re-ranker hits:
Method            Score  Passage
-------------  --------  --------------------------------------------------------------------------------------------------------------------------------------------------------------

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]


-------------------------

Top-3 Bi-Encoder Retrieval hits:
Method         Score  Passage
----------  --------  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Bi-Encoder  0.700826  The Vienna Philharmonic (in German: die Wiener Philharmoniker) is an orchestra based in Vienna, Austria. It is thought of as one of the greatest orchestras in the world.
Bi-Encoder  0.640915  The Vienna Symphony () is an orchestra in Vienna, Austria.
Bi-Encoder  0.639719  The Berlin Philharmonic (in German: Die Berliner Philharmoniker), is an orchestra from Berlin, Germany. It is one of the greatest orchestras in the world. The conductor of the orchestra is Sir Simon Rattle.

-------------------------

Top-3 Cross-Encoder Re-ranker hits:
Method            Score  Passage
-------------  --------  --------------------------------------------------------

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]


-------------------------

Top-3 Bi-Encoder Retrieval hits:
Method         Score  Passage
----------  --------  ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Bi-Encoder  0.537658  The Council of Europe (, ) is an international organization of 47 member states in the European region. One of its first successes was the European Convention on Human Rights in 1950, which serves as the basis for the European Court of Human Rights.
Bi-Encoder  0.530549  England is a country in Europe. It is a country with over sixty cities in it. It is in a union with Scotland, Wales and Northern Ireland. All four countries are in the British Isles and are part of the United Kingdom (UK).
Bi-Encoder  0.506799  Europe is a Swedish rock band. The band was started by Joey Tempest and John Norum in 1979. Their 