In [1]:
import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import faiss
import torch
from tqdm import tqdm
import scml

In [2]:
faiss_nlist: int = 1000
model_file = "models/word2vec/20230131_030406/lightning_logs/version_0/checkpoints/epoch=17-step=45036.ckpt"
index_file = "output/m8_w7_i10.index"
em_file = "output/m8_w7_i10.npy"

In [3]:
tim = scml.Timer()
tim.start()
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
pd.set_option("use_inf_as_na", True)
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything()

In [4]:
ckp = torch.load(str(model_file))
print(ckp)

{'epoch': 17, 'global_step': 45036, 'pytorch-lightning_version': '1.7.7', 'state_dict': OrderedDict([('word_embeddings.weight', tensor([[ 0.4950, -0.4440,  0.6524,  ...,  0.7314,  0.6554,  0.0495],
        [ 0.6121, -1.4272,  1.2639,  ..., -0.4463,  2.4568,  0.7586],
        [ 1.5778,  0.1984, -0.1209,  ..., -0.0077,  0.2741, -0.9397],
        ...,
        [ 1.7135,  0.4625, -2.1386,  ...,  2.8617,  0.6442, -1.3882],
        [ 1.7531,  0.5239, -2.5799,  ...,  2.7967,  0.4245, -1.2093],
        [ 1.5163,  0.4189, -2.2055,  ...,  2.8239,  0.6003, -1.2564]],
       device='cuda:0')), ('type_embeddings.weight', tensor([[-6.9059e-02,  4.6103e-02,  9.0372e-02, -8.6478e-02, -9.6168e-02,
          7.5709e-02, -7.6599e-02, -1.3479e-02,  9.9222e-02, -2.9230e-02,
         -5.8211e-02,  1.6897e-03, -1.9739e-02,  3.3352e-02,  1.2558e-02,
         -1.2016e-01, -6.6455e-02,  1.0979e-01,  5.7964e-02,  9.7062e-02,
         -1.7890e-02,  8.1899e-02,  3.3800e-02, -2.1709e-02, -1.1737e-01,
         -5.637

In [5]:
word_em = ckp["state_dict"]["word_embeddings.weight"].cpu().numpy()
type_em = ckp["state_dict"]["type_embeddings.weight"].cpu().numpy()
em = []
for wem in tqdm(word_em):
    for tem in type_em:
        em.append(wem + tem)
em = np.array(em, dtype=np.float32)
print(f"word_em.shape={word_em.shape}\ntype_em.shape={type_em.shape}\nem.shape={em.shape}")
assert word_em.shape[0] * type_em.shape[0] == em.shape[0]
faiss.normalize_L2(em)

100%|██████████████████████| 1855603/1855603 [00:05<00:00, 310753.50it/s]


word_em.shape=(1855603, 32)
type_em.shape=(3, 32)
em.shape=(5566809, 32)


In [6]:
%%time
np.save(em_file, em)

Wall time: 1.94 s


In [7]:
%%time
d = em.shape[1]
m = 8  # number of subquantizers
quantizer = faiss.IndexFlatIP(d)
# 8 specifies that each sub-vector is encoded as 8 bits
index = faiss.IndexIVFPQ(quantizer, d, faiss_nlist, m, 8)
index.verbose = True
index.train(em)
index.add(em)
faiss.write_index(index, str(index_file))

Wall time: 17.7 s


In [8]:
index = faiss.read_index(str(index_file))
assert index.is_trained and index.ntotal == em.shape[0]
# sanity check
index.nprobe = 1
k = 4
n = 1000
distances, ids = index.search(em[:n], k)
for i in range(n):
    if ids[i][0] != i:
        print(f"Expecting id {i} but found {ids[i][0]}")

In [9]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")
print(f"Saved {str(index_file)}")

Total time taken 0:00:30.362810
Saved output/m8_w7_i10.index
