In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from matplotlib.colors import to_hex
from plotnine import *
from plotnine.animation import PlotnineAnimation
from IPython.core.display import HTML
from scipy.spatial.distance import cdist


def make_plot(samples: torch.Tensor, centroids: torch.Tensor):
    """
    Plots a set of samples and centroids in a 2D scatter plot.
    """

    # Detach just in case.
    samples = samples.detach().cpu().numpy()
    centroids = centroids.detach().cpu().numpy()

    distances = cdist(samples, centroids)
    closest_centroids = np.argmin(distances, axis=1)

    # Set up plot data, concatenating sample and centroids coordinates. We
    # assign different colors and sizes to the centroids. The samples will be
    # colored according to the centroids they belong to.
    data = {
        "x": samples[:, 0].tolist() + [centroids[i, 0] for i in range(centroids.shape[0])],
        "y": samples[:, 1].tolist() + [centroids[i, 1] for i in range(centroids.shape[0])],
        "size": [0.5] * samples.shape[0] + [2] * centroids.shape[0],
        "color": [f"Centroid {closest_centroids[i]+1}" for i in range(samples.shape[0])] + \
            [f"Centroid {i+1}" for i in range(centroids.shape[0])],
    }

    # Generate a color map with a unique color for each centroid.
    color_map = plt.get_cmap("rainbow")
    colors = [to_hex(color_map(i)) for i in torch.linspace(0, 1, centroids.shape[0] + 1)]
    color_dict = dict()
    for i in range(centroids.shape[0]):
        color_dict[f"Centroid {i+1}"] = colors[i + 1]

    df = pd.DataFrame.from_dict(data)
    plot = ggplot(df) + \
        geom_point(aes(x="x", y="y", fill="color", size="size"), show_legend=False) + \
        scale_fill_manual(color_dict)
    return plot


# Hyperparameters
learning_rate = 0.0002
num_training_iterations = 100
num_training_samples = 1000
num_centroids = 4

# Create training data from a superposition of a normal and a uniform distribution.
samples = torch.cat([
    torch.randn((num_training_samples // 2, 2), requires_grad=False) * 0.5 - 0.5,
    torch.rand((num_training_samples - num_training_samples // 2, 2), requires_grad=False),
])

# Initialize centroids. Make sure they're not scattered too far apart, as this
# simple implementation can spawn zombie centroids which don't move during the
# optimization process, if after initialization they don't have any samples
# assigned to them. The loss function below only measures the sum of distances
# of each centroids with respect to its own samples.
centroids = (0.2 * torch.randn((num_centroids, 2))).requires_grad_()


def step():
    """
    Performs a single optimization step, returning a plot.
    """

    global centroids
    global learning_rate
    loss = torch.tensor(0.0, requires_grad=True)
    for i in range(samples.shape[0]):
        # Find closest centroid
        distances = (centroids - samples[i, :]).norm(2, dim=1)
        closest_centroid = distances.argmin()
        loss = loss + distances[closest_centroid]

    loss.backward()
    with torch.no_grad():
        centroids -= learning_rate * centroids.grad

    centroids.grad.zero_()
    plot = make_plot(samples, centroids)
    return plot


# Produce an animation.
plots = (step() for _ in range(num_training_iterations))
ani = PlotnineAnimation(plots, interval=100, repeat_delay=500)
HTML(ani.to_jshtml())