In [1]:
import os
import time
import json
import numpy as np
import pandas as pd
import faiss

from tqdm.notebook import tqdm
import math
from itertools import product

from dotenv import load_dotenv

In [2]:
load_dotenv("../.env")

True

In [3]:
data_path = "/storage/st079069/dataset_for_faiss/"

In [4]:
chunks_base_path = data_path + "chunks_embeddings_jina_base.json"
chunks_pass_path = data_path + "chunks_embeddings_jina_passage.json"

questions_base_path = data_path + "questions_embeddings_jina_base.npy"
questions_query_path = data_path + "questions_embeddings_jina_query.npy"

In [5]:
chunks_base = pd.read_json(chunks_base_path, orient='index')

chunks_pass = pd.read_json(chunks_pass_path, orient='index')

In [6]:
questions_base_emb = np.load(questions_base_path)

questions_query_emb = np.load(questions_query_path)

In [7]:
chunks_base_ids = dict()

i = 0
for id, emb_list in tqdm(chunks_base.iterrows()):
    ln = len(emb_list.dropna())
    chunks_base_ids[id] = np.arange(i, i+ln)
    i += ln

chunks_base_emb = np.array([emb for emb in tqdm(chunks_base.values.flatten()) if emb is not None])

0it [00:00, ?it/s]

  0%|          | 0/226450 [00:00<?, ?it/s]

In [8]:
chunks_pass_ids = dict()

i = 0
for id, emb_list in tqdm(chunks_pass.iterrows()):
    ln = len(emb_list.dropna())
    chunks_pass_ids[id] = np.arange(i, i+ln)
    i += ln

chunks_pass_emb = np.array([emb for emb in tqdm(chunks_pass.values.flatten()) if emb is not None])

0it [00:00, ?it/s]

  0%|          | 0/226450 [00:00<?, ?it/s]

In [32]:
def name_w_params(name: str, params: list[object]):
    return "_".join([name] + [p.__name__ if isinstance(p, faiss.Index) else str(p) for p in params])

In [33]:
embedding_dim = chunks_pass_emb.shape[1]

In [34]:
flat_index = faiss.IndexFlatL2(embedding_dim)

In [35]:
indexes_with_params = {"flat": {"flat" : flat_index}}

In [36]:
# HNSW paremeters
M_list = [16, 32, 64]  # number of connections each vertex will have
ef_search_list = [16, 32, 64]  # depth of layers explored during search
ef_construction_list = [16, 32, 64]  # depth of layers explored during index construction

In [37]:
indexes_with_params["hnsw"] = dict()

for params in product(*[M_list, ef_search_list, ef_construction_list]):
    M, ef_search, ef_construction = params
    if ef_search > ef_construction:
        continue
    index = faiss.IndexHNSWFlat(embedding_dim, M)
    index.hnsw.efConstruction = ef_construction
    index.hnsw.efSearch = ef_search

    indexes_with_params["hnsw"][name_w_params("hnsw", params)] = index

In [38]:
# SQ parameters
quantizer_list = [
    faiss.ScalarQuantizer.QT_8bit,
    faiss.ScalarQuantizer.QT_4bit,
    faiss.ScalarQuantizer.QT_8bit_uniform,
    faiss.ScalarQuantizer.QT_4bit_uniform
]

In [39]:
indexes_with_params["hnsw_sq"] = dict()

for params in product(*[M_list, ef_search_list, ef_construction_list, quantizer_list]):
    M, ef_search, ef_construction, scalar_quantizer = params
    if ef_search > ef_construction:
        continue
    index = faiss.IndexHNSWSQ(embedding_dim, scalar_quantizer, M)
    index.hnsw.efConstruction = ef_construction
    index.hnsw.efSearch = ef_search

    indexes_with_params["hnsw_sq"][name_w_params("hnsw_sq", params)] = index

In [40]:
# PQ parameters
M_pq_list = [4, 8, 16, 32]
nbits_list = [6, 8, 9] # needs >= 2**nbits * 39 training points

In [41]:
indexes_with_params["pq"] = dict()

for params in product(*[M_pq_list, nbits_list]):
    M_pq, nbits = params
    assert embedding_dim % M_pq == 0
    index = faiss.IndexPQ(embedding_dim, M_pq, nbits)
    
    indexes_with_params["pq"][name_w_params("pq", params)] = index

In [42]:
indexes_with_params["hnsw_pq"] = dict()

for params in tqdm(product(*[M_pq_list, nbits_list, M_list, ef_search_list, ef_construction_list])):
    M_pq, nbits, M, ef_search, ef_construction = params
    assert embedding_dim % M_pq == 0
    if ef_search > ef_construction:
        continue
    index = faiss.IndexHNSWPQ(embedding_dim, M_pq, M, nbits)
    index.hnsw.efConstruction = ef_construction
    index.hnsw.efSearch = ef_search

    indexes_with_params["hnsw_pq"][name_w_params("hnsw_pq", params)] = index
print(len(indexes_with_params["hnsw_pq"]))

0it [00:00, ?it/s]

216


In [43]:
# IVF parameters
quantizer_ivf_list = [faiss.IndexFlatIP, faiss.IndexFlatL2]
nlist_list = [32, 64, 128, 256] # needs >= nlist * 39 training points
nprobe_list = [16, 32, 64]

In [44]:
indexes_with_params["ivf"] = dict()

for params in product(*[quantizer_ivf_list, nlist_list, nprobe_list]):
    quantizer, nlist, nprobe = params
    if nprobe > nlist:
        continue
    index = faiss.IndexIVFFlat(quantizer(embedding_dim), embedding_dim, nlist)
    index.nprobe = nprobe

    quantizer = quantizer.__name__
    indexes_with_params["ivf"][name_w_params("ivf", [quantizer, nlist, nprobe])] = index

In [45]:
indexes_with_params["ivf_pq"] = dict()

for params in tqdm(product(*[quantizer_ivf_list, nlist_list, nprobe_list, M_pq_list, nbits_list])):
    quantizer, nlist, nprobe, M, nbits = params
    assert embedding_dim % M == 0
    if nprobe > nlist:
        continue

    index = faiss.IndexIVFPQ(quantizer(embedding_dim), embedding_dim, nlist, M, nbits, faiss.METRIC_L2)
    index.nprobe = nprobe

    quantizer = quantizer.__name__
    indexes_with_params["ivf_pq"][name_w_params("ivf_pq", [quantizer, nlist, nprobe, M, nbits])] = index

0it [00:00, ?it/s]

In [46]:
indexes_with_params["ivfHNSW"] = dict()

for params in tqdm(product(*[nlist_list, nprobe_list, M_list, ef_search_list, ef_construction_list])):
    nlist, nprobe, M, ef_search, ef_construction = params
    if ef_search > ef_construction:
        continue
    if nprobe > nlist:
        continue
    quantizer = faiss.IndexHNSWFlat(embedding_dim, M)
    quantizer.hnsw.efConstruction = ef_construction
    quantizer.hnsw.efSearch = ef_search

    index = faiss.IndexIVFFlat(quantizer, embedding_dim, nlist, faiss.METRIC_L2)
    index.nprobe = nprobe

    # index.cp.min_points_per_centroid = 5
    index.quantizer_trains_alone = 2

    indexes_with_params["ivfHNSW"][name_w_params("ivfHNSW", params)] = index

0it [00:00, ?it/s]

In [47]:
k = 16 # top-k queries

In [48]:
def get_top_k(index: faiss.Index, queries, k):
    t0 = time.time()
    D, I = index.search(queries, k)
    t1 = time.time()
    return D, I, (t1 - t0) / queries.shape[0]


def dcg(scores):
    return np.sum((np.power(2 * np.ones_like(scores), scores) - 1) / np.log2(np.arange(len(scores)) + 2))
    # return np.sum(scores / np.log2(np.arange(len(scores)) + 2))


def ndcg_with_mismatch(true_indices, true_distances, approx_indices, verbose=False):
    # Map ground truth indices to their relevance scores
    # true_relevance = {idx: 4 / (dist + 1) if not math.isinf(dist) else 0
    #                   for idx, dist in zip(true_indices, true_distances)}
    true_relevance = dict(zip(true_indices, np.exp(- 0.01 * true_distances)))
    # print(true_relevance)

    if verbose:
        if not np.all(np.isnan(true_distances)):
            print("WARNING: there is an nan distance")
        if not np.all(true_distances >= 0):
            print("WARNING: there is a < 0 distance")

    # Create relevance scores list for the approximate indices
    approx_relevance_scores = [
        true_relevance.get(idx, 0) for idx in approx_indices
    ]

    # Create the ideal DCG by sorting the true relevance scores in descending order
    ideal_relevance_scores = sorted(true_relevance.values(), reverse=True)
    ideal_relevance_scores.extend([0]*(len(approx_relevance_scores) - len(ideal_relevance_scores)))

    # Calculate DCG for both approximate and ideal relevance scores
    dcg_approx = dcg(approx_relevance_scores)
    dcg_ideal = dcg(ideal_relevance_scores)
    # print(dcg_approx, dcg_ideal)

    return dcg_approx / dcg_ideal if dcg_ideal > 0 else 0


def calculate_mean_ndcg_mismatch(ideal_index: faiss.Index, index: faiss.Index, queries: np.ndarray, k: int = 10):
    # Get distances and indices from both indices
    true_distances, true_indices, _ = get_top_k(ideal_index, queries, k)
    _, approx_indices, mean_time = get_top_k(index, queries, k)

    # Compute nDCG for each query with potentially mismatched sets
    ndcg_scores = [
        ndcg_with_mismatch(true_idx, true_dist, approx_idx)
        for true_idx, true_dist, approx_idx in zip(true_indices, true_distances, approx_indices)
    ]
    mean_ndcg = np.mean(ndcg_scores)

    return mean_ndcg, mean_time

In [49]:
def calculate_disk_usage(index: faiss.Index, index_name: str):
    file_name = "../data/index/" + index_name + ".index"
    faiss.write_index(index, file_name)

    index_size = os.path.getsize(file_name)
    os.remove(file_name)

    index_size_mb = index_size / (1024 * 1024)

    return index_size_mb

In [50]:
def train_index(index: faiss.Index, data):
    t0 = time.time()
    index.train(data)
    t1 = time.time()

    return t1 - t0

In [51]:
def construct_index(index: faiss.Index, data):
    assert index.is_trained
    t0 = time.time()
    index.add(data)
    t1 = time.time()

    return t1 - t0

In [52]:
def round_down(a: float):
    return math.floor(a * 100)/100.0

In [53]:
embeddings_np_train = chunks_base_emb[:20000]
embeddings_np_test = questions_base_emb[:3000]

In [54]:
results: list[list] = list()

for group_name in indexes_with_params.keys():
    if os.path.exists("../data/index_res_base/" + group_name + ".csv"):
        continue

    group_results: list[list] = list()
    group_best: list = ["", -1, -1, -1, -1]

    list_indexes = list(indexes_with_params[group_name].keys())
    if len(list_indexes) > 50:
        list_indexes = np.sort(np.random.choice(list_indexes, max(len(list_indexes)//5, 50), replace=False))
    
    pbar = tqdm(list_indexes)
    for name in pbar:
        index = indexes_with_params[group_name][name]
        pbar.set_description(name)
        train_time = -1 if index.is_trained else train_index(index, embeddings_np_train)*1000
        assert index.is_trained
        construct_time = construct_index(index, embeddings_np_train)*1000
        assert index.ntotal > 0
        flat_index_local = faiss.IndexFlatL2(embedding_dim)
        flat_index_local.add(embeddings_np_train)
        mean_ndcg, mean_time = calculate_mean_ndcg_mismatch(flat_index_local, index, embeddings_np_test, k)
        index_size_mb = calculate_disk_usage(index, name)
        if mean_ndcg > 1:
            print(mean_ndcg)
            assert 0
        index.reset()

        res = [
            name,
            round_down(mean_ndcg),
            round_down(index_size_mb),
            round_down(mean_time*1000),
            round_down(train_time),
            round_down(construct_time)
        ]
        group_results.append(res)
        if res[1] > group_best[1]:
            group_best = res
            pbar.set_postfix({"NDCG": res[1], "index": name})

    df = pd.DataFrame(group_results, columns=["index", "NDCG", "size_mb", "mean_time_query_ms", "time_train_ms", "time_construct_ms"])
    df.to_csv("../data/index_res_base/" + group_name + ".csv")
    results.extend(group_results)
    pbar.set_description(group_name)

res_df = pd.DataFrame(results, columns=["index", "NDCG", "size_mb", "mean_time_query_ms", "time_train_ms", "time_construct_ms"])
res_df.to_csv("../data/index_res_base/all.csv")

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/52 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]