<a href="https://colab.research.google.com/github/zostaw/learning-regression-methods/blob/main/kmeans.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install jaxtyping



In [None]:
from jaxtyping import Array, Float, PyTree, Int
import jax.numpy as jnp
import jax.random as jrandom
import jax
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.colors as mcolors
import numpy as np
from IPython.display import HTML

In [None]:
key = jrandom.key(70)

# Starting from scratch

As explained in "The Elements of Statistical Learning" by T. Hastie, R. Tibshirani, J. Friedman, the algorithm for K-Means goes as follows:
1. For each data point, the closest cluster center is identified (Eucledean distance)
2. Each cluster center is replaced by the coordinate-wise average of all data point that are closest to it

## First step

We want to calculate Eucledian distances for each point versus centroid.  
In basic example with single data point and single centroid, the formula would be:
$$
distance = \sqrt{ \sum_i^n (point_i - centroid_i)^2 }
$$
where $distance \in \mathbb{R}$, $point, centroid \in \mathbb{R}^n$ and $n$ is dimensionality of the vector space.  
We want to calculate $distance$ to each centroid, then calculate argmin for each datapoint in order to pick associated centroid.  

In [None]:
def closest_centroid(datapoint: Float[Array, "feature_dims"], centroid: Float[Array, "centroid_dims" "feature_dims"]):
    # this is broadcasted into (centroid_dims, feature_dims)
    diff: Float[Array, "centroid_dims" "feature_dims"] = datapoint - centroid
    distance_per_centroid: Float[Array, "centroid_dims"] = jnp.sum(diff**2, axis=1)
    return jnp.argmin(distance_per_centroid, axis=0)

p = np.random.random((5, 2))
c = np.random.random((3, 2))

jax.vmap(closest_centroid, in_axes=(0, None))(p, c)

Array([0, 0, 1, 0, 2], dtype=int32)

## Second step

Once we found association between all datapoints and centroids, we calculate new positions of the centroids, by averaging the positions of points that are closests to each.  

In [None]:
def k_means_verbose(datapoints: Float[Array, "N feature_dims"], k: int, num_iterations: int, key1):
    centroids = jrandom.uniform(key1, shape=(k, datapoints.shape[-1]))

    for i in range(num_iterations):
        print("___________________________\nIteration:", i)

        # 1 step
        dp_distances = jax.vmap(closest_centroid, in_axes=(0, None))(datapoints, centroids)
        print("Distances:", dp_distances, "\n")

        # 2 step
        for cluster_id in range(centroids.shape[0]):
            print("\nCluster id", cluster_id)
            # find datapoints within cluster
            cluster_datapoints = datapoints[jnp.where(dp_distances == cluster_id)]
            print("Datapoints: ", cluster_datapoints, cluster_datapoints.shape)
            if jnp.size(cluster_datapoints) == 0:
                continue
            # calculate average among
            centroid_pos = np.mean(cluster_datapoints, axis=0)
            print("Centroid new pos:", centroid_pos)
            centroids = centroids.at[cluster_id].set(centroid_pos)

        print("___________________________\n\n")

key, key1 = jrandom.split(key)
k_means_verbose(p, 3, 4, key1)

___________________________
Iteration: 0
Distances: [0 0 1 0 1] 


Cluster id 0
Datapoints:  [[0.091053   0.75945372]
 [0.31835557 0.8426555 ]
 [0.22677102 0.84457623]] (3, 2)
Centroid new pos: [0.21205986 0.81556182]

Cluster id 1
Datapoints:  [[0.62605219 0.09644724]
 [0.70605345 0.08680608]] (2, 2)
Centroid new pos: [0.66605282 0.09162666]

Cluster id 2
Datapoints:  [] (0, 2)
___________________________


___________________________
Iteration: 1
Distances: [0 0 1 0 1] 


Cluster id 0
Datapoints:  [[0.091053   0.75945372]
 [0.31835557 0.8426555 ]
 [0.22677102 0.84457623]] (3, 2)
Centroid new pos: [0.21205986 0.81556182]

Cluster id 1
Datapoints:  [[0.62605219 0.09644724]
 [0.70605345 0.08680608]] (2, 2)
Centroid new pos: [0.66605282 0.09162666]

Cluster id 2
Datapoints:  [] (0, 2)
___________________________


___________________________
Iteration: 2
Distances: [0 0 1 0 1] 


Cluster id 0
Datapoints:  [[0.091053   0.75945372]
 [0.31835557 0.8426555 ]
 [0.22677102 0.84457623]] (3, 2)


# Visualization

Let us rewrite the function in a more consise way.  

In [None]:
def k_means_iteration(centroids, datapoints):
    """
    This is the key algorithm
    """
    # 1 step
    dp_distances = jax.vmap(closest_centroid, in_axes=(0, None))(datapoints, centroids)
    # 2 step
    centroid_associated_datapoints = []
    for centroid_id in range(centroids.shape[0]):
        # find datapoints within centroid
        centroid_datapoints = datapoints[jnp.where(dp_distances == centroid_id)]
        centroid_associated_datapoints.append(centroid_datapoints)
        if jnp.size(centroid_datapoints) == 0:
            continue
        # calculate average among asigned datapoints
        centroid_pos = np.mean(centroid_datapoints, axis=0)
        centroids = centroids.at[centroid_id].set(centroid_pos)
    return centroids, centroid_associated_datapoints

def k_means(datapoints: Float[Array, "N feature_dims"], k: int, key1, max_num_iterations=100):
    """
    This is iterator that also produces history for the animation
    """
    assert datapoints.shape[-1] == 2
    centroids = jrandom.uniform(key1, shape=(k, datapoints.shape[-1]))
    centroids_history = []
    centroids_datapoints_history = []

    for i in range(max_num_iterations):
        # track history of associations for visualization
        centroids_history.append(centroids)

        # step through
        centroids, centroid_associated_datapoints = k_means_iteration(centroids, datapoints)

        # track history of associations for visualization
        centroids_datapoints_history.append(centroid_associated_datapoints)

        # ending condition - repeat
        if (centroids_history[-1] == centroids).all():
            return centroids_history, centroids_datapoints_history
    # ending condition - reached max number of steps
    return centroids_history, centroids_datapoints_history

In [None]:
colors = ['tab:BLUE', 'tab:ORANGE', 'tab:BROWN', 'tab:GREY', 'tab:PINK']

def draw_centroids(centroids, datapoints, point_colors=None, ax=None):
    """
    Just standard scatter plot for single iteration
    """
    if not ax:
        fig, ax = plt.subplots(figsize=(8, 8))
    if point_colors is None:
        point_colors = np.array([colors[i] for i in range(centroids.shape[0])] + ["blue" for _ in range(datapoints.shape[0])])

    point_sizes = np.array([100 for i in range(centroids.shape[0])] + [15 for _ in range(datapoints.shape[0])])
    ax.scatter(jnp.concatenate((centroids[:, 0], datapoints[:, 0])),
                jnp.concatenate((centroids[:,1], datapoints[:,1])),
                c=point_colors,
                s=point_sizes)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    return ax


def generate_centroid_animation(centroids_history, centroids_datapoints_history):
    """
    Take history of centroids and datapoints associated with them -> print frame-by-frame the state.
    """
    fig, ax = plt.subplots(figsize=(8, 8))

    def update(frame):
        ax.clear()
        centroids = centroids_history[frame]

        point_colors = [colors[i] for i in range(len(centroids))]
        for i, item in enumerate(centroids_datapoints_history[frame]):
            for j, dpoints in enumerate(item):
                point_colors.append(colors[i])
                if i == j == 0:
                    datapoints = dpoints
                else:
                    datapoints = jnp.vstack((datapoints, dpoints))

        point_colors = np.array(point_colors)

        draw_centroids(centroids, datapoints, point_colors, ax)
        return ax,

    ani = animation.FuncAnimation(fig=fig, func=update, frames=len(centroids_history),
                                  interval=1000, blit=False, repeat=True)
    plt.close()
    return ani


key, key1 = jrandom.split(key)

num_samples, num_centroids = 2000, 4
centroids_history, centroids_datapoints_history = k_means(np.random.random((num_samples, 2)), num_centroids, key1)
ani = generate_centroid_animation(centroids_history, centroids_datapoints_history)
HTML(ani.to_jshtml())