In [2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

In [32]:
# Set random seed for reproducibility
np.random.seed(42)

# Generate dataset with 4 clusters of different sizes
# We'll create clusters with 100, 50, 200, and 150 points respectively
cluster_sizes = [100, 50, 500, 150]
n_samples = sum(cluster_sizes)

# Cluster centers - positioned in a way that they're clearly separated
centers = [(2, 2), (4.75, 3), (3.5, 4.5), (1, 4.5)]

# Create a dataset where clusters are close together (more challenging for K-Means)
# X, y_true = make_blobs(n_samples=300, centers=4, cluster_std=0.8, random_state=42)
X, y_true = make_blobs(n_samples=cluster_sizes,
                       centers=centers,
                       cluster_std=[0.5, 0.8, 0.3, 0.6],
                       random_state=42)

# Custom K-Means class that records centroids and assignments at each iteration
class VisualKMeans:
    def __init__(self, n_clusters=4, max_iter=10, random_state=42):
        self.n_clusters = n_clusters
        self.max_iter = max_iter
        self.random_state = random_state
        self.history = []
        self.colors = plt.cm.get_cmap('viridis', n_clusters)  # Distinct colors for each cluster

    def fit(self, X):
        # Initial centroids
        rng = np.random.RandomState(self.random_state)
        i = rng.permutation(X.shape[0])[:self.n_clusters]
        centroids = X[i]

        for iteration in range(self.max_iter):
            # Assign points to nearest centroid
            distances = np.sqrt(((X - centroids[:, np.newaxis])**2).sum(axis=2))
            labels = np.argmin(distances, axis=0)

            # Record current state
            self.history.append({
                'iteration': iteration + 1,
                'centroids': centroids.copy(),
                'labels': labels.copy()
            })

            # Update centroids
            new_centroids = np.array([X[labels == k].mean(axis=0) for k in range(self.n_clusters)])

            # Check for convergence
            if np.all(centroids == new_centroids):
                break

            centroids = new_centroids

        return self

# Run K-Means with visualization
kmeans = VisualKMeans(n_clusters=4, max_iter=10, random_state=42)
kmeans.fit(X)

# Create figure
fig, ax = plt.subplots(figsize=(12, 9))
scatter = ax.scatter(X[:, 0], X[:, 1], s=30, alpha=0.7)
centroid_marks = ax.scatter([], [], c='red', marker='X', s=100, linewidth=2, edgecolor='red')
iteration_text = ax.text(0.02, 0.95, '', transform=ax.transAxes, fontsize=12,
                       bbox=dict(facecolor='white', alpha=0.8))
ax.set_xlim(X[:, 0].min()-1, X[:, 0].max()+1)
ax.set_ylim(X[:, 1].min()-1, X[:, 1].max()+1)
ax.set_xlabel("Feature 1")
ax.set_ylabel("Feature 2")

# Animation update function
def update(iteration):
    if iteration == 0:
        # Initial state before first iteration
        scatter.set_array(np.zeros(len(X)))
        scatter.set_cmap('gray')
        centroid_marks.set_offsets(np.empty((0, 2)))
        iteration_text.set_text('Initial State')
        return scatter, centroid_marks, iteration_text

    state = kmeans.history[iteration-1]

    # Update scatter plot with cluster colors
    # scatter = ax.scatter(X[:, 0], X[:, 1], s=30, alpha=0.7)
    # scatter.set_array(state['labels'])
    # print(state['labels'])
    # scatter.set_cmap('viridis')
    colors = ['crimson', 'darkorange', 'lime', 'cornflowerblue']
    scatter.set_color([colors[i] for i in state['labels']])

    # Update centroids
    centroid_marks.set_offsets(state['centroids'])

    # Update iteration text
    iteration_text.set_text(f'Iteration: {state["iteration"]}')

    return scatter, centroid_marks, iteration_text

# Create animation
ani = FuncAnimation(fig, update, frames=len(kmeans.history)+1, interval=1000, blit=True)
plt.close()

# Display animation (for Jupyter notebook)
HTML(ani.to_jshtml())

# Alternative: Save as GIF
# ani.save('kmeans_iterations.gif', writer='pillow', fps=1)

# If not in Jupyter, you can display each iteration as separate plots
# if False:  # Change to True to see static plots
#     for i, state in enumerate(kmeans.history):
#         plt.figure(figsize=(8, 6))
#         plt.scatter(X[:, 0], X[:, 1], c=state['labels'], cmap='viridis', s=30, alpha=0.7)
#         plt.scatter(state['centroids'][:, 0], state['centroids'][:, 1],
#                   c='red', marker='x', s=100, linewidth=2)
#         plt.title(f"K-Means Clustering - Iteration {state['iteration']}")
#         plt.xlabel("Feature 1")
#         plt.ylabel("Feature 2")
#         plt.show()

  self.colors = plt.cm.get_cmap('viridis', n_clusters)  # Distinct colors for each cluster
