In [1]:
import pickle
import numpy as np
import time
import faiss
import json
import random
import sys
import os
from tqdm import tqdm

import torch.cuda
from sentence_transformers import SentenceTransformer

## Load hotpot data

In [19]:
facts_path = "corpus.jsonl"
with open(facts_path) as f:
    fact_sets = [json.loads(line) for line in f.readlines()]
facts = [fact_set["text"] for fact_set in fact_sets]

facts_path = "queries.jsonl"
with open(facts_path) as f:
    questions_sets = [json.loads(line) for line in f.readlines()]
questions = [questions_set["text"] for questions_set in questions_sets]

In [3]:
len(facts)

5233329

In [4]:
len(facts[0])

349

In [28]:
# check min and max length of fact, not necessary to run
max_l, min_l = 0, 999999999999
max_words, min_words = 0, 999999999999
for fact in facts:
    max_l, min_l = max(max_l, len(fact)), min(min_l, len(fact))
    c = fact.count(" ")
    max_words, min_words = max(max_words, c), min(min_words, c)
print(max_l, min_l)
print(max_words, min_words)

8237 0
1377 0


# Encode dataset

In [103]:
# model_name = 'msmarco-distilbert-base-v4'
# model_name = "multi-qa-MiniLM-L6-cos-v1"
model_name = 'all-mpnet-base-v2'
model = SentenceTransformer(model_name, device='cuda')

In [4]:
embeddings = model.encode(facts, device='cuda', show_progress_bar=True)

Batches:   0%|          | 0/163542 [00:00<?, ?it/s]

### Save embeddings

In [5]:
with open(f'emb_hotpot_{model_name}.pickle', 'wb') as f:
    pickle.dump(embeddings, f, protocol=pickle.HIGHEST_PROTOCOL)

### Load embeddings

In [None]:
model_name = 'all-mpnet-base-v2'
with open(f'emb_hotpot_{model_name}.pickle', 'rb') as f:
    embeddings = pickle.load(f)

# Create index

In [12]:
d = embeddings.shape[1]

In [9]:
# index_flatip = faiss.IndexFlatIP(embeddings.shape[1])
# index_flatip.add(embeddings)
# sys.getsizeof(index_flatip)

48

In [14]:
quant = faiss.IndexFlatL2(d)
index_ivfpqflat = faiss.IndexIVFPQ(quant, d, 100, 8, 8)
index_ivfpqflat.train(embeddings)
index_ivfpqflat.add(embeddings)

In [None]:
# quant = faiss.IndexFlatIP(embeddings.shape[1])
# index_ivfflat = faiss.IndexIVFFlat(quant, embeddings.shape[1], 100)
# index_ivfflat.train(embeddings)
# index_ivfflat.add(embeddings)

### Save index

In [13]:
index_name = "index_ivfpqflat"
with open(f'{index_name}_hotpot_{model_name}.pickle', 'wb') as f:
    pickle.dump(index_ivfpqflat, f, protocol=pickle.HIGHEST_PROTOCOL)

In [82]:
index_name = "index_ivfpqflat"
model_name = 'all-mpnet-base-v2'
with open(f'{index_name}_hotpot_{model_name}.pickle', 'rb') as f:
    index = pickle.load(f)

## Test retrievers against each other

In [164]:
def test_sample(sample, fact_sets, questions_sets, index1, index2, model1, model2, prtlvl=0):
    hits_sum = {'1': 0, '2': 0}
    versions = [(model1, index1), (model2, index2)]

    for i in sample:
        q, target_facts = questions_sets[i]["text"], questions_sets[i]["metadata"]["supporting_facts"]

        for vi, (model, index) in enumerate(versions):
            D, I = index.search(model.encode([q]), k)
            f = [fact_sets[j]['title'] for j in I[0]]
            ft = [fact[0] for fact in target_facts]
            hits = len(list(set(f) & set(ft)))
            hits_sum[str(vi+1)] += hits
            if prtlvl > 0:
                if vi == 0 or prtlvl > 2: print(q)
                print("  retrieved: " + str(f))
                print("  target: " + str(ft))
                print("  hits: " + str(hits))

                if prtlvl > 1:
                    for j in I[0]:
                        print(f"    -> {fact_sets[j]['title']} - {fact_sets[j]['text']}\n")


                if vi < len(versions) - 1:
                    print("  " + "*" * 50)
                else:
                    print("-" * 100)
    print(hits_sum)

In [83]:
model2_name = 'multi-qa-MiniLM-L6-cos-v1'
model2 = SentenceTransformer(model2_name, device='cuda')
index2_name = "index_ivfpq"

with open(f'{index2_name}_hotpot_{model2_name}.pickle', 'rb') as f:
    index2 = pickle.load(f)

In [167]:
k = 10
n_qs = 50
sample = random.sample(range(len(questions)), n_qs)

In [170]:
test_sample(sample, fact_sets, questions_sets, index, index2, model, model2, prtlvl=0)

{'1': 11, '2': 12}


# Dump/load embeddings in batch files

In [116]:
def save_batched_embs(embeddings, directory, batch_volume):
    required_space_gb = embeddings.shape[0] * embeddings.shape[1] * 4 / batch_volume
    batch_size =  int(np.ceil(batch_volume / (embeddings.shape[1] * 4))) # n of rows that is 1 gbs of data
    n_batches = int(np.ceil(embeddings.shape[0] / batch_size))
    print(f"Required space: {required_space_gb}, {n_batches} of files with {batch_size} embeddings each and filesize {batch_volume / 1000000000} GB")
    batch_indexes = [(batch_size*x, batch_size*(x+1)) for x in range(n_batches)]
    if not os.path.exists(directory):
        os.makedirs(directory)

    for i, (lower, upper) in tqdm(enumerate(batch_indexes)):
        fdir = f"{directory}/{str(i).zfill(3)}.pickle"
        emb_batch = embeddings[lower:upper]

        with open(fdir, 'wb') as f:
            pickle.dump(emb_batch, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_all_batched_embs(directory):
    files = os.listdir(directory)
    files.sort()
    embeddings = []
    for f in tqdm(files):
        with open(f"{directory}/{f}", 'rb') as f:
            embeddings.append(pickle.load(f))
    return np.concatenate(embeddings)

# generator returning batches of embeddings from directory
def batched_embs_generator(directory):
    files = os.listdir(directory)
    files.sort()
    for f in files:
        with open(f"{directory}/{f}", 'rb') as f:
            yield pickle.load(f)

In [117]:
model_name = 'all-mpnet-base-v2'
batch_volume = 1000000000 # size in bytes
directory = f"./emb_hotpot_{model_name}_b"

save_batched_embs(embeddings, directory, batch_volume)

# Train index from batched embeddings files

In [118]:
# initiate index
d = 768 # embedding length
quant = faiss.IndexFlatL2(d)
index_batched = faiss.IndexIVFPQ(quant, d, 100, 8, 8)

# train in batches
gen = batched_embs_generator(directory)
for i, emb_batch in tqdm(enumerate(gen)):
    if i==0:
        index_batched.train(emb_batch)
    index_batched.add(emb_batch)

17it [00:46,  2.73s/it]
