In [None]:
import pickle
import numpy as np
import faiss
import json
import os
from tqdm import tqdm

from sentence_transformers import SentenceTransformer

## Load text data
In this section, we load the HotpotQA documents, which are facts and queries. Keep in mind that you might have to change the path as well as the loading script according to your data.

In [2]:
dataset_name = "hotpot"

facts_path = "corpus.jsonl"
with open(facts_path) as f:
    fact_sets = [json.loads(line) for line in f.readlines()]
facts = [fact_set["text"] for fact_set in fact_sets]

facts_path = "queries.jsonl"
with open(facts_path) as f:
    questions_sets = [json.loads(line) for line in f.readlines()]
questions = [questions_set["text"] for questions_set in questions_sets]

# Encode dataset
Next up the pre-trained model is loaded and the dataset is encoded. The encoding is done on the GPU if available, however it will take a few hours and lots of RAM to process - for HotpotQA it would be around 15 GB. There is a lighter batch encoding method available below if your resources are limited.

In [15]:
model_name = "multi-qa-MiniLM-L6-cos-v1"
model = SentenceTransformer(model_name, device='cuda')

In [4]:
embeddings = model.encode(facts, device='cuda', show_progress_bar=True)

Batches:   0%|          | 0/163542 [00:00<?, ?it/s]

### Save embeddings
The embeddings are saved to disk for later use.

In [5]:
with open(f'emb_{dataset_name}_{model_name}.pickle', 'wb') as f:
    pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)

### Dump/load embeddings in batch files
If you don't have enough RAM for storing the whole embedding database, the text can be encoded and saved in batches. Then the index can be also trained with batches.

In [116]:
def save_batched_embs(embeddings, directory, batch_volume):
    required_space_gb = embeddings.shape[0] * embeddings.shape[1] * 4 / batch_volume
    batch_size =  int(np.ceil(batch_volume / (embeddings.shape[1] * 4))) # n of rows that is 1 gbs of data
    n_batches = int(np.ceil(embeddings.shape[0] / batch_size))
    print(f"Required space: {required_space_gb}, {n_batches} of files with {batch_size} embeddings each and filesize {batch_volume / 1000000000} GB")
    batch_indexes = [(batch_size*x, batch_size*(x+1)) for x in range(n_batches)]
    if not os.path.exists(directory):
        os.makedirs(directory)

    for i, (lower, upper) in tqdm(enumerate(batch_indexes)):
        fdir = f"{directory}/{str(i).zfill(3)}.pickle"
        emb_batch = embeddings[lower:upper]

        with open(fdir, 'wb') as f:
            pickle.dump(emb_batch, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_all_batched_embs(directory):
    files = os.listdir(directory)
    files.sort()
    embeddings = []
    for f in tqdm(files):
        with open(f"{directory}/{f}", 'rb') as f:
            embeddings.append(pickle.load(f))
    return np.concatenate(embeddings)

# generator returning batches of embeddings from directory
def batched_embs_generator(directory):
    files = os.listdir(directory)
    files.sort()
    for f in files:
        with open(f"{directory}/{f}", 'rb') as f:
            yield pickle.load(f)

In [117]:
model_name = "multi-qa-MiniLM-L6-cos-v1"
batch_volume = 1000000000 # size in bytes
batch_directory = f"./emb_{dataset_name}_{model_name}_b"

save_batched_embs(embeddings, batch_directory, batch_volume)

### Load embeddings
Embeddings can be loaded if they have been saved before.

In [None]:
with open(f'emb_{dataset_name}_{model_name}.pickle', 'rb') as f:
    embeddings = pickle.load(f)

#### Batched version

In [None]:
embeddings = load_all_batched_embs(batch_directory)

# Create and train index
We create an index for the embeddings with the default parameters. This will take a few minutes.

In [6]:
d = embeddings.shape[1]

In [7]:
quant = faiss.IndexFlatL2(d)
index_ivfpq = faiss.IndexIVFPQ(quant, d, 100, 8, 8)
index_ivfpq.train(embeddings)
index_ivfpq.add(embeddings)

# Train index from batched embeddings files

In [118]:
# initiate index
d = 384 # embedding length
quant = faiss.IndexFlatL2(d)
index_ivfpq = faiss.IndexIVFPQ(quant, d, 100, 8, 8)

# train in batches
gen = batched_embs_generator(batch_directory)
for i, emb_batch in tqdm(enumerate(gen)):
    if i==0:
        index_ivfpq.train(emb_batch)
    index_ivfpq.add(emb_batch)

17it [00:46,  2.73s/it]


### Save index

In [8]:
index_name = "index_ivfpq"
with open(f'{index_name}_hotpot_{model_name}.pickle', 'wb') as f:
    pickle.dump(index_ivfpq, f, protocol=pickle.HIGHEST_PROTOCOL)