In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def D(a,b):
    return abs(a-b)

def normalize(data):
    m_, s_ = data.mean(),data.std()
    data -= m_
    data /= s_
    a,b = 0.,255.
    c,d = data.min(),data.max()
    data = (data-c)*(b-a)/(d-c)+a
    return data

def init_c(data,k):
    R,C = data.shape[0],data.shape[1]
    r,c = np.random.randint(R),np.random.randint(C)
    centroid = [data[r,c]]
    dist = []
    for i in range(k-1):
        for r in range(R):
            for c in range(C):
                p = data[r,c]
                d = sys.maxsize
                for j in centroid: d = min(d,D(p,j))
                dist.append(d)
        dist = np.array(dist).reshape(R,C)
        nxt_c = data[np.unravel_index(dist.argmax(), dist.shape)]
        centroid.append(nxt_c)
        dist = []
    return np.array(centroid)

def assign_to_clusters(data, centroid):
    R,C = data.shape[0],data.shape[1]
    l = []
    for r in range(R):
        for c in range(C):
            l.append(np.argmin(np.array([D(data[r,c],cen) for cen in centroid])))
    l = np.array(l).reshape(R,C)
    return l

def cluster_filter(data,label):
    filter = []
    for l in np.unique(label):
        filter.append(np.array([int(i) for i in ((label == l).flatten())]).reshape(data.shape[0],-1))
    return filter

def plot_clusters(data,label):
    filter = cluster_filter(data,label)
    k = len(filter)
    col = 5
    row = k // col + (len(filter) % col > 0)
    count = 1
    plt.figure(figsize=(15,15))
    plt.title(f'All {k} Clusters')
    for i in range(k):
        plt.subplot(row,col,i+1)
        image_ = data*filter[i]
        plt.imshow(image_, cmap = 'gray')
        plt.title(f'Cluster - {i+1}')
        count += 1
    return filter

def distortion(data,label,centroid):
    distortion = 0
    L = np.unique(label)
    for l in L:
        c = centroid[l]
        distortion += np.mean([D(c,i) for i in data[label==l]])
    distortion /= len(L)
    return distortion


def silhouette_score(data,label):
    silhouette = 0
    d_ = sys.maxsize
    data = np.array(data)
    n = len(np.unique(label))
    for i in range(n):
        c = data[label == i]
        s = np.random.choice(c)
        a = np.mean([D(s,i) for i in c])
        for j in range(n):
            if j != i:
                c_temp = data[label==j]
                b = min(d_,np.mean([D(s,i) for i in c_temp]))
        silhouette += (b-a)
        silhouette /= a if a>b else b
    silhouette /= n
    return silhouette
    
def optimum_k(data, show_elbow = False, min_ = 1, max_ = 5):
    k = []
    sc = []
    d = []
    for i in range(min_,max_+1): 
        model = kmeans(data,k=i)
        k.append(model[0])
        sc.append(model[-2])
        d.append(model[-1])
    
    if show_elbow:
        plt.plot(k,d)
        plt.xlabel('Number of cluster')
        plt.ylabel('Distortion')
        plt.show()
    return k[np.argmax(sc)]


def kmeans(data,k=5, auto_ = False, max_iters = 50,threshold=0.12, plot_cluster = False):
    if data.min() < 0 or data.max() > 255: data = normalize(data)
    lbl = [i for i in range(k)]
    prev_c , i = init_c(data,k) , 0
    filter_ = None

    while i < max_iters:
        print(f"Iteration {i+1} -- Centroids {prev_c} ")
        label = assign_to_clusters(data,prev_c)
        l = np.unique(label)
        new_c = np.array([int(np.mean(data[label == i])) for i in l])
        if (np.unique(abs(new_c - prev_c)) < threshold)[0]: break
        prev_c = new_c
        i += 1
    centroid_ = new_c
    labels_ = assign_to_clusters(data,new_c)
    sc_score = silhouette_score(data,labels_)
    distortion_ = distortion(data,labels_,centroid_)
    
    if plot_cluster: filter_ = plot_clusters(data,labels_)

    return i, k, centroid_, labels_, filter_ sc_score, distortion_