## Image segmentation using EM 

Segment images using a clustering method - each segment is the cluster center to which a pixel belongs. 

Image pixels are represented by their r, g, and b values. EM algorithm is applied to the mixture of normal distribution model to cluster image pixels, then the image is segmented by mapping each pixel to the cluster center with the highest value of the posterior probability for that pixel.

#### Segment each of the images to 10, 20, and 50 segments. These segmented images are displayed as images, where each pixel's color is replaced with the mean color of the closest segment.

In [30]:
# import libs
import sys
import numpy as np
import matplotlib.pyplot as plt

from scipy import misc
from sklearn.cluster import KMeans

# read data
goby_img = misc.imread('images/goby.jpg') # 480 x 640 x 3 array
(H, W, N) = goby_img.shape
data = goby_img.reshape((H * W, N))

In [31]:
def logsumexp(X):
    """
    log sum exp trick for approximation to avoid underflow/ overflow
    :param X: data matrix
    :return: log sum exp applied to each row of input matrix X
    """
    x_max = X.max(1)
    return x_max + np.log(np.exp(X - x_max[:, None]).sum(1))

We use k-means to compute rough cluster centers and mixture weights.

*Caution*: pi should not have a zero-element. An infinitesimal smoothing must be applied in such situations, else no documents may be assigned to the corresponding topic.

*We define a function so we can repeat the process for different images*

In [35]:
CONVERGENCE_THRESHOLD = 0.0001

def EM(X, img_name):
    """
    Function that performs EM for image segmentation for [10, 20, 50] segments and displays results as image
    :param X: data matrix where each row is [R,G,B]
    """
    estimators = {'k_means_10': KMeans(n_clusters=10),
                  'k_means_20': KMeans(n_clusters=20),
                  'k_means_50': KMeans(n_clusters=50)}
    
    for name, est in estimators.items():
        print(name)
        J = int(name[8:])
        NUM_PIXELS = X.shape[0]
        
        # perform k means
        est.fit(X)
        segments_id = est.labels_
        
        # get initial cluster centers/ means from k-means
        means = est.cluster_centers_
        
        # get initial pi from k-means
        pi = np.array([np.sum(segments_id == i) for i in range(J)])
        pi = pi / float(NUM_PIXELS)

        # check that there are no zero values in pi's
        print("check there are no zero values in pi", 0 not in pi)
        
        ### EM ###
        prev_Q = sys.maxsize
        
        while True:
            
            ## E-Step ##
            ll = np.zeros((NUM_PIXELS, J))
            for j in range(J):
                ll[:,j] = -0.5 * np.sum((X - means[j,])**2, 1)
            
            # compute w_ij
            w = np.exp(ll) @ np.diag(pi)
            w = (w.T / np.sum(w,1)).T
            
            # compute Q without constant K
            Q = np.sum(ll * w)
            print(Q)
            
            # check for convergence
            if abs(Q - prev_Q) <= CONVERGENCE_THRESHOLD:
                break
            else:
                prev_Q = Q
            
            ## M-Step ##
            
            # update means
            for j in range(J):
                means[j,] = np.sum((x.T * w[:,j]).T, 0) / np.sum(w[:,j])

            # update pi
            pi = np.sum(w, 0)/ NUM_PIXELS
            
            # display result as segmented image
            misc.imsave('images/goby.jpg')
        
#EM(data)

#### Segment the sunset image to 20 segments using five different start points, and display the result for each case. 

Is there much variation in the result?