In [2]:
import json
from sentence_transformers import SentenceTransformer, util
import time
import gzip
import os
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = 'nq-distilbert-base-v1'
bi_encoder = SentenceTransformer(model_name)
top_k = 5 

Downloading: 100%|██████████| 690/690 [00:00<00:00, 234kB/s]
Downloading: 100%|██████████| 190/190 [00:00<00:00, 95.2kB/s]
Downloading: 100%|██████████| 3.69k/3.69k [00:00<00:00, 922kB/s]
Downloading: 100%|██████████| 540/540 [00:00<00:00, 182kB/s]
Downloading: 100%|██████████| 122/122 [00:00<00:00, 115kB/s]
Downloading: 100%|██████████| 265M/265M [02:04<00:00, 2.14MB/s] 
Downloading: 100%|██████████| 53.0/53.0 [00:00<00:00, 50.5kB/s]
Downloading: 100%|██████████| 112/112 [00:00<00:00, 37.4kB/s]
Downloading: 100%|██████████| 466k/466k [00:02<00:00, 188kB/s]  
Downloading: 100%|██████████| 554/554 [00:00<00:00, 138kB/s]
Downloading: 100%|██████████| 232k/232k [00:01<00:00, 218kB/s]  
Downloading: 100%|██████████| 229/229 [00:00<00:00, 116kB/s]


In [5]:

wikipedia_filepath = 'data/simplewiki-2020-11-01.jsonl.gz'

if not os.path.exists(wikipedia_filepath):
    util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)


In [6]:
passages = []
with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        data = json.loads(line.strip())
        for paragraph in data['paragraphs']:
            # We encode the passages as [title, text]
            passages.append([data['title'], paragraph])

In [18]:
passages[1]

['Aileen Wuornos',
 'Aileen Carol Wuornos Pralle (born Aileen Carol Pittman; February 29, 1956\xa0– October 9, 2002) was an American serial killer. She was born in Rochester, Michigan. She confessed to killing six men in Florida and was executed in Florida State Prison by lethal injection for the murders. Wuornos said that the men she killed had raped her or tried to rape her while she was working as a prostitute.']

In [14]:
# If you like, you can also limit the number of passages you want to use
print("Passages:", len(passages))

# To speed things up, pre-computed embeddings are downloaded.
# The provided file encoded the passages with the model 'nq-distilbert-base-v1'
if model_name == 'nq-distilbert-base-v1':
    embeddings_filepath = 'simplewiki-2020-11-01-nq-distilbert-base-v1.pt'
    if not os.path.exists(embeddings_filepath):
        util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01-nq-distilbert-base-v1.pt', embeddings_filepath)

    
    corpus_embeddings = torch.load(embeddings_filepath,map_location=torch.device('cpu'))
    corpus_embeddings = corpus_embeddings.float()  # Convert embedding file to float
    if torch.cuda.is_available():
        corpus_embeddings = corpus_embeddings.to('cuda')
else:  # Here, we compute the corpus_embeddings from scratch (which can take a while depending on the GPU)
    corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)


Passages: 509663


In [16]:
while True:
    query = input("Please enter a question: ")

    # Encode the query using the bi-encoder and find potentially relevant passages
    start_time = time.time()
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
    hits = hits[0]  # Get the hits for the first query

    end_time = time.time()

    # Output of top-k hits
    print("Input question:", query)
    print("Results (after {:.3f} seconds):".format(end_time - start_time))
    for hit in hits:
        print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']]))

    print("\n\n========\n")

Input question: Why is ESG important for companies?
Results (after 0.179 seconds):
	0.515	['Computer', 'One of the most important jobs that computers do for people is helping with communication. Communication is how people share information. Computers have helped people move forward in science, medicine, business, and learning, because they let experts from anywhere in the world work with each other and share information. They also let other people communicate with each other, do their jobs almost anywhere, learn about almost anything, or share their opinions with each other. The Internet is the thing that lets people communicate between their computers.']
	0.507	['Industrial engineering', 'Industrial engineering is a type of engineering. It is one of the fastest growing areas of engineering. It looks at what makes organizations work best. An industrial engineer tries to find the right combination of human and natural resources, technology, equipment, information and finance to do the 

KeyboardInterrupt: Interrupted by user