# **Animation of K-Means Clustering**

Source:  [https://github.com/d-insight/code-bank.git](https://github.com/d-insight/code-bank.git)  
License: [MIT License](https://opensource.org/licenses/MIT). See open source [license](LICENSE) in the Code Bank repository. 

-------------

## Overview

This illustration shows the internal, itterative working of the K-Means algorithm, for different distributions of data. Code is adapted from: https://nrsyed.com/2017/11/20/animating-k-means-clustering-in-2d-with-matplotlib/

-------------

## **Part 0**: Setup

In [None]:
# Import all packages 

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import colorsys
from helper.i3_kmeans_helper import KMeansND

from sklearn.metrics import silhouette_score
from sklearn.cluster import KMeans

from IPython.display import HTML

%matplotlib inline


In [None]:
# define constants

PTSPERCLUSTER = 100
VARIANCECOEFF = 0.05    # This controls the spread of clustered points
FIGSIZE       = (24, 16)

XMIN = -4
XMAX = 4
YMIN = -4
YMAX = 4

## **Part 1**: Define all functions


In [None]:
# define covariance matrix 

xCenterBounds = (XMIN, XMAX)
yCenterBounds = (YMIN, YMAX)

covariance = np.array([[VARIANCECOEFF * (xCenterBounds[1] - xCenterBounds[0]), 0], [0, VARIANCECOEFF * (yCenterBounds[1] - yCenterBounds[0])]])

# randomly generate clusters 

def generateClusters(NUMCLUSTERS):
    
    centers = np.random.random_sample((NUMCLUSTERS, 2))
    
    centers[:,0] = (centers[:,0] * (xCenterBounds[1] - xCenterBounds[0]) + xCenterBounds[0])
    centers[:,1] = (centers[:,1] * (yCenterBounds[1] - yCenterBounds[0]) + yCenterBounds[0])

    points = np.zeros((NUMCLUSTERS * PTSPERCLUSTER, 2))
    
    for i in range(NUMCLUSTERS):
        
        points[i*PTSPERCLUSTER : (i+1)*PTSPERCLUSTER,:] = (np.random.multivariate_normal(centers[i,:], covariance, PTSPERCLUSTER))
                
    return points

# randomly initialize cluster centroids

def initializeCentroids(K, points):
    
    initialCentroids = sampl = np.random.uniform(low=XMIN, high=XMAX, size=(K,2))
    
    return initialCentroids

# update cluster information

def animate(clusterInfo):

    (currentCentroids, classifications, iteration) = clusterInfo
    
    for k in range(K):

        updatedClusterData = points[classifications == k,:]
        clusterObjList[k].set_data(updatedClusterData[:,0], updatedClusterData[:,1])
        centroidObjList[k].set_data(currentCentroids[k,0], currentCentroids[k,1])

    iterText.set_text('i = {:d}'.format(iteration))


def setColors(K):
    
    # Create figure and axes. Initialize cluster and centroid line objects.

    plt.rc('xtick',labelsize=16)
    plt.rc('ytick',labelsize=16)

    # Set cluster colors 
    
    fig, ax = plt.subplots(figsize = FIGSIZE)
    clusterObjList = []
    centroidObjList = []

    for k in range(K):

        clusterColor = tuple(colorsys.hsv_to_rgb(k / K, 0.8, 0.8))

        clusterLineObj, = ax.plot([], [], ls='None', marker='x', markersize = 12, color=clusterColor)
        clusterObjList.append(clusterLineObj)

        centroidLineObj, = ax.plot([], [], ls='None', marker='o', markersize = 16, markeredgecolor='k', color=clusterColor)
        centroidObjList.append(centroidLineObj)

    iterText = ax.annotate('', xy=(0.01, 0.01), xycoords='axes fraction')
    
    return fig, ax, iterText, clusterObjList, centroidObjList

def setAxisLimits(ax, points):
    
    xSpan = np.amax(points[:,0]) - np.amin(points[:,0])
    ySpan = np.amax(points[:,1]) - np.amin(points[:,1])
    
    pad = 0.05
    
    ax.set_xlim(np.amin(points[:,0]) - pad * xSpan, np.amax(points[:,0]) + pad * xSpan)
    ax.set_ylim(np.amin(points[:,1]) - pad * ySpan, np.amax(points[:,1]) + pad * ySpan)
    
def computeSilhouette(points, K):
    
    kmeans = KMeans(K).fit(points)
    score = silhouette_score(points, kmeans.labels_)
    
    print('Silhouette score for K={}: \t{}'.format(K, round(score, 4)))
    

## **Part 2**: Animate K-Means


In [None]:
# Set K, the number of clusters that generate the data and whether to generate new data 

K             = 6      # Number of clusters to find
NUMCLUSTERS   = 10     # Number of clusters to generate
generateData  = True  # True = randomly generate data, False = keep data currently in state

# -----------------------------------------------------------------------

# Generate random data 
if generateData:
    points       = generateClusters(NUMCLUSTERS)
elif 'points' not in globals():
    raise NameError('No data available. Randomly generate data by setting generateData = True.')
    
initialCentroids = initializeCentroids(K, points)
genFunc          = KMeansND(initialCentroids, points).getGeneratorFunc()
fig, ax, iterText, clusterObjList, centroidObjList = setColors(K)
setAxisLimits(ax, points)
animObj          = animation.FuncAnimation(fig, animate, frames=genFunc, repeat=True, interval=500)
# Show silhouette score
computeSilhouette(points, K)

plt.close(fig)

# Plot HTML animation 
HTML(animObj.to_jshtml())

In [None]:
# Silhouette scores from K=2 to K=20
for k in range(2, 21):
    computeSilhouette(points, k)