In [None]:
import matplotlib.pyplot as plt
import PIL as pil
import numpy as np

from PIL import Image

In [None]:
def init_image(img):
    size = (np.array(img).shape[0], np.array(img).shape[1])
    
    return size

In [None]:
def convert_1d(img, size):
    img_1d = np.ndarray((size[0] * size[1], np.array(img).shape[2]), 'uint8', np.array(img).astype('uint8'))
    return img_1d

In [None]:
def random_centroids(k, img_1d, init_centroids):
    centroids = []
    _img_1d = np.unique(img_1d, axis = 0)
    if (init_centroids == 'random'):
        centroids = np.random.randint(0, 256, (k, img_1d.shape[1]), dtype='uint8')
    elif (init_centroids == 'in_pixels'):
        for _ in range(k):
            centroids.append(_img_1d[np.random.randint(0, _img_1d.shape[0])])

    return np.array(centroids)

In [None]:
def init_labels(img_1d, k, init_centroids):
    _img_1d = img_1d.reshape(img_1d.shape[0], 1, img_1d.shape[1]).astype('int64')
    centroids = random_centroids(k, img_1d, init_centroids).reshape(1, k, img_1d.shape[1]).astype('int64')
    dist = np.sum((_img_1d - centroids) ** 2, axis = -1)
    return np.array(np.argmin(dist, axis = -1))

In [None]:
def kmeans(img_1d, k_clusters, max_iter, init_centroids='random'):
    
    centroids = np.ndarray((k_clusters, img_1d.shape[1]), 'uint8', random_centroids(k_clusters, img_1d, init_centroids))
    labels = init_labels(img_1d, k_clusters, init_centroids)
    
    for _ in range(max_iter):
        _centroids = np.copy(centroids)
        #update labels
        labels = init_labels(img_1d, k_clusters, init_centroids)
        #re-calc centroid
        for i in range(k_clusters):
            if img_1d[labels == i].size > 0:
                centroids[i] = np.mean(img_1d[labels == i], axis = 0)
 
        if np.array_equal(_centroids, centroids):
            break
        
    return centroids.astype('uint8'), labels

    

In [None]:
def convert_2d(size, centroids, labels):
    __img_1d = np.ndarray((np.array(labels).shape[0], np.array(centroids).shape[1]), 'uint8', np.array([centroids[i] for i in labels]))
    
    return np.reshape(__img_1d, (size[0], size[1], centroids.shape[1]))

In [None]:
def print_img(imgs, row, col):
    for i in range(len(imgs)):
        plt.subplot(row, col, i+1)
        plt.imshow(imgs[i])
    plt.show()

In [None]:
def main():
    img_name = input('Enter name of image')
    k_clusters = int(input('Enter the number of clusters'))
    max_iter = int(input('Enter the max iteration'))
    init_centroids = input('Initialize centroids')
    file_type = input('Enter type of saved type:')

    save_name = 'new_img.' + file_type
    img = Image.open(img_name)
        
    if img.mode == 'RGBA':
        img = img.convert('RGB')
    
    size = init_image(img)
    img_1d = convert_1d(img, size)
    
    centroids, labels = kmeans(img_1d, k_clusters, max_iter, init_centroids)
    res = convert_2d(size, centroids, labels)

    imgs = [img, res]
    print_img(imgs, 1, 2)
    
    Image.fromarray(res).save(save_name)

In [None]:
main()