In [8]:
import os
import json
import numpy as np
import pandas as pd
import faiss
import torch
from tqdm import tqdm
import scml

In [2]:
faiss_c: int = 50
faiss_k: int = 100
faiss_nlist: int = 1000        

In [3]:
tim = scml.Timer()
tim.start()
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
pd.set_option("use_inf_as_na", True)
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything()

In [4]:
ckp = torch.load("models/word2vec/20230127_031513/lightning_logs/version_0/checkpoints/epoch=1-step=12110.ckpt")
print(ckp)

{'epoch': 1, 'global_step': 12110, 'pytorch-lightning_version': '1.7.7', 'state_dict': OrderedDict([('word_embeddings.weight', tensor([[-0.4697, -0.0902, -0.2578,  ..., -0.4001,  0.0420,  0.0653],
        [ 0.4781, -0.0490,  0.7954,  ...,  0.7575, -0.1404,  0.7434],
        [-0.3789,  0.1923, -0.2426,  ...,  0.0597, -0.1023,  0.2375],
        ...,
        [ 0.1751, -0.0444, -0.3889,  ..., -0.2694,  0.2358, -0.2691],
        [ 0.4944, -0.0559,  0.0382,  ..., -0.2361, -0.2321, -0.2804],
        [-0.6416,  0.0252,  0.3920,  ..., -0.2325,  0.4232,  0.0895]],
       device='cuda:0')), ('type_embeddings.weight', tensor([[-2.8271e-02,  2.5899e-02, -3.8499e-03, -1.2672e-02,  1.8102e-02,
         -7.6821e-03,  6.0259e-03,  2.0547e-02,  8.2483e-03,  2.2127e-02,
         -3.4239e-02, -3.6754e-02, -2.0670e-03,  6.0558e-03, -1.1508e-02,
         -1.6790e-02, -6.7824e-03,  3.1600e-02, -9.1243e-03,  2.4465e-02,
         -4.9653e-03, -3.1446e-02,  1.8366e-02, -2.0687e-02,  4.1522e-02,
         -2.8023

In [5]:
word_em = ckp["state_dict"]["word_embeddings.weight"].cpu().numpy()
type_em = ckp["state_dict"]["type_embeddings.weight"].cpu().numpy()
em = []
for wem in tqdm(word_em):
    for tem in type_em:
        em.append(wem + tem)
em = np.array(em, dtype=np.float32)
print(f"word_em.shape={word_em.shape}\ntype_em.shape={type_em.shape}\nem.shape={em.shape}")
assert word_em.shape[0] * type_em.shape[0] == em.shape[0]

100%|██████████████████████| 1855603/1855603 [00:06<00:00, 305818.16it/s]


word_em.shape=(1855603, 64)
type_em.shape=(3, 64)
em.shape=(5566809, 64)


In [9]:
%%time
faiss.normalize_L2(em)
d = em.shape[1]
m = 8  # number of subquantizers
quantizer = faiss.IndexFlatIP(d)
# 8 specifies that each sub-vector is encoded as 8 bits
index = faiss.IndexIVFPQ(quantizer, d, faiss_nlist, m, 8)
index.verbose = True
index.train(em)
index.add(em)
filepath = f"output/item.index"
faiss.write_index(index, filepath)
print(f"Saved {filepath}")

Saved output/item.index
Wall time: 22.5 s


In [10]:
index = faiss.read_index(filepath)
assert index.is_trained and index.ntotal == em.shape[0]
# sanity check
index.nprobe = 1
k = 4
n = 1000
distances, ids = index.search(em[:n], k)
for i in range(n):
    if ids[i][0] != i:
        print(f"Expecting id {i} but found {ids[i][0]}")

Expecting id 5 but found 4
Expecting id 275 but found 274
Expecting id 488 but found 487
Expecting id 875 but found 874


In [11]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")

Total time taken 0:00:58.231592
