# Token Embedding

## 1. Setup

In [95]:
import faiss

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE

import torch

from transformers import AutoModel

In [46]:
model = AutoModel.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## 2. Embedding

In [214]:
db = np.load('../db.npy')
db -= db.mean(axis=0, keepdims=True)

In [303]:
num_sections = 48
cluster_dim = db.shape[1] // num_sections
num_clusters = 10

In [307]:
for i in range(num_sections):
    kmeans = faiss.Kmeans(cluster_dim, num_clusters, niter=30, nredo=1, spherical=False, verbose=True, gpu=True)
    X = db[:, i*cluster_dim:(i+1)*cluster_dim]
    kmeans.train(np.ascontiguousarray(X))
    break


Sampling a subset of 2560 / 1000050 for training
Clustering 2560 points in 16D to 10 clusters, redo 1 times, 30 iterations
  Preprocessing in 0.02 s
  Iteration 0 (0.00 s, search 0.00 s): objective=10958.5 imbalance=1.829 nsplit=0         Iteration 1 (0.00 s, search 0.00 s): objective=7701.24 imbalance=1.281 nsplit=0         Iteration 2 (0.00 s, search 0.00 s): objective=7356.4 imbalance=1.154 nsplit=0         Iteration 3 (0.00 s, search 0.00 s): objective=7166.87 imbalance=1.127 nsplit=0         Iteration 4 (0.00 s, search 0.00 s): objective=7056.63 imbalance=1.085 nsplit=0         Iteration 5 (0.00 s, search 0.00 s): objective=7027.81 imbalance=1.070 nsplit=0         Iteration 6 (0.00 s, search 0.00 s): objective=7018.85 imbalance=1.061 nsplit=0         Iteration 7 (0.00 s, search 0.00 s): objective=7011.61 imbalance=1.057 nsplit=0         Iteration 8 (0.00 s, search 0.00 s): objective=7007.13 imbalance=1.056 nsplit=0         Iteration 9 (0.00 s, search 0.00 s): objective=7

In [308]:
D, I = kmeans.index.search(np.ascontiguousarray(sample[:, i*cluster_dim:(i+1)*cluster_dim]), num_clusters)

In [312]:
(-torch.tensor(D)).softmax(dim=1)[:,0]

tensor([0.9339, 0.2782, 0.2518,  ..., 0.2301, 0.2220, 0.2283])

In [285]:
D

array([[ 2.140842  ,  0.2796097 ,  0.2506493 , ..., -0.742604  ,
        -0.79542494, -0.9803657 ],
       [ 0.788064  ,  0.5848533 ,  0.37754714, ..., -0.4165491 ,
        -0.7380801 , -0.7636842 ],
       [ 1.4701922 ,  1.1750776 ,  1.0496389 , ..., -0.7477314 ,
        -1.335281  , -1.5062048 ],
       ...,
       [ 1.0759405 ,  0.838409  ,  0.250079  , ..., -0.4382284 ,
        -0.6448281 , -0.9668639 ],
       [ 0.39732748,  0.3572457 ,  0.34227183, ..., -0.33942476,
        -0.7597793 , -0.7751605 ],
       [ 0.475685  ,  0.37753665,  0.22459699, ..., -0.13650276,
        -0.182647  , -0.817257  ]], dtype=float32)

In [247]:
(-torch.tensor(D)).softmax(dim=1)

tensor([[6.9467e-02, 6.9420e-02, 5.8833e-02,  ..., 5.4045e-08, 4.7684e-08,
         3.6143e-08],
        [1.1384e-02, 1.1080e-02, 1.1028e-02,  ..., 4.6887e-07, 3.0408e-07,
         1.2915e-07],
        [3.6143e-02, 3.1051e-02, 2.7092e-02,  ..., 5.1974e-08, 4.8904e-08,
         1.3046e-08],
        ...,
        [2.7997e-02, 2.6732e-02, 1.6230e-02,  ..., 8.5449e-07, 7.9209e-07,
         6.2996e-07],
        [9.4049e-03, 7.7874e-03, 7.6398e-03,  ..., 1.2306e-06, 1.0933e-06,
         1.0795e-06],
        [7.5814e-03, 6.9586e-03, 6.5871e-03,  ..., 2.2890e-06, 2.1200e-06,
         1.1616e-06]])

In [248]:
I

array([[664, 733, 941, ..., 343, 867, 105],
       [112,  92, 514, ..., 228, 135, 445],
       [918, 362, 667, ..., 988, 105, 867],
       ...,
       [362, 900, 773, ..., 432, 765, 738],
       [539, 411, 229, ..., 867, 505, 343],
       [ 37, 938, 690, ..., 781, 228, 445]])

In [185]:
rand_idxs = sorted(np.random.choice(range(len(embed)), 10000))
sample = embed[rand_idxs]

In [202]:
kmeans = faiss.Kmeans(768, 1000, niter=30, nredo=1, spherical=True, verbose=False, gpu=True)

In [203]:
kmeans.train(embed)


Clustering 1000050 points in 768D to 100000 clusters, redo 1 times, 30 iterations
  Preprocessing in 0.39 s
Clustering 1000050 points in 768D to 50000 clusters, redo 1 times, 30 iterations
  Preprocessing in 0.40 s
Clustering 1000050 points in 768D to 25000 clusters, redo 1 times, 30 iterations
  Preprocessing in 0.39 s
Sampling a subset of 256000 / 1000050 for training
Clustering 256000 points in 768D to 1000 clusters, redo 1 times, 30 iterations
  Preprocessing in 0.65 s
  Iteration 29 (101.78 s, search 100.81 s): objective=2.18846e+06 imbalance=1.219 nsplit=0       

2188461.5

In [204]:
D, I = kmeans.index.search(sample, 1000)

In [208]:
faiss.write_index(kmeans.index, 'kmeans')

In [210]:
index = faiss.read_index('kmeans')

In [211]:
index.search(sample, 1)

(array([[14.185184 ],
        [ 6.6428385],
        [ 5.625658 ],
        ...,
        [ 8.343491 ],
        [ 6.2191877],
        [ 5.5191755]], dtype=float32),
 array([[ 19],
        [105],
        [ 93],
        ...,
        [989],
        [862],
        [913]]))

In [212]:
I[:,0]

array([ 19, 105,  93, ..., 989, 862, 913])

In [None]:
kmeans.index.

In [206]:
torch.tensor(D).softmax(dim=1)[0]

tensor([1.9240e-01, 1.3758e-01, 1.1031e-01, 1.0508e-01, 1.0476e-01, 7.2238e-02,
        5.7052e-02, 5.5457e-02, 4.4753e-02, 4.0526e-02, 1.7296e-02, 1.0706e-02,
        9.8267e-03, 8.0082e-03, 7.6235e-03, 4.5702e-03, 3.3738e-03, 3.0598e-03,
        2.0924e-03, 1.5668e-03, 1.5500e-03, 1.4151e-03, 1.2332e-03, 1.1718e-03,
        9.7523e-04, 8.0928e-04, 7.9776e-04, 6.2521e-04, 5.2402e-04, 5.1694e-04,
        4.9166e-04, 4.7253e-04, 1.8533e-04, 1.8207e-04, 8.0691e-05, 7.3702e-05,
        6.7691e-05, 5.2238e-05, 5.0595e-05, 4.9196e-05, 4.4507e-05, 4.2352e-05,
        4.0040e-05, 3.9040e-05, 3.2777e-05, 2.9015e-05, 1.8160e-05, 8.5674e-06,
        8.5486e-06, 6.8867e-06, 6.5506e-06, 6.2771e-06, 5.3578e-06, 5.2873e-06,
        5.2236e-06, 4.8666e-06, 4.3608e-06, 4.1106e-06, 3.1903e-06, 2.1702e-06,
        1.9264e-06, 1.7227e-06, 1.4430e-06, 1.4374e-06, 1.3586e-06, 1.2727e-06,
        1.2009e-06, 1.1762e-06, 1.0922e-06, 1.0562e-06, 1.0462e-06, 1.0367e-06,
        9.5198e-07, 9.4423e-07, 9.2607e-