In [1]:
import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import faiss
import torch
from tqdm import tqdm
import scml

In [2]:
faiss_nlist: int = 1000
model_file = "models/word2vec/20230130_025602/lightning_logs/version_0/checkpoints/epoch=0-step=10697.ckpt"
index_file = "output/20230130_025602.index"
em_file = "output/20230130_025602.npy"

In [3]:
tim = scml.Timer()
tim.start()
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
pd.set_option("use_inf_as_na", True)
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything()

In [4]:
ckp = torch.load(str(model_file))
print(ckp)

{'epoch': 0, 'global_step': 10697, 'pytorch-lightning_version': '1.7.7', 'state_dict': OrderedDict([('word_embeddings.weight', tensor([[-0.2121,  0.3218,  0.1166,  ..., -0.1201,  0.0314,  0.1183],
        [-0.0597,  0.4598, -0.0143,  ...,  0.3505, -0.1453, -0.0897],
        [-0.1908,  0.0925,  0.1736,  ..., -0.1001,  0.0232, -0.1467],
        ...,
        [-0.1057, -0.2042,  0.1163,  ...,  0.0912,  0.0080, -0.2655],
        [ 0.0415,  0.0685,  0.1066,  ..., -0.2149, -0.5060,  0.0733],
        [-0.1519, -0.3222, -0.2716,  ...,  0.0119, -0.3107,  0.1295]],
       device='cuda:0')), ('type_embeddings.weight', tensor([[ 5.0394e-02,  6.8741e-02, -1.3322e-02, -1.0164e-02,  4.5324e-02,
         -3.1418e-02,  3.4683e-03, -2.5713e-03, -3.4160e-02,  3.5731e-02,
          2.9226e-02,  3.9432e-02, -1.0001e-02,  7.0614e-02,  1.7031e-02,
          2.2364e-02,  3.6731e-02, -2.7171e-02, -3.5089e-02,  1.9941e-02,
          1.0660e-01, -5.0294e-02, -3.6797e-02,  4.6865e-02,  1.9488e-02,
          2.3824

In [5]:
word_em = ckp["state_dict"]["word_embeddings.weight"].cpu().numpy()
type_em = ckp["state_dict"]["type_embeddings.weight"].cpu().numpy()
em = []
for wem in tqdm(word_em):
    for tem in type_em:
        em.append(wem + tem)
em = np.array(em, dtype=np.float32)
print(f"word_em.shape={word_em.shape}\ntype_em.shape={type_em.shape}\nem.shape={em.shape}")
assert word_em.shape[0] * type_em.shape[0] == em.shape[0]
faiss.normalize_L2(em)

100%|██████████████████████| 1855603/1855603 [00:06<00:00, 300870.98it/s]


word_em.shape=(1855603, 32)
type_em.shape=(3, 32)
em.shape=(5566809, 32)


In [6]:
%%time
np.save(em_file, em)

Wall time: 1.88 s


In [7]:
%%time
d = em.shape[1]
m = 8  # number of subquantizers
quantizer = faiss.IndexFlatIP(d)
# 8 specifies that each sub-vector is encoded as 8 bits
index = faiss.IndexIVFPQ(quantizer, d, faiss_nlist, m, 8)
index.verbose = True
index.train(em)
index.add(em)
faiss.write_index(index, str(index_file))

Wall time: 18 s


In [8]:
index = faiss.read_index(str(index_file))
assert index.is_trained and index.ntotal == em.shape[0]
# sanity check
index.nprobe = 1
k = 4
n = 1000
distances, ids = index.search(em[:n], k)
for i in range(n):
    if ids[i][0] != i:
        print(f"Expecting id {i} but found {ids[i][0]}")

Expecting id 5 but found 4
Expecting id 47 but found 46
Expecting id 182 but found 181
Expecting id 185 but found 184
Expecting id 191 but found 190
Expecting id 236 but found 235
Expecting id 296 but found 295
Expecting id 380 but found 379
Expecting id 389 but found 388
Expecting id 524 but found 523
Expecting id 527 but found 526
Expecting id 572 but found 571
Expecting id 590 but found 589
Expecting id 752 but found 751
Expecting id 788 but found 787
Expecting id 866 but found 865
Expecting id 938 but found 937


In [9]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")
print(f"Saved {str(index_file)}")

Total time taken 0:00:33.115153
Saved output/20230130_025602.index
