In [1]:
import statistics

import nltk
import numpy as np
import plotly.express as px
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

from src.misc import import_data

np.random.seed(22)

nltk.download("stopwords")

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/tadejkrivec/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [2]:
EXPERIMENT_REPETITIONS = 10 # increase for better results
N_TRUE_CLUSTERS = 35
TSNE_COMPONENTS = 2
DATASET_FILE = "./data/dev-v2.0.json"
PRECALCULATED_EMBEDDINGS_FILE = "./data/dev-v2.0-precalculated.csv"
FIGURES_FOLDER = "./figures"

In [3]:
data, embeddings = import_data(
    file=DATASET_FILE, precalculated_file=PRECALCULATED_EMBEDDINGS_FILE
)

In [4]:
# t-SNE manifold projection
scaler = StandardScaler()
scaler.fit(embeddings)
embeddings = scaler.transform(embeddings)

tsne = TSNE(
    n_components=TSNE_COMPONENTS,
    perplexity=15,
    random_state=42,
    init="random",
    learning_rate=200,
)
vis_dims = tsne.fit_transform(embeddings)
vis_dims = vis_dims.astype(np.float64)

In [5]:
# number of clusters selection via silhouette score
selected_number_of_clusters = []
for _ in tqdm(range(EXPERIMENT_REPETITIONS)):
    res_iter = []
    for n_clusters in range(20, 45):
        kmeans = KMeans(n_clusters=n_clusters, n_init=n_clusters)
        kmeans.fit(vis_dims)
        cluster = kmeans.predict(vis_dims)
        data["cluster"] = cluster
        silhouette_avg = silhouette_score(vis_dims, cluster)
        res_iter.append((n_clusters, silhouette_avg))

    selected_cluster = max(res_iter, key=lambda x: x[1])[0]
    selected_number_of_clusters.append(selected_cluster)
clusters_mode = statistics.mode(selected_number_of_clusters)

100%|██████████| 10/10 [16:12<00:00, 97.20s/it]


In [6]:
fig = px.histogram(x=selected_number_of_clusters)
fig.add_vline(
    x=N_TRUE_CLUSTERS, annotation_text="The number of articles", line_color="orange"
)
fig.add_vline(
    x=clusters_mode,
    annotation_text="The number of clusters",
    line_color="orange",
    line_dash="dash",
)

fig.update_layout(
    plot_bgcolor="rgba(0, 0, 0, 0)",
    paper_bgcolor="rgba(0, 0, 0, 0)",
    margin=dict(l=0, r=0, t=0, b=0),
    height=500,
    width=600,
    annotations=[
        {**a, **{"textangle": -90, "bgcolor": "orange"}}
        for a in fig.to_dict()["layout"]["annotations"]
    ],
)
fig.update_coloraxes(showscale=False)
fig.update_yaxes(showgrid=False, showticklabels=False)
fig.show()

fig.write_html(f"{FIGURES_FOLDER}/histogram_nclusters.html")
fig.write_image(f"{FIGURES_FOLDER}/histogram_nclusters.png")