In [1]:
import pandas as pd
import numpy as np
import altair as alt
from random import sample
import torch
from tqdm import tqdm
from collections import Counter
from sklearn.cluster import DBSCAN, KMeans, Birch
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import euclidean_distances

from doc_embed_torch import DocumentEmbeddingTrainer, load_run_config, VOCAB_SIZE

In [2]:
run_code = "CjP6bpvz"
trainer = DocumentEmbeddingTrainer(run_code=run_code)
trainer.load_mlm(run_code, VOCAB_SIZE)

Preparing the masked dataset ...
Done preparing the masked dataset.
Preparing the model for quantization ...


In [3]:
indices = sample(range(len(trainer.train_dataset)), 256)
print(indices[:10])

[10635, 5113, 13261, 6271, 48800, 35451, 51626, 1076, 41618, 46530]


In [4]:
distances = dict()
embeddings = dict()
for i in tqdm(indices):
    for j in indices:
        if (i, j) in distances or (j, i) in distances or i == j:
            continue
        
        doc_i = trainer.train_dataset[i]
        doc_j = trainer.train_dataset[j]
        if i not in embeddings:
            embeddings[i] = trainer.model(doc_i, return_doc_embedding=True)
        if j not in embeddings:
            embeddings[j] = trainer.model(doc_j, return_doc_embedding=True)
        
        # p determines the Minkowski order. 2 is Euclidean, 1 is Manhattan. etc.
        distances[(i, j)] = torch.cdist(embeddings[i].unsqueeze(0), embeddings[j].unsqueeze(0), p=2).item()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:49<00:00,  5.18it/s]


In [5]:
s = pd.Series(distances.values())
s.describe()

count    32640.000000
mean         0.052520
std          0.020543
min          0.020587
25%          0.038996
50%          0.042784
75%          0.063072
max          0.123312
dtype: float64

In [6]:
def get_percentile(percent, embeddings=None):
    if embeddings is None:
        embeddings = list(distances.values())
    
    return np.percentile(embeddings, percent)

get_percentile(25)

0.0389955285936594

In [7]:
# create a DataFrame from your Series s
df = pd.DataFrame(s)
df.columns=['value']
df.head()

Unnamed: 0,value
0,0.091463
1,0.09171
2,0.038493
3,0.036431
4,0.038532


In [8]:
embeddings_df = pd.DataFrame(
    [{'idx': key, 'embedding': value.detach().numpy()} for key, value in embeddings.items()]
)
embeddings_df.head()

Unnamed: 0,idx,embedding
0,10635,"[0.0042035063, 0.002961846, 0.0045745405, 0.00..."
1,5113,"[-0.0014725369, -0.001475839, -0.0029670973, 0..."
2,13261,"[-0.0043907985, -1.4081597e-06, -0.0016197367,..."
3,6271,"[0.003861553, 0.008717964, 0.003665205, 0.0032..."
4,48800,"[0.0027450407, 0.005412426, 0.004454517, 0.008..."


In [9]:
emb_array = np.array([emb for emb in embeddings_df.embedding.values])
emb_array.shape, emb_array[0].shape

((256, 128), (128,))

In [22]:
def get_predictions(embedding_array, percent, min_samples=3):
    # Reduce to 2 components (i.e. 2-dimensional space)
    embed_pca = PCA(n_components=64).fit_transform(embedding_array)
    
    embed_reduced = TSNE(
        n_components=2, learning_rate=200, init='random', perplexity=10, early_exaggeration=percent,
    ).fit_transform(embed_pca)
    
    # Calculate the distance matrix
    embed_distances = euclidean_distances(embed_pca)
    
    # Calculate eps as a percentile of our distance matrix values
    eps = np.percentile(embed_distances, percent)
    dbscan = DBSCAN(eps=eps, min_samples=min_samples)
    pred = dbscan.fit_predict(embed_pca)
    return embed_reduced, pred

def get_num_topics(predictions):
    pred_set = set(predictions)
    num_topics = len(pred_set) - int(-1 in pred_set)
    counter = Counter(predictions)
    return num_topics, counter

In [23]:
# get_num_topics(get_predictions(emb_array, 20, 5))
er, pred = get_predictions(emb_array, 20, 5)
get_num_topics(pred)

(1, Counter({0: 256}))

In [24]:
er_df = pd.DataFrame(er)
er_df.columns = ['x', 'y']
er_df['pred'] = pred

er_df.head()

Unnamed: 0,x,y,pred
0,-6.830966,16.923643,0
1,43.918655,-16.57724,0
2,48.581882,-14.067924,0
3,-9.912354,8.771676,0
4,-11.76837,22.128508,0


In [25]:
alt.Chart(er_df).mark_point().encode(
    x=alt.X('x', title=None),
    y=alt.Y('y', title=None),
    color=alt.Color('pred:N'),
)