In [1]:
import numpy as np
import math
import matplotlib
matplotlib.use('TkAgg') 
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import cv2

In [2]:
#EM algorithm for Gaussian Mixture Model.
#the training data (target pixel values) are stored in "ball_values.txt"

#RGB: NxC(C=3 this time)
#K  : # of gaussian
#Z  : latent variable (NxK)

In [3]:
with open("ball_values.txt") as f:
    RGB = f.readlines()
RGB = [x.strip().split(" ") for x in RGB] 
RGB = np.array([[float(x) for x in line if len(x) != 0] for line in RGB])/255

In [4]:



class EM_gauss():
    
    def __init__(self,data):
        self.sigma=np.identity(data.shape[1])#np.ones((data.shape[1],data.shape[1]))
        self.mean =np.random.rand(data.shape[1])
        self.z=np.random.rand(data.shape[0])

    def get_mean_sig(self):
        return self.mean, self.sigma
    
    def get_latent(self):
        return self.z
    
    def update_latent(self,new_z):
        self.z = new_z
        return 0
    
    def update_m_s(self,m,s):
        self.mean = m
        self.sigma = s
        return 0
    
    def norm_pdf_multivariate(self, x, mu, sigma):
        size = len(x)
        
        if size == len(mu) and (size, size) == sigma.shape:
            det = np.linalg.det(sigma)
            
            if det == 0:
                raise NameError("The covariance matrix can't be singular")

            norm_const = 1.0/ (math.pow((2*math.pi),float(size)/2) * math.pow(det,1.0/2))
            x_mu = x - mu
            inv = np.linalg.inv(sigma)
            result = math.pow(math.e, -0.5 * (np.dot(np.dot(x_mu, inv), x_mu.T)))

            return norm_const * result
        else:
            raise NameError("The dimensions of the input don't match")

class EM_comp():
    def __init__(self,data):
        self.num_of_data=data.shape[0]
        
    def e_step(self,gausses):
        mean=[]
        sigma=[]
        for g in gausses:
            z=g.get_latent()
            m = np.sum(np.repeat(np.array(z).reshape(-1,1),3,axis=1)*RGB,0)/np.sum(z)
            mean.append(m)
            sig=np.zeros((3,3))
            for i,x in enumerate(RGB):
                sig += z[i]*np.dot((x-m).reshape(-1,1),(x-m).reshape(1,-1))
            sigma.append(sig/np.sum(z))

        [g.update_m_s(mean[i],sigma[i]) for i,g in enumerate(gausses)]
       
        return 0

    def m_step(self,data,gausses):

        z = []

        for i,x in enumerate(data):
            
            denom = 0
            for g in gausses:
                mu,sigma = g.get_mean_sig()
                denom += g.norm_pdf_multivariate(x, mu, sigma)
        
            for g in gausses:
                mu,sigma = g.get_mean_sig()
                z.append(g.norm_pdf_multivariate(x, mu, sigma)/denom)
        
        z=np.array(z).reshape(-1,len(gausses))
        [g.update_latent(z[:,i]) for i,g in enumerate(gausses)]
        
        return 0
    
    def inference(self,input,gausses):
        ans = 0
        for g in gausses:
            pi = np.sum(g.get_latent())/self.num_of_data
            mean,sigma = g.get_mean_sig()
            ans += pi*g.norm_pdf_multivariate(input, mean, sigma)
        #pi1*norm_pdf_multivariate(input,mu,sigma)+pi2*norm_pdf_multivariate(input,mu2,sigma2)
        return ans

    
if __name__=='__main__':
    
    em_comp=EM_comp(RGB)
    gauss1=EM_gauss(RGB)
    gauss2=EM_gauss(RGB)
    em_iter = 50
    threshold = 200

    # repeat EM steps
    for i in range(em_iter):
        em_comp.m_step(RGB,[gauss1,gauss2])
        em_comp.e_step([gauss1,gauss2])
        
        
    #load image and apply smoothing on it
    img = cv2.imread("001.png")/float(255)
    img = cv2.GaussianBlur(img,(5,5),0)
    binary = np.zeros(img.shape)

    for h in range(len(img)):
        for w in range(len(img[h])):
            #Note that opencv load RGB images as BGR.
            if em_comp.inference([img[h][w][2],img[h][w][1],img[h][w][0]],[gauss1,gauss2]) > threshold:
                binary[h][w]=1
            
    #apply the inferred binary mask
    img = img*binary
    cv2.imshow('image',img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()