In [None]:
import pandas as pd
import torch
from matplotlib.colors import to_hex
from matplotlib import colormaps
from plotnine import *
from plotnine.animation import PlotnineAnimation
from IPython.core.display import HTML


def make_plot(samples: torch.Tensor, centroids: torch.Tensor):
    """
    Plots a set of samples and centroids in a 2D scatter plot.
    """

    samples = samples.data
    centroids = centroids.data
    distances = torch.cdist(samples, centroids)
    closest_centroids = torch.argmin(distances, dim=1)

    # Set up plot data, concatenating sample and centroids coordinates. We
    # assign different colors and sizes to the centroids. The samples will be
    # colored according to the centroids they belong to.
    data = {
        "x": samples[:, 0].tolist() + [centroids[i, 0].item() for i in range(len(centroids))],
        "y": samples[:, 1].tolist() + [centroids[i, 1].item() for i in range(len(centroids))],
        "size": [0.5] * samples.shape[0] + [2] * len(centroids),
        "color": [f"Centroid {closest_centroids[i]+1}" for i in range(len(samples))] + \
            [f"Centroid {i+1}" for i in range(len(centroids))],
    }

    # Generate a color map with a unique color for each centroid.
    color_map = colormaps.get_cmap("rainbow")
    colors = [to_hex(color_map(i)) for i in torch.linspace(0, 1, len(centroids))]
    color_dict = {f"Centroid {i+1}": colors[i] for i in range(len(centroids))}

    df = pd.DataFrame.from_dict(data)
    plot = ggplot(df) + \
        geom_point(aes(x="x", y="y", fill="color", size="size"), show_legend=False) + \
        scale_fill_manual(color_dict)
    return plot


# Hyperparameters
learning_rate = 0.0002
num_training_iterations = 100
num_training_samples = 1000
num_centroids = 4

# Create training data from a superposition of a normal and a uniform distribution.
samples = torch.cat([
    torch.randn((num_training_samples // 2, 2), requires_grad=False) * 0.5 - 0.5,
    torch.rand((num_training_samples - num_training_samples // 2, 2), requires_grad=False),
])

# Initialize centroids. Make sure they're not scattered too far apart, as this
# simple implementation can spawn zombie centroids which don't move during the
# optimization process, if after initialization they don't have any samples
# assigned to them. The loss function below only measures the sum of distances
# of each centroids with respect to its own samples.
centroids = (0.2 * torch.randn((num_centroids, 2))).requires_grad_()


def step():
    """
    Performs a single optimization step, returning a plot.
    """

    global centroids
    global learning_rate

    distances = torch.cdist(samples, centroids)
    centroid_idx = distances.argmin(dim=1)
    loss = torch.sum(distances.gather(1, centroid_idx.unsqueeze(dim=1)).squeeze())
    loss.backward()

    with torch.no_grad():
        centroids -= learning_rate * centroids.grad
    centroids.grad.zero_()

    plot = make_plot(samples, centroids)
    return plot


# Produce an animation.
plots = [step() for _ in range(num_training_iterations)]
ani = PlotnineAnimation(plots, interval=100, repeat_delay=500)
HTML(ani.to_jshtml())