In [2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from ipywidgets import interact, FloatSlider, VBox, HTML

!pip install -qq scikit-learn-extra
from sklearn_extra.cluster import KMedoids
from sklearn_extra.cluster import KMedoids

# For reproducibility
np.random.seed(42)


def plot_description(text):
    print(f"\nDescription:\n{text}\n")


def clustering_demo(class_separation=0.0):
    np.random.seed(0)
    
    # Generate two rectangular clusters
    cluster1 = np.random.uniform(low=[0, 0], high=[4, 4], size=(200, 2))
    cluster2 = np.random.uniform(low=[4 - class_separation, 0],
                                 high=[6 - class_separation, 4], size=(200, 2))
    X = np.vstack([cluster1, cluster2])
    
    # Fit models
    kmeans = KMeans(n_clusters=2, n_init=10, random_state=0).fit(X)
    kmedians = KMedoids(n_clusters=2, metric='manhattan', random_state=0).fit(X)
    
    # Sort clusters by x-coordinate of their centers (to ensure consistent colors)
    def sort_clusters(model):
        centers = model.cluster_centers_
        order = np.argsort(centers[:, 0])
        labels_sorted = np.zeros_like(model.labels_)
        for new_label, old_label in enumerate(order):
            labels_sorted[model.labels_ == old_label] = new_label
        return labels_sorted, centers[order]
    
    labels_kmeans, centers_kmeans = sort_clusters(kmeans)
    labels_kmedians, centers_kmedians = sort_clusters(kmedians)
    
    # Plot
    fig, ax = plt.subplots(1, 2, figsize=(12, 5))
    titles = ["K-Means (L2 loss)", "K-Medians (L1 loss)"]
    
    for i, (labels, centers, title) in enumerate(
        zip([labels_kmeans, labels_kmedians],
            [centers_kmeans, centers_kmedians],
            titles)
    ):
        ax[i].scatter(X[labels==0, 0], X[labels==0, 1], color="blue" if "K-Means" in title else "green", label="Cluster 0", s=40, alpha=0.7)
        ax[i].scatter(X[labels==1, 0], X[labels==1, 1], color="orange" if "K-Means" in title else "purple", label="Cluster 1", s=40, alpha=0.7)
        ax[i].scatter(centers[:, 0], centers[:, 1],
                      c="red", s=200 if "K-Means" in title else 100,
                      marker="X" if "K-Means" in title else "D",
                      edgecolor="black", label="Means" if "K-Means" in title else "Medians")
        ax[i].set_title(title, fontsize=14)
        ax[i].set_xlim(-1, 7)
        ax[i].set_ylim(-1, 6)
        ax[i].grid(True)
        ax[i].legend()
    
    plt.show()

def clustering_demo_interact():
    plot_description("Clustering Demo: K-Means (trained using L2 loss) vs. K-Medians (trained using L1 loss). Note that "
                     "these plots demonstrate an example of unsupervised learning (in contrast to previous plots). "
                     "The task of the K-Means and K-Medians algorithms is to find clusters independent of any pre-defined ground truth labels.\n\n"
                    "Left Plot: Clusters found by K-Means with respective cluster means (red crosses).\n\n"
                    "Right Plot: Clusters found by K-Medians with respective cluster medians (red squares).\n\n"
                    "In this case, we instructed both algorithms to find two distinct clusters. Notice how the identified\n"
                    " clusters differ as the sample distribution changes (play around with the slider below).")
    
    # --- Interactive control ---
    sep_slider2 = FloatSlider(
        value=0.0, min=-1.0, max=4.0, step=0.1,
        description="Move Samples",
        style={'description_width': '150px'},
        layout=Layout(width='500px')
    )
    ui_box = VBox([
        Label(value="ðŸ“Š Controls", layout=Layout(margin="0 0 0 0")),
    ])
    interactive_plot = interactive(clustering_demo, class_separation=sep_slider2)
    display(ui_box, interactive_plot)