In [1]:
import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import faiss
import torch
from tqdm import tqdm
import scml

In [2]:
faiss_nlist: int = 1000
model_file = "models/word2vec/20230131_055311/lightning_logs/version_0/checkpoints/epoch=15-step=64208.ckpt"
index_file = "output/m8_w7_i20.index"
em_file = "output/m8_w7_i20.npy"

In [3]:
tim = scml.Timer()
tim.start()
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
pd.set_option("use_inf_as_na", True)
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything()

In [4]:
ckp = torch.load(str(model_file))
print(ckp)

{'epoch': 15, 'global_step': 64208, 'pytorch-lightning_version': '1.7.7', 'state_dict': OrderedDict([('word_embeddings.weight', tensor([[ 0.7196, -0.2523,  0.0447,  ...,  0.8360,  0.6226,  0.3175],
        [-0.0872, -1.8396,  1.0697,  ..., -0.0081,  2.1950,  1.7129],
        [ 1.5607,  0.3383, -0.2469,  ..., -0.0354,  0.6428, -0.8423],
        ...,
        [ 1.5420,  0.1625, -2.4177,  ...,  2.9724,  0.2030, -1.2735],
        [ 1.6230,  0.1426, -2.6091,  ...,  2.6237,  0.1537, -1.2714],
        [ 1.8226,  0.2791, -2.6857,  ...,  2.6923,  0.1305, -1.2754]],
       device='cuda:0')), ('type_embeddings.weight', tensor([[-2.0791e-01,  5.1622e-02,  2.7370e-01, -2.6897e-01, -1.7714e-01,
          2.4166e-01, -2.6782e-01, -2.2856e-02,  2.5493e-01, -8.1043e-02,
         -2.3945e-01, -7.7283e-02,  2.3141e-02, -1.7718e-02,  6.7999e-02,
         -2.9145e-01, -2.5474e-01,  2.7317e-01,  2.0982e-01,  2.8213e-01,
         -8.1954e-02,  2.3246e-01,  1.3225e-01, -1.3763e-01, -2.9451e-01,
         -2.434

In [5]:
word_em = ckp["state_dict"]["word_embeddings.weight"].cpu().numpy()
type_em = ckp["state_dict"]["type_embeddings.weight"].cpu().numpy()
em = []
for wem in tqdm(word_em):
    for tem in type_em:
        em.append(wem + tem)
em = np.array(em, dtype=np.float32)
print(f"word_em.shape={word_em.shape}\ntype_em.shape={type_em.shape}\nem.shape={em.shape}")
assert word_em.shape[0] * type_em.shape[0] == em.shape[0]
faiss.normalize_L2(em)

100%|██████████████████████| 1855603/1855603 [00:05<00:00, 314479.04it/s]


word_em.shape=(1855603, 32)
type_em.shape=(3, 32)
em.shape=(5566809, 32)


In [6]:
%%time
np.save(em_file, em)

Wall time: 2.03 s


In [7]:
%%time
d = em.shape[1]
m = 8  # number of subquantizers
quantizer = faiss.IndexFlatIP(d)
# 8 specifies that each sub-vector is encoded as 8 bits
index = faiss.IndexIVFPQ(quantizer, d, faiss_nlist, m, 8)
index.verbose = True
index.train(em)
index.add(em)
faiss.write_index(index, str(index_file))

Wall time: 19 s


In [8]:
index = faiss.read_index(str(index_file))
assert index.is_trained and index.ntotal == em.shape[0]
# sanity check
index.nprobe = 1
k = 4
n = 1000
distances, ids = index.search(em[:n], k)
for i in range(n):
    if ids[i][0] != i:
        print(f"Expecting id {i} but found {ids[i][0]}")

Expecting id 12 but found 9


In [9]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")
print(f"Saved {str(index_file)}")

Total time taken 0:00:33.695480
Saved output/m8_w7_i20.index
