In [156]:
import numpy as np
from PIL import Image
import os
import time

In [157]:
def getcentroids(centroids,clusters):
    """
    obtain centroid for the cluster of point
    """
    k = len(centroids)
    new_centroids = np.ndarray(shape=(k,3))
    for i in range(k):
        icluster= imgarr[np.where(clusters==i)]
        new_centroids[i] = np.mean(icluster, axis=0)        
    return new_centroids

In [158]:
def getclusters(centroids):
    """
    obtain clusters of points based on their distance to centroids
    """
    k=len(centroids)    
    distances = np.ndarray(shape=((k,)+np.shape(imgarr)[:-1]))
    for i in range(k):
        distances[i] = np.sqrt(np.sum((imgarr-centroids[i])**2,axis=2))
    clusters = np.argmin(distances, axis = 0)
    return clusters    

In [159]:
def converged(centroids,old_centroids):
    """
    check if centroid is convereged
    """
    threshold = 0.1
    val = np.sum((centroids-old_centroids)**2)
    if val<threshold:
        return True
    else:
        return False    

In [160]:
def getsegmentation(centroids):
    """
    obtain segmented image
    """
    img = Image.new('L', (img_width, img_height))
    img = np.asarray(img)
    p = np.sum(centroids,axis=1)
    q=p.argsort()
    centroids = centroids[q,:]
    clusters = getclusters(centroids)
    img.setflags(write=1)
    for i in range(k):
        img[np.where(clusters==i)] = i*255
    return img

In [161]:
def dice_coef(y_true, y_pred,smooth=1):
    """
    obtain dice coefficient
    """
    y_true_f = np.ndarray.flatten(y_true)
    y_pred_f = np.ndarray.flatten(y_pred)
    intersection = np.sum(y_true_f * y_pred_f)
    return (2. * intersection +smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) +smooth)

In [162]:
def start():
    k=2
    centroids = np.ndarray(shape = (k,3))
    old_centroids =  np.zeros(shape= (k,3))
    i=0
    while i<k:
        centroid = imgarr[np.random.randint(0,img_height),np.random.randint(0,img_width)]
        if centroid in centroids:
            continue
        centroids[i] = centroid
        i+=1

    j=0 
    while not converged(centroids,old_centroids) and j<=50:
        old_centroids= centroids
        clusters = getclusters(centroids)
        centroids = getcentroids(centroids,clusters)                              
        j+=1

    
    predict = getsegmentation(centroids.astype(int))
    return predict



In [163]:
start_time = time.time()
path = os.path.join("img","images")
files = os.listdir("./img/images/")
dice = []
for file in files:
    if "mask" in file:
        continue
    pathimg= os.path.join(path,file)
    img = Image.open(pathimg)
    img = img.resize((int(512*1.5),512))
    img_width, img_height = img.size
    imgarr = np.asarray(img)
    name= file.split('.')[0]
    maskname = name+"_mask.tif"
    pathmask= os.path.join(path,maskname)
    mask = Image.open(pathmask).convert('L')
    mask = mask.resize((int(512*1.5),512))
    maskarr = np.asarray(mask)
    predict = start()
    output=Image.fromarray(predict)
    pathsave = os.path.join("img","predicts",name+'_predict.jpg')
    output.save(pathsave)
    dice.append(dice_coef(maskarr/255,predict/255))
print(np.mean(dice))
print("--- %s seconds ---" % (time.time() - start_time))

0.9976832773994307
--- 37.2862343788147 seconds ---


0.9976748067708263
