In [None]:
import pandas as pd
import numpy as np
import altair as alt
from random import sample
import torch
from tqdm import tqdm
from collections import Counter
from hdbscan import HDBSCAN
from sklearn.cluster import DBSCAN, KMeans, Birch
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import euclidean_distances
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler

from doc_embed_torch import DocumentEmbeddingTrainer

In [None]:
DUAL = "dual"
MLM = "mlm"

run_code = "MlnsPLul"
model_type = DUAL

trainer = DocumentEmbeddingTrainer(run_code=run_code, model_type=model_type)
trainer.load_model(run_code)

In [None]:
indices = sample(range(len(trainer.doc_dataset)), 256)
print(indices[:10])

In [None]:
distances = dict()
embeddings = dict()
for i in tqdm(indices):
    for j in indices:
        if (i, j) in distances or (j, i) in distances or i == j:
            continue
        
        doc_i = trainer.doc_dataset[i].unsqueeze(0)
        doc_j = trainer.doc_dataset[j].unsqueeze(0)
        if i not in embeddings:
            embeddings[i] = trainer.model(doc_i, return_doc_embedding=True)
        if j not in embeddings:
            embeddings[j] = trainer.model(doc_j, return_doc_embedding=True)
        
        # p determines the Minkowski order. 2 is Euclidean, 1 is Manhattan. etc.
        distances[(i, j)] = torch.cdist(embeddings[i].unsqueeze(0), embeddings[j].unsqueeze(0), p=2).item()

In [None]:
doc_i.shape

In [None]:
s = pd.Series(distances.values())
s.describe()

In [None]:
def get_percentile(percent, embeddings=None):
    if embeddings is None:
        embeddings = list(distances.values())
    
    return np.percentile(embeddings, percent)

get_percentile(25)

In [None]:
# create a DataFrame from your Series s
df = pd.DataFrame(s)
df.columns=['value']
df.head()

In [None]:
embeddings_df = pd.DataFrame(
    [{'idx': key, 'embedding': value.detach().numpy()} for key, value in embeddings.items()]
)
embeddings_df.head()

In [None]:
emb_array = np.concatenate(embeddings_df.embedding.values)
emb_array.shape

In [None]:
emb_array[0][:3]

In [None]:
mms_emb = MinMaxScaler().fit_transform(emb_array)
mms_emb[:5, :5]

In [None]:
mms_emb.shape

In [None]:
def get_predictions(embedding_array, percent, min_samples=3):
    # Reduce to 2 components (i.e. 2-dimensional space)
    scaled_emb = RobustScaler().fit_transform(embedding_array)
    embed_pca = PCA(n_components=20).fit_transform(scaled_emb)
    
    embed_reduced = TSNE(
        n_components=2, learning_rate=200, init='random', perplexity=25, early_exaggeration=percent,
    ).fit_transform(embed_pca)
    
    # Calculate the distance matrix
    embed_distances = euclidean_distances(embed_pca)
    
    # Calculate eps as a percentile of our distance matrix values
    eps = np.percentile(embed_distances, percent)
    dbscan = DBSCAN(eps=eps, min_samples=min_samples)
    pred = dbscan.fit_predict(embed_pca)
    return embed_reduced, pred

def get_num_topics(predictions):
    pred_set = set(predictions)
    num_topics = len(pred_set) - int(-1 in pred_set)
    counter = Counter(predictions)
    return num_topics, counter

In [None]:
# get_num_topics(get_predictions(emb_array, 20, 5))
er, pred = get_predictions(emb_array, 10, 5)
get_num_topics(pred)

In [None]:
er_df = pd.DataFrame(er)
er_df.columns = ['x', 'y']
er_df['pred'] = pred

er_df.head()

In [None]:
alt.Chart(er_df).mark_point().encode(
    x=alt.X('x', title=None),
    y=alt.Y('y', title=None),
    color=alt.Color('pred:N'),
)

In [None]:
def predict_generic(embedding_array, model_cls, percent=20, **kwargs):
    # Reduce to 2 components (i.e. 2-dimensional space)
    scaled_emb = RobustScaler().fit_transform(embedding_array)
    embed_pca = PCA(n_components=50).fit_transform(scaled_emb)
    
    embed_reduced = TSNE(
        n_components=2, learning_rate=100, init='random', perplexity=100, early_exaggeration=20,
    ).fit_transform(scaled_emb)
    
    cluster_model = model_cls(**kwargs)
    pred = cluster_model.fit_predict(embed_pca)
    return embed_reduced, pred

In [None]:
her, hpred = predict_generic(emb_array, HDBSCAN, min_cluster_size=5)
get_num_topics(pred)

In [None]:
er_df = pd.DataFrame(er)
er_df.columns = ['x', 'y']
er_df['pred'] = pred
alt.Chart(er_df).mark_point().encode(
    x=alt.X('x', title=None),
    y=alt.Y('y', title=None),
    color=alt.Color('pred:N'),
)