Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
109 lines (96 sloc) 3.59 KB
import numpy as np
import matplotlib.pyplot as plt
# PARAMETERS AND VARIABLES
N = 500 # Number of data points
D = 2 # Number of dimensions of data points
C = 8 # Number of clusters
pointSize = 3 # Data point size in the plot
X = np.random.random((N,D)) # Set of data points
V = np.random.random((C,D)) # Set of cluster centers
clusterColors = np.random.random( (C,3) ) # Colors for each cluster
assignments = np.random.randint(0,C, N)# List of correspondences between X and V
figCount = 0
# FUNCTIONS
# Return distance between two points
def distance(p1,p2):
# calculate vector difference
difference = p2-p1
# calculate magnitude of difference (i.e. Euclidean distance)
return np.linalg.norm(difference)
# Return index of closest cluster center
def closestClusterCenter(p):
# Prepare variables of minimum distance
minDistance = 10
minIndex = 0
for index, clusterCenter in enumerate(V):
d = distance(p,clusterCenter)
if d < minDistance:
minDistance = d
minIndex = index
return minIndex
# Check distances to cluster centers and updates assignments
# Returns true if there was a change in assignments
def updateAssignments():
# Store a copy of the assignments before the process
originalAssignments = np.zeros(N)
np.copyto(originalAssignments,assignments)
# For each point with index in X...
for index,point in enumerate(X):
# get the closest cluster center
c = closestClusterCenter(point)
# Assign it to the assignments list
assignments[index] = c
# Returns True if there was a change
return not np.array_equal(originalAssignments, assignments)
# Make cluster centers the centroid of the points that belong to it
def updateClusterCenters():
# Count how many points are there per cluster center
counts = np.zeros(C)
# Initialize a place where to sum the points
sums = np.zeros( (C, D) )
# For each point with index in X
for index, point in enumerate(X):
# Get currently assigned cluster center index
ccIndex = assignments[index]
# Increment counter for that cluster center
counts[ccIndex] += 1
# Vector sum of the points in that cluster
sums[ccIndex] += point
# We can know calculate the centroids
for index, clusterCenter in enumerate(V):
# Calculate centroid
newClusterCenter = sums[index] / counts[index]
# Copy new cluster center to cluster center
np.copyto(clusterCenter, newClusterCenter)
# PLOTTING!
def plot():
plt.clf()
# For each point with index in X...
for index, point in enumerate(X):
# Get the cluster it belongs to
clusterNumber = int(assignments[index])
# Get the corresponding color
color = clusterColors[ clusterNumber ]
# Add to the scatter plot with the specific color
plt.scatter(point[0],point[1], c=color, s=pointSize)
# For each cluster center...
for index, clusterCenter in enumerate(V):
# Get corresponding color
color = clusterColors[index]
# Add to the scatter plot with a star
plt.scatter(clusterCenter[0], clusterCenter[1], c=color, marker='*')
# Generate a name for the image
figname = 'Figure%02d.png' % (figCount)
# Save it as an image
plt.savefig(figname)
# SCRIPT
# Repeat the process of updating assignments while there are changes
while updateAssignments():
# After updating assignments, update the cluster centers
updateClusterCenters()
# Make a plot
plot()
figCount += 1
print( "Iteration %d..." % figCount )
# Show the plot
plt.show()