In [3]:
import matplotlib.animation
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML

In [4]:
colors = ["#fcba03","#0318fc", "#ad05a5"]

## generate_cluster_data()

Generates synthetic matrix of clustered data.

In [5]:
def generate_cluster_data():
    rng = np.random.default_rng(10293847)

    x_1 = rng.normal(100, 30, 100)
    y_1 = rng.normal(100, 30, 100)

    x_2 = rng.normal(800, 25, 200)
    y_2 = rng.normal(700, 100, 200)

    x_3 = rng.normal(300, 80, 150)
    y_3 = rng.normal(600, 120, 150)

    xcoords = np.concatenate((x_1, x_2, x_3))
    ycoords = np.concatenate((y_1, y_2, y_3))

    points = np.column_stack((xcoords, ycoords))
    rng.shuffle(points)
    return points

## KMeans class
Used to allow stepping through the k-means process.

In [6]:
class KMeans:
    def __init__(self, data, n_centroids=3, seed=None):
        self._rng = np.random.default_rng(seed)
        self._n_centroids = n_centroids
        self._data = data
        self._width = data.shape[1]
        self._length = data.shape[0]
        self._centroids = np.zeros((self._n_centroids, self._width))
        self._clusters = []

    @property
    def centroids(self):
        return self._centroids

    @property
    def clusters(self):
        return self._clusters

    def initial_centroids(self, centroids=None):
        """
        Randomly pick 3 points from data to be our centroids - make sure there are no duplicates.
        """
        if centroids is not None:
            self._centroids = centroids
        else:
            self._random_centroids()
        

    def _random_centroids(self):
        picked = []

        n = 0
        while len(picked) < self._n_centroids:
            rx = self._rng.integers(0, self._length)
            if rx in picked:
                continue            
            self._centroids[n] = self._data[rx]
            picked.append(rx)  
            n += 1

    def assign_clusters(self):
        clusters = [[] for n in range(self._n_centroids)]
        
        for x, point in enumerate(self._data):
            distance = np.linalg.norm(point - self._centroids[0])
            picked_centroid = 0
            for idx in range(1, self._n_centroids):
                n_distance = np.linalg.norm(point - self._centroids[idx])
                if n_distance < distance:
                    distance = n_distance
                    picked_centroid = idx
            clusters[picked_centroid].append(point)

        self._clusters = [np.array(c) for c in clusters]


    def update_centroids(self):
        
        changed = False
        centroids = []
        for x, cluster in enumerate(self._clusters):
            if len(cluster) > 0:
                new_centroid = cluster.sum(axis=0)/len(cluster)
                centroids.append(new_centroid)

        new_centroids = np.array(centroids)
        self._centroids = new_centroids

In [2]:
def cluster_animation():
    points = generate_cluster_data()
    
    km = KMeans(points, seed=81828)
    km.initial_centroids(np.array([[900, 800], [200, 600], [600, 100]]))
    km.centroids
    
    fig, ax = plt.subplots()
    
    def animate(t):
        ax.clear()
        
        ax.set_xlim([0, 1100])
        ax.set_ylim([0, 1000])

        if t == 0:
            ax.scatter(points[:,0],points[:,1], s=1)
            return
            
        if t == 1:
            ax.scatter(points[:,0],points[:,1], s=1)
            for idx, centroid in enumerate(km.centroids):
                ax.scatter(centroid[0], centroid[1], s=100, c=colors[idx], marker="*")
            return
    
        if t % 2 == 0:
            km.assign_clusters()
    
        if t % 2 != 0:
            km.update_centroids();
        

        for idx, cluster in enumerate(km.clusters):
            ax.scatter(cluster[:,0],cluster[:,1], s=1, c=colors[idx])
    
        for idx, centroid in enumerate(km.centroids):
            ax.scatter(centroid[0], centroid[1], s=100, c=colors[idx], marker="*")
    
    
    anim = matplotlib.animation.FuncAnimation(fig, animate, frames=7)
    plt.close()
    return anim

In [86]:
class Animation:
    def __init__(self, plot_init=None):
        self._plot_init = plot_init
        self._frames = []
        
    def add(self, frames, callback):
        self._frames.append((frames, callback))
        return self
    
    def play(self):
        # plt.rcParams["animation.html"] = "jshtml"
        # plt.rcParams['figure.dpi'] = 150  
        # plt.cla()
        # plt.ioff()
        fig, ax = plt.subplots()

        def animate(t):
            ax.clear()
            if self._plot_init is not None:
                self._plot_init(ax)
            
            for (frames, callback) in self._frames:
                if t in frames:
                    callback(ax)

        
        
        anim = matplotlib.animation.FuncAnimation(fig, animate, frames=len(self._frames))
        plt.close()
        return anim


In [87]:
def fnplot(fn, start=-1, end=1, steps=100):
    x = np.linspace(start, end, steps)
    y = [fn(r) for r in x]
    plt.plot(x, y)
    plt.show()

In [1]:
def anim_init(ax):
    ax.set_xlim([0, 10])
    ax.set_ylim([0, 10])

def points(ax):
    ax.scatter([3, 6], [3, 7])
    ax.text(3,2.5, "(3,3)")
    ax.text(6.2,7, "(6,7)")

def seg_1(ax):
    ax.arrow(3,3,3,0)
    
def seg_2(ax): 
    ax.arrow(6,3,0,4)
def seg_3(ax): 
    ax.arrow(3,3,3,4)
def seg_4(ax): 
    ax.text(4,2, "$x_2 - x_1$")
def seg_5(ax):
    ax.text(6.5, 5, "$y_2 - y_1$")

def seg_6(ax):
    ax.text(1, 5, "$ \sqrt{ (x_2 - x_1)^2 + (y_2 - y_1)^2}$", rotation=0)

def distance_animation():
    frames = 7
    a = Animation(anim_init)
    
    a.add(list(range(0, frames)), points)
    a.add(list(range(1, frames)), seg_1)
    a.add(list(range(2, frames)), seg_2)
    a.add(list(range(3, frames)), seg_3)
    a.add(list(range(4, frames)), seg_4)
    a.add(list(range(5, frames)), seg_5)
    a.add(list(range(6, frames)), seg_6) 
    return a.play()



In [2]:
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))