In [1]:
from sentence_transformers import SentenceTransformer, models, util
import pandas as pd
import numpy as np
import torch
import os
import json
import pickle
import time
import faiss

model_name = './output/bi-encoder-sup_hlf-rbtl3-2022-06-09'
dataset_path = './dureader/passage-collection/'

embedding_cache_path = 'embeddings-hfl-rbtl3.pkl'
embedding_size = 1024    #Size of embeddings
top_k_hits = 50         #Output k hits

In [2]:
model = SentenceTransformer(model_name)

In [3]:
%%time
corpus_sentences = []
for part in ['part-00', 'part-01', 'part-02', 'part-03']:
    with open(dataset_path + part, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            data = line.rstrip().split('\t')
            corpus_sentences.append(data[2][:256])
    
print('corpus_sentences: ', len(corpus_sentences))

corpus_sentences:  8096668
CPU times: user 25.1 s, sys: 4.53 s, total: 29.6 s
Wall time: 29.6 s


In [4]:
print("Encode the corpus. This might take a while")
pool = model.start_multi_process_pool()
corpus_embeddings = model.encode_multi_process(corpus_sentences, batch_size=128, pool=pool)

print("Store file on disc")
with open(embedding_cache_path, "wb") as fOut:
    pickle.dump({'sentences': corpus_sentences, 'embeddings': corpus_embeddings}, fOut)

Encode the corpus. This might take a while
Store file on disc


In [2]:
print("Load pre-computed embeddings from disc")
with open(embedding_cache_path, "rb") as fIn:
    cache_data = pickle.load(fIn)
    corpus_sentences = cache_data['sentences']
    corpus_embeddings = cache_data['embeddings']
del cache_data
del corpus_sentences

Load pre-computed embeddings from disc


In [3]:
### Create the FAISS index
print("Start creating FAISS index")

index= faiss.IndexFlatIP(embedding_size)

# First, we need to normalize vectors to unit length
corpus_embeddings = corpus_embeddings / np.linalg.norm(corpus_embeddings, axis=1)[:, None]

# Finally we add all embeddings to the index
index.add(corpus_embeddings)

Start creating FAISS index


In [4]:
faiss.write_index(index, 'faiss.index') 

In [None]:
os.system("shutdown")