In [25]:
import pandas as pd
import numpy as np
import altair as alt
from random import sample
import torch
from tqdm import tqdm
from collections import Counter
from sklearn.cluster import DBSCAN, KMeans, Birch

from doc_embed_torch import DocumentEmbeddingTrainer, load_run_config, VOCAB_SIZE

In [2]:
run_code = "CjP6bpvz"
trainer = DocumentEmbeddingTrainer(run_code=run_code)
trainer.load_mlm(run_code, VOCAB_SIZE)

Preparing the masked dataset ...
Done preparing the masked dataset.
Preparing the model for quantization ...


In [36]:
indices = sample(range(len(trainer.train_dataset)), 256)
print(indices[:10])

[2887, 37168, 1352, 41623, 29038, 72186, 57613, 8937, 9221, 52895]


In [37]:
distances = dict()
embeddings = dict()
for i in tqdm(indices):
    for j in indices:
        if (i, j) in distances or (j, i) in distances or i == j:
            continue
        
        doc_i = trainer.train_dataset[i]
        doc_j = trainer.train_dataset[j]
        if i not in embeddings:
            embeddings[i] = trainer.model(doc_i, return_doc_embedding=True)
        if j not in embeddings:
            embeddings[j] = trainer.model(doc_j, return_doc_embedding=True)
        
        # p determines the Minkowski order. 2 is Euclidean, 1 is Manhattan. etc.
        distances[(i, j)] = torch.cdist(embeddings[i].unsqueeze(0), embeddings[j].unsqueeze(0), p=2).item()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 256/256 [00:43<00:00,  5.83it/s]


In [38]:
s = pd.Series(distances.values())
s.describe()

count    32640.000000
mean         0.049830
std          0.018750
min          0.021950
25%          0.038877
50%          0.041922
75%          0.048554
max          0.133237
dtype: float64

In [39]:
def get_percentile(percent):
    return np.percentile(list(distances.values()), percent)

get_percentile(25)

0.03887707367539406

In [40]:
# create a DataFrame from your Series s
df = pd.DataFrame(s)
df.columns=['value']
df.head()

Unnamed: 0,value
0,0.092948
1,0.038697
2,0.044966
3,0.038087
4,0.042275


In [41]:
embeddings_df = pd.DataFrame(
    [{'idx': key, 'embedding': value.detach().numpy()} for key, value in embeddings.items()]
)
embeddings_df.head()

Unnamed: 0,idx,embedding
0,2887,"[0.008524248, 0.0032130228, 0.0064429687, 0.00..."
1,37168,"[-0.0033808919, -0.00012160279, -0.0046389946,..."
2,1352,"[0.004151065, 0.010614921, 0.0033359602, 0.002..."
3,41623,"[0.0010293359, 0.0025295867, 0.0060312664, 0.0..."
4,29038,"[0.0028249985, 0.0060028057, 0.0059975833, 0.0..."


In [17]:
emb_array = np.array([emb for emb in embeddings_df.embedding.values])
emb_array.shape, emb_array[0].shape

((100, 128), (128,))

In [42]:
def get_predictions(embedding_array, percent, min_samples=3):
    dbscan = DBSCAN(eps=get_percentile(percent), min_samples=min_samples)
    pred = dbscan.fit_predict(embedding_array)
    return pred

def get_num_topics(predictions):
    pred_set = set(predictions)
    num_topics = len(pred_set) - int(-1 in pred_set)
    counter = Counter(predictions)
    return num_topics, counter

In [43]:
for percentile in (5, 10, 25, 50, 80, 90):
    for ms in (2, 3, 5, 10, 20):
        print(percentile, ms, get_num_topics(get_predictions(emb_array, percentile, ms)))

5 2 (2, Counter({1: 70, -1: 17, 0: 13}))
5 3 (2, Counter({1: 70, -1: 17, 0: 13}))
5 5 (2, Counter({1: 69, -1: 18, 0: 13}))
5 10 (2, Counter({1: 66, -1: 21, 0: 13}))
5 20 (1, Counter({0: 62, -1: 38}))
10 2 (2, Counter({1: 80, 0: 13, -1: 7}))
10 3 (2, Counter({1: 80, 0: 13, -1: 7}))
10 5 (2, Counter({1: 80, 0: 13, -1: 7}))
10 10 (2, Counter({1: 79, 0: 13, -1: 8}))
10 20 (1, Counter({0: 76, -1: 24}))
25 2 (2, Counter({1: 86, 0: 13, -1: 1}))
25 3 (2, Counter({1: 86, 0: 13, -1: 1}))
25 5 (2, Counter({1: 86, 0: 13, -1: 1}))
25 10 (2, Counter({1: 86, 0: 13, -1: 1}))
25 20 (1, Counter({0: 85, -1: 15}))
50 2 (2, Counter({1: 87, 0: 13}))
50 3 (2, Counter({1: 87, 0: 13}))
50 5 (2, Counter({1: 87, 0: 13}))
50 10 (2, Counter({1: 87, 0: 13}))
50 20 (1, Counter({0: 87, -1: 13}))
80 2 (1, Counter({0: 100}))
80 3 (1, Counter({0: 100}))
80 5 (1, Counter({0: 100}))
80 10 (1, Counter({0: 100}))
80 20 (1, Counter({0: 95, -1: 5}))
90 2 (1, Counter({0: 100}))
90 3 (1, Counter({0: 100}))
90 5 (1, Counter({0: 