In [1]:
import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import faiss
import torch
from tqdm import tqdm
import scml

In [2]:
faiss_nlist: int = 1000
model_file = "models/word2vec/20230126_103746/lightning_logs/version_0/checkpoints/pytorch_model.bin"
index_file = "output/item.index"
em_file = "output/em.npy"

In [3]:
tim = scml.Timer()
tim.start()
percentiles=[.01, .05, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99]
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
pd.set_option("use_inf_as_na", True)
pd.set_option("max_info_columns", 9999)
pd.set_option("display.max_columns", 9999)
pd.set_option("display.max_rows", 9999)
pd.set_option('max_colwidth', 9999)
tqdm.pandas()
scml.seed_everything()

In [4]:
ckp = torch.load(str(model_file))
print(ckp)

{'epoch': 1, 'global_step': 12110, 'pytorch-lightning_version': '1.7.7', 'state_dict': OrderedDict([('word_embeddings.weight', tensor([[ 0.2630, -0.0449,  0.6082,  ...,  0.4589,  0.3399, -0.3081],
        [ 0.0128,  0.6311,  0.2702,  ...,  0.2637, -0.2946, -0.0521],
        [-0.1858, -0.0423,  0.1846,  ..., -0.1649, -0.1736, -0.2905],
        ...,
        [ 0.0516, -0.3595, -0.1154,  ..., -0.4182, -0.4185, -0.0857],
        [ 0.2802,  0.5586,  0.0309,  ..., -0.2533,  0.0147, -0.5825],
        [-0.7261,  0.1738, -0.0546,  ..., -0.0231, -0.9209, -0.0161]])), ('type_embeddings.weight', tensor([[-9.4777e-03,  2.8200e-03,  1.7705e-02, -4.7742e-02,  5.0185e-02,
          2.3823e-02, -4.4638e-02,  3.4346e-02,  2.0280e-02,  2.7702e-02,
         -1.6176e-02,  8.0596e-03,  7.8067e-03,  5.3401e-02,  6.1753e-02,
          2.4924e-01,  2.1415e-02,  2.8027e-02, -2.0764e-02,  7.1716e-03,
         -3.8487e-02,  8.6159e-04,  3.1261e-02, -1.7734e-02, -3.2709e-02,
          1.9222e-03, -1.5777e-02, -2.91

In [5]:
word_em = ckp["state_dict"]["word_embeddings.weight"].cpu().numpy()
type_em = ckp["state_dict"]["type_embeddings.weight"].cpu().numpy()
em = []
for wem in tqdm(word_em):
    for tem in type_em:
        em.append(wem + tem)
em = np.array(em, dtype=np.float32)
print(f"word_em.shape={word_em.shape}\ntype_em.shape={type_em.shape}\nem.shape={em.shape}")
assert word_em.shape[0] * type_em.shape[0] == em.shape[0]
faiss.normalize_L2(em)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1855603/1855603 [00:11<00:00, 166527.93it/s]


word_em.shape=(1855603, 32)
type_em.shape=(3, 32)
em.shape=(5566809, 32)


In [6]:
%%time
np.save(em_file, em)

CPU times: user 956 ms, sys: 2.13 s, total: 3.08 s
Wall time: 2.28 s


In [7]:
%%time
d = em.shape[1]
m = 8  # number of subquantizers
quantizer = faiss.IndexFlatIP(d)
# 8 specifies that each sub-vector is encoded as 8 bits
index = faiss.IndexIVFPQ(quantizer, d, faiss_nlist, m, 8)
index.verbose = True
index.train(em)
index.add(em)
faiss.write_index(index, str(index_file))

Training level-1 quantizer
Training level-1 quantizer on 5566809 vectors in 32D
Training IVF residual
  Input training set too big (max size is 65536), sampling 65536 / 5566809 vectors
computing residuals
training 8x256 product quantizer on 65536 vectors in 32D
Training PQ slice 0/8
Clustering 65536 points in 4D to 256 clusters, redo 1 times, 25 iterations
  Preprocessing in 0.00 s
  Iteration 24 (0.44 s, search 0.40 s): objective=238.139 imbalance=1.654 nsplit=0       
Training PQ slice 1/8
Clustering 65536 points in 4D to 256 clusters, redo 1 times, 25 iterations
  Preprocessing in 0.00 s
  Iteration 24 (0.43 s, search 0.39 s): objective=246.085 imbalance=1.678 nsplit=0       
Training PQ slice 2/8
Clustering 65536 points in 4D to 256 clusters, redo 1 times, 25 iterations
  Preprocessing in 0.00 s
  Iteration 24 (0.48 s, search 0.44 s): objective=234.466 imbalance=1.682 nsplit=0       
Training PQ slice 3/8
Clustering 65536 points in 4D to 256 clusters, redo 1 times, 25 iterations
  

In [8]:
index = faiss.read_index(str(index_file))
assert index.is_trained and index.ntotal == em.shape[0]
# sanity check
index.nprobe = 1
k = 4
n = 1000
distances, ids = index.search(em[:n], k)
for i in range(n):
    if ids[i][0] != i:
        print(f"Expecting id {i} but found {ids[i][0]}")

In [9]:
tim.stop()
print(f"Total time taken {str(tim.elapsed)}")
print(f"Saved {str(index_file)}")

Total time taken 0:00:40.382359
Saved output/item.index
