In [None]:
"""
This file generates figures that show individual silhouette widths compared to their corresponding upper bounds for unlabeled datasets.

Notes: 
    - The unlabeled datasets are available at https://archive.ics.uci.edu/
"""

In [1]:
from silhouette_upper_bound import upper_bound_samples
from sklearn.metrics import pairwise_distances
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
from utils import load_unlabeled_data, asw_optimization, algorithm_kmeans, get_silhouette_plot_data

In [2]:
def ub_fcn(dataset, kappa=1, metric='euclidean'):
    """
    Return ub_samples
    """

    if dataset == "conference_papers":
        data = load_unlabeled_data(dataset=dataset, transpose=True)
    else:
        data = load_unlabeled_data(dataset=dataset)

    D = pairwise_distances(data, metric=metric)  # convert data to dissimilarity matrix

    ub_samples = upper_bound_samples(D, kappa=kappa)
    
    return ub_samples


In [3]:
def cl_fcn(dataset, n_clusters):
    """
    Return labels, scores, min_size
    """

    if dataset == "conference_papers":
        data = load_unlabeled_data(dataset=dataset, transpose=True)
    else:
        data = load_unlabeled_data(dataset=dataset)

    kmeans_solution = asw_optimization(
            algorithm=algorithm_kmeans,
            data=data,
            k_range=range(n_clusters, n_clusters + 1),
            asw_metric="euclidean",
        )

    kmeans_labels = kmeans_solution['best_labels']
    kmeans_scores = kmeans_solution['best_scores']

    min_size = min(Counter(kmeans_labels).values())

    return kmeans_labels, kmeans_scores, min_size, data.shape

In [4]:
datasets = [ ('religious_texts', 3), ('ceramic', 2), ('conference_papers', 2),('rna', 6)]

In [5]:
# -------------------------------------------------
# 4. Plot grid with silhouette plots
# -------------------------------------------------
rows, cols = 2, 4
fig, axes = plt.subplots(rows, cols, figsize=(20, 10))
axes = axes.flatten()

for i, (dataset, n_clusters) in enumerate(datasets):

    # clustering 
    labels, scores, min_size, data_shape = cl_fcn(dataset=dataset, n_clusters=n_clusters)

    for j, kappa in enumerate([1, min_size]):
        ub_samples = ub_fcn(dataset=dataset, kappa=kappa)

        data = get_silhouette_plot_data(labels, scores, n_clusters, ub_samples)

        score, ub = np.mean(scores), np.mean(ub_samples)

        ax = axes[(i * 2) + j] 

        for x in data.keys():

            # Cluster Silhouette scores
            ax.fill_betweenx(
                np.arange(data[x]['y_lower'], data[x]['y_upper']),
                0,
                data[x]['sorted_silhouettes'],
                facecolor=data[x]['color'],
                edgecolor='black',
                alpha=0.8,
            )

            # Cluster Silhouette bounds
            ax.fill_betweenx(
                np.arange(data[x]['y_lower'], data[x]['y_upper']),
                0,
                data[x]['sorted_ub_values'],
                facecolor=data[x]['color'],
                edgecolor=data[x]['color'],
                alpha=0.5,
            )

            # Label cluster number
            ax.text(-0.05, data[x]['y_lower'] + 0.5 * data[x]['size_cluster_i'], str(x))
        
        ax.axvline(x=ub, color="black", linestyle="--", label=rf"upper bound ($\kappa$={kappa})")
        ax.axvline(x=score, color="black", linestyle="-", label="ASW")
        ax.set_title(f"{dataset} {data_shape}")
        ax.set_xlim([-0.1, 1.1])
        ax.set_yticks([])
        ax.legend(fontsize=8, loc="upper right")

plt.savefig("silhouette_grid_unlabeled_data.pdf", bbox_inches="tight")
print("Silhouette grid plot generated!")
plt.close()

2025-09-03 20:25:11 | utils | INFO | ==== Running dataset: religious_texts ====

2025-09-03 20:25:12 | utils | INFO | Data shape: (590, 8266)
2025-09-03 20:25:12 | utils | INFO | Data shape (zeros removed): (589, 8266)
2025-09-03 20:25:12 | utils | INFO | Optimizing ASW
100%|██████████| 1/1 [00:00<00:00,  2.51it/s]
2025-09-03 20:25:12 | utils | INFO | ==== Running dataset: religious_texts ====

2025-09-03 20:25:13 | utils | INFO | Data shape: (590, 8266)
2025-09-03 20:25:13 | utils | INFO | Data shape (zeros removed): (589, 8266)
2025-09-03 20:25:13 | utils | INFO | ==== Running dataset: religious_texts ====

2025-09-03 20:25:14 | utils | INFO | Data shape: (590, 8266)
2025-09-03 20:25:14 | utils | INFO | Data shape (zeros removed): (589, 8266)
2025-09-03 20:25:14 | utils | INFO | ==== Running dataset: ceramic ====

2025-09-03 20:25:14 | utils | INFO | Data shape: (88, 17)
2025-09-03 20:25:14 | utils | INFO | Data shape (zeros removed): (88, 17)
2025-09-03 20:25:14 | utils | INFO | Opt

Silhouette grid plot generated!
