In [None]:
import random, operator, subprocess
from pyspark.sql.types import *

rdd = sc.textFile('data10k_6attr.csv') \
        .map(lambda line: line.split(',')) \
        .map(lambda elements: tuple([int(elements[i]) for i in range(len(elements))])) \
        .cache()

k = 10
dimension = 6
headers = ['age', 'height', 'weight', 'blood_sugar_level', 'child', 'exercise_hours']
max_cluster = rdd.count() / k
min_cluster = rdd.count() / (2*k-1)
loop_for_converge = 20
different_combination = 30

def dist(x, y):
    return sum([abs(x[i]-y[i]) for i in range(dimension)])

def get_nearest_centroid_idx(x, centroids):
    dists = {}
    for cluster in centroids:
        dists[cluster] = dist(x, centroids[cluster])
        
    cluster = min(dists, key=dists.get)
    return cluster

def assign_to_cluster(pt, available_centroids):
    nearest_centroid = get_nearest_centroid_idx(pt, available_centroids)
    return (nearest_centroid, ([pt], [dist(pt, available_centroids[nearest_centroid])]))

def calculate_pts_sum(pts):
    pts_sum = [0 for _ in range(dimension)]
    for pt in pts:
        for i in range(dimension):
            pts_sum[i] += pt[i]
    return pts_sum

def calculate_centroid(pts_sum, nb_pts):
    nb_pts = float(nb_pts)
    return [pts_sum[i]/nb_pts for i in range(dimension)]

def popup_available_pts(pts, dists):
    sorted_idx = sorted(range(len(dists)), key=lambda k: dists[k])
    for i in sorted_idx[k:]:
        yield pts[i]

def keep_pts(pts, dists):
    sorted_idx = sorted(range(len(dists)), key=lambda k: dists[k])
    return [pts[i] for i in sorted_idx[:k]]

def calculate_cost(pts, centroid):
    cost = 0
    for pt in pts:
        cost += dist(pt, centroid)
    return cost

def is_converge(old_cens, new_cens):
    diff = 0
    old_sum = 0
    for i in range(len(old_cens)):
        old_cen = old_cens[i]
        new_cen = new_cens[i]
        for j in range(dimension):
            diff += abs(new_cen[j] - old_cen[j])
            old_sum += old_cen[j]
    return abs(float(diff) / old_sum) < 0.000001

def write_to_output(assignment, centroids):
    tmp = assignment.flatMap(lambda (cluster, pts): [centroids[cluster] for _ in range(len(pts))])
    sqlContext.createDataFrame(tmp, headers[:dimension]).save('output.txt', mode='overwrite')
    
min_cost_rdd = None
min_cost = float('inf')
for want_cluster in range(min_cluster, max_cluster+1):
    print "trying " + str(want_cluster) + " clusters"
    for combination_idx in range(different_combination): # try different combination of initial centroid
        print "trying combination " + str(combination_idx) + "/" + str(different_combination)
        
        # convert centroids as a dictionary having index as key and centroid point as value
        tmp_centroids = rdd.takeSample(False, want_cluster)
        centroids = {}
        for i in range(len(tmp_centroids)):
            centroids[i] = tmp_centroids[i]
        
        for converge_idx in range(loop_for_converge): # try to converge
            print "trying to converge " + str(converge_idx) + "/" + str(loop_for_converge)
            
            available_pts_rdd = rdd
            available_centroids = centroids
            assignment = None
            
            # ensure each cluster has at least k members (k-anonymity)
            while(True):
                cluster_rdd = sc.parallelize([(i, ([], [])) for i in available_centroids])
                dist_to_cluster_rdd = available_pts_rdd.map(lambda pt: assign_to_cluster(pt, available_centroids)) \
                                         .reduceByKey(lambda (pt1, dist1), (pt2, dist2): (pt1+pt2,dist1+dist2)) \
                                         .cache()
                
                assignment_for_all_rdd = cluster_rdd.union(dist_to_cluster_rdd) \
                                           .reduceByKey(lambda (pts1, dists1), (pts2, dists2): (pts1+pts2, dists1+dists2)) \
                                            .cache()
            
                clusters_require_more_rdd = assignment_for_all_rdd.filter(lambda (cluster, (pts, dists)): len(dists) < k).cache()
#                 print "Require more: " + str(clusters_require_more_rdd.count())
#                 print "Assignemtn: " + str(assignment_for_all_rdd.collect())
                if(clusters_require_more_rdd.count() > 1): # some cluster has less than k members
                    completed_pts_rdd = assignment_for_all_rdd.map(lambda (cluster, (pts, dists)): (cluster, keep_pts(pts, dists)))
                    assignment = completed_pts_rdd if assignment == None else assignment.union(completed_pts_rdd).cache()
                    
                    available_centroids = clusters_require_more_rdd.map(lambda (cluster, _): (cluster, available_centroids[cluster])) \
                                            .collectAsMap()

                    available_pts_rdd = assignment_for_all_rdd.filter(lambda (cluster, (pts, dists)): len(dists) > k) \
                                            .flatMap(lambda (cluster, (pts, dists)): popup_available_pts(pts, dists))
                else: # each cluster has at least k elements
                    completed_pts_rdd = dist_to_cluster_rdd.map(lambda (cluster, (pts, dists)): (cluster, pts))
                    assignment = completed_pts_rdd if assignment == None else assignment.union(completed_pts_rdd).cache()
                    break
            
            assignment = assignment.reduceByKey(lambda x, y: x+y).cache() # final assignment rdd (cluster, pts)
            # calculate new centroids based on all points inside a cluster
            new_centroids = assignment.map(lambda (cluster, pts): (cluster, calculate_centroid(calculate_pts_sum(pts), len(pts))))\
                    .collectAsMap()
#             print assignment.collect()

            # update if it achieve smaller cost
            cost = assignment.map(lambda (cluster, pts): calculate_cost(pts, new_centroids[cluster])).reduce(lambda x,y: x+y)
            print "cost: " + str(cost)
            if(cost < min_cost):
                min_cost_rdd = assignment
                min_cost = cost
                write_to_output(min_cost_rdd, new_centroids)
                
            if(is_converge(centroids, new_centroids)):
                break
            centroids = new_centroids

trying 526 clusters
trying combination 0/30
trying to converge 0/20
cost: 223139.460424
trying to converge 1/20
cost: 260233.705755
trying to converge 2/20
cost: 283736.998457
trying to converge 3/20
cost: 279588.768123
trying to converge 4/20
cost: 279438.418918
trying to converge 5/20
cost: 295931.386301
trying to converge 6/20
cost: 152624.305212
trying to converge 7/20
cost: 352318.315347
trying to converge 8/20
cost: 320881.295529
trying to converge 9/20
cost: 320719.295107