In [1]:
import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import faiss
import torch
from tqdm import tqdm
import scml

In [2]:
faiss_nlist: int = 1000
model_file = "models/word2vec/20230129_000726/lightning_logs/version_0/checkpoints/epoch=5-step=76674.ckpt"
index_file = "output/item.index"
em_file = "output/em.npy"

In [3]:
tim = scml.Timer()
tim.start()
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
pd.set_option("use_inf_as_na", True)
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything()

In [4]:
ckp = torch.load(str(model_file))
print(ckp)

{'epoch': 5, 'global_step': 76674, 'pytorch-lightning_version': '1.7.7', 'state_dict': OrderedDict([('word_embeddings.weight', tensor([[ 0.2808, -0.1675, -0.0884,  ...,  0.4609,  0.0819, -0.6246],
        [-1.0734, -0.5885, -0.1136,  ...,  0.8975, -0.0801, -0.7135],
        [ 0.2650,  0.7763, -0.1735,  ..., -0.4250, -0.3904,  0.3702],
        ...,
        [-0.1343, -0.0453, -0.5101,  ..., -0.2197, -0.2182, -0.2723],
        [ 0.2942,  0.1648, -0.1698,  ..., -0.4534, -0.2354, -0.1568],
        [ 0.0597, -0.2618, -0.5689,  ...,  0.0629, -0.1870,  0.0186]],
       device='cuda:0')), ('type_embeddings.weight', tensor([[-6.0013e-02, -2.0535e-02,  2.1018e-01, -1.5173e-01,  1.8060e-01,
          1.5254e-01, -2.0961e-01,  8.7165e-02,  1.9033e-01,  8.5203e-02,
         -1.9330e-02,  1.0078e-03,  3.9269e-02,  8.3006e-02,  1.5975e-01,
         -1.4394e-01,  2.7029e-02,  1.0162e-01, -2.7421e-02,  3.4179e-02,
         -2.0325e-01,  4.9373e-02,  1.1454e-01, -1.0582e-01, -1.9734e-01,
         -6.0054

In [5]:
word_em = ckp["state_dict"]["word_embeddings.weight"].cpu().numpy()
type_em = ckp["state_dict"]["type_embeddings.weight"].cpu().numpy()
em = []
for wem in tqdm(word_em):
    for tem in type_em:
        em.append(wem + tem)
em = np.array(em, dtype=np.float32)
print(f"word_em.shape={word_em.shape}\ntype_em.shape={type_em.shape}\nem.shape={em.shape}")
assert word_em.shape[0] * type_em.shape[0] == em.shape[0]
faiss.normalize_L2(em)

100%|██████████████████████| 1855603/1855603 [00:06<00:00, 300042.53it/s]


word_em.shape=(1855603, 32)
type_em.shape=(3, 32)
em.shape=(5566809, 32)


In [6]:
%%time
np.save(em_file, em)

Wall time: 1.69 s


In [7]:
%%time
d = em.shape[1]
m = 8  # number of subquantizers
quantizer = faiss.IndexFlatIP(d)
# 8 specifies that each sub-vector is encoded as 8 bits
index = faiss.IndexIVFPQ(quantizer, d, faiss_nlist, m, 8)
index.verbose = True
index.train(em)
index.add(em)
faiss.write_index(index, str(index_file))

Wall time: 23.5 s


In [8]:
index = faiss.read_index(str(index_file))
assert index.is_trained and index.ntotal == em.shape[0]
# sanity check
index.nprobe = 1
k = 4
n = 1000
distances, ids = index.search(em[:n], k)
for i in range(n):
    if ids[i][0] != i:
        print(f"Expecting id {i} but found {ids[i][0]}")

Expecting id 677 but found 676


In [9]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")
print(f"Saved {str(index_file)}")

Total time taken 0:00:38.906730
Saved output/item.index
