K-Means for the Iris dataset on Apache Spark (based on https://github.com/apache/spark/blob/master/examples/src/main/python/kmeans.py)

In [None]:
import numpy as np

In [None]:
# NOTE: Start the distributed file system first
# read input data from HDFS
lines = sc.textFile("hdfs:///user/lsda/iris.data")

In [None]:
def parseVector(line):
    """ Parses an input line and generates a 
    vector (numpy array) containing the points.
    """
        
    # last entry is the label (not used for K-Means)
    return np.array([float(x) for x in line.split(',')[:-1]])

In [None]:
def closestPoint(p):
    """ Gets a new point p computes the
    closest cluster index for p given the
    broadcasted centers
    """
    
    # get broadcasted centers
    centers = centers_bc.value
    
    bestIndex = 0
    closest = float("+inf")
    
    for i in range(len(centers)):
        
        tempDist = np.sum((p - centers[i]) ** 2)
        
        if tempDist < closest:
            closest = tempDist
            bestIndex = i

    return bestIndex

In [None]:
# important: filter bad lines beforehand!
lines_filtered = lines.filter(lambda line: len(line.split(",")) == 5)

# let us cache the data since we will need it in many iterations!
data = lines_filtered.map(parseVector).cache()

In [None]:
K = 3
convergeDist = 0.0001

In [None]:
# take K points without replacement; use seed 1
centers = data.takeSample(False, K, 1)

# current distance (big value)
d = 1000.0
iteration = 0

while d > convergeDist:
    
    # broadcast the current cluster centers
    # (not optimal to use here: see discussion during lecture)
    centers_bc = sc.broadcast(centers) 

    # compute closest cluster index for each point
    closest = data.map(lambda p: (closestPoint(p), (p, 1)))
    
    # for each key (cluster index), compute the sum of centers and the sum of points
    stats = closest.reduceByKey(lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1]))
    
    # for each (key,value) == (cluster_index, (sums, counts)), compute new centers
    new_centers_pairs = stats.map(lambda st: (st[0], st[1][0] / st[1][1])).collect()
    # compute distance between old and new cluster centers
    d = sum(np.sum((centers[index] - c) ** 2) for (index, c) in new_centers_pairs)

    # the (key,value) pairs are not necessarily sorted; we have to 
    # overwrite the current cluster centers in the right way (based 
    # provided cluster index)
    for (iK, p) in new_centers_pairs:
        centers[iK] = p
        
    print("Iteration {}: {}".format(iteration, d))        
    iteration += 1

In [None]:
print("Final cluster centers:\n")
for c in centers:
    print(c)