In [47]:
import sys
import os
import numpy as np
import math
import cv2 
import igraph as ig
from sklearn.cluster import KMeans
import time
import matplotlib.pyplot as plt 

In [48]:
class EventHandler:
    """
    Class for handling user input during segmentation iterations 
    """
    
    def __init__(self, flags, img, _mask, colors):
        
        self.FLAGS = flags
        self.ix = -1
        self.iy = -1
        self.img = img
        self.img2 = self.img.copy()
        self._mask = _mask
        self.COLORS = colors

    @property
    def image(self):
        return self.img
    
    @image.setter
    def image(self, img):
        self.img = img
        
    @property
    def mask(self):
        return self._mask
    
    @mask.setter
    def mask(self, _mask):
        self._mask = _mask
    
    @property
    def flags(self):
        return self.FLAGS 
    
    @flags.setter
    def flags(self, flags):
        self.FLAGS = flags
    
    def handler(self, event, x, y, flags, param):

        # Draw the rectangle first
        if event == cv2.EVENT_RBUTTONDOWN:
            self.FLAGS['DRAW_RECT'] = True
            self.ix, self.iy = x,y

        elif event == cv2.EVENT_MOUSEMOVE:
            if self.FLAGS['DRAW_RECT'] == True:
                self.img = self.img2.copy()
                cv2.rectangle(self.img, (self.ix, self.iy), (x, y), self.COLORS['BLUE'], 2)
                self.FLAGS['RECT'] = (min(self.ix, x), min(self.iy, y), abs(self.ix - x), abs(self.iy - y))
                self.FLAGS['rect_or_mask'] = 0

        elif event == cv2.EVENT_RBUTTONUP:
            self.FLAGS['DRAW_RECT'] = False
            self.FLAGS['rect_over'] = True
            cv2.rectangle(self.img, (self.ix, self.iy), (x, y), self.COLORS['BLUE'], 2)
            self.FLAGS['RECT'] = (min(self.ix, x), min(self.iy, y), abs(self.ix - x), abs(self.iy - y))
            self.FLAGS['rect_or_mask'] = 0

        
        # Draw strokes for refinement 

        if event == cv2.EVENT_LBUTTONDOWN:
            if self.FLAGS['rect_over'] == False:
                print('Draw the rectangle first.')
            else:
                self.FLAGS['DRAW_STROKE'] = True
                cv2.circle(self.img, (x,y), 3, self.FLAGS['value']['color'], -1)
                cv2.circle(self._mask, (x,y), 3, self.FLAGS['value']['val'], -1)

        elif event == cv2.EVENT_MOUSEMOVE:
            if self.FLAGS['DRAW_STROKE'] == True:
                cv2.circle(self.img, (x, y), 3, self.FLAGS['value']['color'], -1)
                cv2.circle(self._mask, (x, y), 3, self.FLAGS['value']['val'], -1)

        elif event == cv2.EVENT_LBUTTONUP:
            if self.FLAGS['DRAW_STROKE'] == True:
                self.FLAGS['DRAW_STROKE'] = False
                cv2.circle(self.img, (x, y), 3, self.FLAGS['value']['color'], -1)
                cv2.circle(self._mask, (x, y), 3, self.FLAGS['value']['val'], -1)


In [49]:
class GMM:

    def __init__(self, k = 5):
        '''k is the number of components of GMM'''
        self.k = k
        self.weights = np.asarray([0. for i in range(k)]) # Weight of each component
        self.means = np.asarray([[0., 0., 0.] for i in range(k)]) # Means of each component
        self.covs = np.asarray([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]] for i in range(k)]) # Covs of each component
        self.cov_inv = np.asarray([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]] for i in range(k)])
        self.cov_det = np.asarray([0. for i in range(k)])
        self.pixel_counts = np.asarray([0. for i in range(k)]) # Count of pixels in each components
        self.pixel_total_count = 0 # The total number of pixels in the GMM

        # The following two parameters are assistant parameters for counting pixels and calc. pars.
        self._sums = np.asarray([[0., 0., 0.] for i in range(k)])
        self._prods = np.asarray([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]] for i in range(k)])

    
    def calc_prob(self, X):
        prob = []
        for ci in range(self.k):
            score = np.zeros(X.shape[0])
            if self.weights[ci] > 0:
                diff = X - self.means[ci]
                Tdiff = diff.T
                #inv_cov = np.linalg.inv(self.covariances[ci])
                dot = np.dot(self.cov_inv[ci], Tdiff)
                Tdot = dot.T
                mult = np.einsum('ij,ij->i', diff, Tdot)
                score = np.exp(-.5 * mult) / np.sqrt(2 * np.pi)/np.sqrt(np.linalg.det(self.covs[ci]))
            prob.append(score)
        ans = np.dot(self.weights, prob)
        return ans
            
    
    def calc_score(self, X, ci):
        score = np.zeros(X.shape[0])
        
        if self.weights[ci] > 0:
            diff = X - self.means[ci]
            Tdiff = diff.T
            inv_cov = np.linalg.inv(self.covs[ci])
            dot = np.dot(inv_cov, Tdiff)
            Tdot = dot.T
            mult = np.einsum('ij,ij->i', diff, Tdot)
            score = np.exp(-.5 * mult) / np.sqrt(2 * np.pi)/np.sqrt(np.linalg.det(self.covs[ci]))
        return score

    def which_component(self, X):
        prob = []
        for ci in range(self.k):
            score = self.calc_score(X,ci)
            prob.append(score)
        prob = np.array(prob).T
        return np.argmax(prob, axis=1)

    def add_pixel(self, pixel, ci):
        '''Add a pixel to the ci_th component of GMM, and refresh the parameters'''
        # print(np.asarray(pixel))
        tp = pixel.copy().astype(np.float32)
        self._sums[ci] += tp
        tp.shape = (tp.size, 1)
        self._prods[ci] += np.dot(tp, np.transpose(tp))
        self.pixel_counts[ci] += 1
        self.pixel_total_count += 1

    def learning(self):
        variance = 0.01
        for ci in range(self.k):
            n = self.pixel_counts[ci]
            if n == 0:
                self.weights[ci] = 0
            else:
                self.weights[ci] = n/self.pixel_total_count
                self.means[ci] = self._sums[ci]/n
                nm = self.means[ci].copy()
                nm.shape = (nm.size, 1)
                self.covs[ci] = self._prods[ci]/n - np.dot(nm, np.transpose(nm))
                self.cov_det[ci] = np.linalg.det(self.covs[ci])
            while self.cov_det[ci] <= 0:
                self.covs[ci] += np.diag([variance for i in range(3)])
                self.cov_det[ci] = np.linalg.det(self.covs[ci])
            self.cov_inv[ci] = np.linalg.inv(self.covs[ci])

In [50]:

class GrabC:
    
    def __init__(self, img, k ):
        self.k = k # The number of components in each GMM model

        self.img = np.asarray(img, dtype = np.float32)
        self.img2 = img
        self.rows, self.cols = img.shape[0] ,img.shape[1]
        self.gamma = 50
        self.lam = 9*self.gamma
        self.beta = 0
        
        self._BLUE = [255,0,0]        # rectangle color
        self._RED = [0,0,255]         # PR BG
        self._GREEN = [0,255,0]       # PR FG
        self._BLACK = [0,0,0]         # sure BG
        self._WHITE = [255,255,255]   # sure FG
        
        self._DRAW_BG = {'color':self._BLACK, 'val':0}
        self._DRAW_FG = {'color':self._WHITE, 'val':1}
        self._DRAW_PR_FG = {'color':self._GREEN, 'val':3}
        self._DRAW_PR_BG = {'color':self._RED, 'val':2}
        
        self._GC_BGD = 0
        self._GC_FGD = 1
        self._GC_PR_BGD = 2
        self._GC_PR_FGD = 3
        
        self.calc_beta()
        self.calc_nearby_weight()
        
        self._DRAW_VAL = None

        self._mask = np.zeros([self.rows, self.cols], dtype = np.uint8) 
        
        self.gc_source = self.cols*self.rows
        self.gc_sink = self.gc_source + 1

    
    def calc_beta(self):
        '''Calculate Beta -- The Exp Term of Smooth Parameter in Gibbs Energy'''
        '''beta = 1/(2*average(sqrt(||pixel[i] - pixel[j]||)))'''
        '''Beta is used to adjust the difference of two nearby pixels in high or low contrast rate'''
        beta = 0
        
        self._left_diff    = self.img[:, 1:] - self.img[:, :-1] # Left-difference
        self._upleft_diff  = self.img[1:, 1:] - self.img[:-1, :-1] # Up-Left difference
        self._up_diff      = self.img[1:, :] - self.img[:-1, :] # Up-difference
        self._upright_diff = self.img[1:, :-1] - self.img[:-1, 1:] # Up-Right difference
        
        beta = (self._left_diff*self._left_diff).sum() + (self._upleft_diff*self._upleft_diff).sum() + (self._up_diff*self._up_diff).sum() + (self._upright_diff*self._upright_diff).sum() 
        self.beta = 1/(2*beta/(4*self.cols*self.rows - 3*self.cols - 3*self.rows + 2))
        
    def calc_nearby_weight(self):  
        self.left_V = self.gamma * np.exp(-self.beta * np.sum(np.square(self._left_diff), axis=2))
        self.upleft_V = self.gamma / np.sqrt(2) * np.exp(-self.beta * np.sum(np.square(self._upleft_diff), axis=2))
        self.up_V = self.gamma * np.exp(-self.beta * np.sum(np.square(self._up_diff), axis=2))
        self.upright_V = self.gamma / np.sqrt(2) * np.exp(-self.beta * np.sum(np.square(self._upright_diff), axis=2))
        
    
    def init_with_kmeans(self):
        '''Initialise the BGDGMM and FGDGMM, which are respectively background-model and foreground-model, using kmeans algorithm'''
        
        max_iter = 2 # Max-iteration count for Kmeans
        
        '''In the following two indexings, the np.logical_or is needed in place of or'''
        self._bgd = np.where(np.logical_or(self._mask == self._GC_BGD, self._mask == self._GC_PR_BGD)) # Find the places where pixels in the mask MAY belong to BGD.
        self._fgd = np.where(np.logical_or(self._mask == self._GC_FGD, self._mask == self._GC_PR_FGD)) # Find the places where pixels in the mask MAY belong to FGD.
        self._BGDpixels = self.img[self._bgd]
        self._FGDpixels = self.img[self._fgd]

        bg_label = KMeans(n_clusters=self.k, n_init=1).fit(self._BGDpixels).labels_
        fg_label = KMeans(n_clusters=self.k, n_init=1).fit(self._FGDpixels).labels_
        
        self._BGD_by_components = bg_label
        self._FGD_by_components = fg_label
        
        self.BGD_GMM = GMM() 
        self.FGD_GMM = GMM() 
        
        '''Add the pixels by components to GMM'''
        for i in range(len(self._BGD_by_components)):
            self.BGD_GMM.add_pixel(self._BGDpixels[i] ,self._BGD_by_components[i])
        
        for i in range(len(self._FGD_by_components)):
            self.FGD_GMM.add_pixel(self._FGDpixels[i] ,self._FGD_by_components[i])
            
        
        self.BGD_GMM.learning()
        self.FGD_GMM.learning()
    
    def assign_GMM_components(self):
        '''
        Componant_index = Containes Componant Labels of all pixels 
                        = Numpy array
        '''
        
        self.components_index = np.zeros([self.rows, self.cols], dtype = np.uint)

        bg_comp = self.BGD_GMM.which_component(self._BGDpixels)
        fg_comp = self.FGD_GMM.which_component(self._FGDpixels)
        
        self.components_index[ self._bgd] = bg_comp
        self.components_index[ self._fgd] = fg_comp
        
        
    
    def learn_GMM_parameters(self):
        '''
        Learns GMM 
        '''
        
        for ci in range(self.k):
            # The places where the pixel belongs to the ci_th model and background model.
            bgd_ci = np.where(np.logical_and(self.components_index == ci, np.logical_or(self._mask == self._GC_BGD, self._mask == self._GC_PR_BGD)))
            fgd_ci = np.where(np.logical_and(self.components_index == ci, np.logical_or(self._mask == self._GC_FGD, self._mask == self._GC_PR_FGD)))
            
            for pixel in self.img[bgd_ci]:
                self.BGD_GMM.add_pixel(pixel, ci)
            for pixel in self.img[fgd_ci]:
                self.FGD_GMM.add_pixel(pixel, ci)
        
        self.BGD_GMM.learning()
        self.FGD_GMM.learning()
    
    def construct_gcgraph(self):
        bgd_indexes = np.where(self._mask.reshape(-1) == self._DRAW_BG['val'])
        fgd_indexes = np.where(self._mask.reshape(-1) == self._DRAW_FG['val'])
        
        pr_indexes = np.where(np.logical_or(self._mask.reshape(-1) == self._DRAW_PR_BG['val'],self._mask.reshape(-1) == self._DRAW_PR_FG['val']))
        
        edges = []
        self.graph_capacity = []
        
        ### Adding Sink ,Source Nodes
        edges.extend(list(zip([self.gc_source] * pr_indexes[0].size, pr_indexes[0])))
        _D = -np.log(self.BGD_GMM.calc_prob(self.img.reshape(-1, 3)[pr_indexes]))
        self.graph_capacity.extend(_D.tolist())
        
        edges.extend(list(zip([self.gc_sink] * pr_indexes[0].size, pr_indexes[0])))
        _D = -np.log(self.FGD_GMM.calc_prob(self.img.reshape(-1, 3)[pr_indexes]))
        self.graph_capacity.extend(_D.tolist())
        e =time.time()

        
        edges.extend(list(zip([self.gc_source] * bgd_indexes[0].size, bgd_indexes[0])))
        self.graph_capacity.extend([0] * bgd_indexes[0].size)
        
        edges.extend(list(zip([self.gc_sink] * bgd_indexes[0].size, bgd_indexes[0])))
        self.graph_capacity.extend([9 * self.gamma] * bgd_indexes[0].size)
        
        edges.extend(list(zip([self.gc_source] * fgd_indexes[0].size, fgd_indexes[0])))
        self.graph_capacity.extend([9 * self.gamma] * fgd_indexes[0].size)
        
        edges.extend(list(zip([self.gc_sink] * fgd_indexes[0].size, fgd_indexes[0])))
        self.graph_capacity.extend([0] * fgd_indexes[0].size)

        ### Adding pixel nodes 
        img_indexes = np.arange(self.rows*self.cols,dtype=np.uint32).reshape(self.rows,self.cols)
        temp1 = img_indexes[:, 1:]
        temp2 = img_indexes[:, :-1]
        mask1 = temp1.reshape(-1)
        mask2 = temp2.reshape(-1)
        edges.extend(list(zip(mask1, mask2)))
        self.graph_capacity.extend(self.left_V.reshape(-1).tolist())
        
        temp1 = img_indexes[1:, 1:]
        temp2 = img_indexes[:-1, :-1]
        mask1 = temp1.reshape(-1)
        mask2 = temp2.reshape(-1)
        edges.extend(list(zip(mask1, mask2)))
        self.graph_capacity.extend(self.upleft_V.reshape(-1).tolist())
        
        temp1 = img_indexes[1:, :]
        temp2 = img_indexes[:-1, :]
        mask1 = temp1.reshape(-1)
        mask2 = temp2.reshape(-1)
        edges.extend(list(zip(mask1, mask2)))
        self.graph_capacity.extend(self.up_V.reshape(-1).tolist())
        
        temp1 = img_indexes[1:, :-1]
        temp2 = img_indexes[:-1, 1:]
        mask1 = temp1.reshape(-1)
        mask2 = temp2.reshape(-1)
        edges.extend(list(zip(mask1, mask2)))
        self.graph_capacity.extend(self.upright_V.reshape(-1).tolist())
        
        self.graph = ig.Graph(self.cols * self.rows + 2)
        self.graph.add_edges(edges)

    def estimate_segmentation(self):
        mincut = self.graph.st_mincut(self.gc_source,self.gc_sink, self.graph_capacity)
        
        pr_indexes = np.where(np.logical_or(self._mask == self._DRAW_PR_BG['val'], self._mask == self._DRAW_PR_FG['val']))
        
        img_indexes = np.arange(self.rows * self.cols,dtype=np.uint32).reshape(self.rows, self.cols)
        self._mask[pr_indexes] = np.where(np.isin(img_indexes[pr_indexes], mincut.partition[0]),self._DRAW_PR_FG['val'], self._DRAW_PR_BG['val'])
        
        bgd_indexes = np.where(np.logical_or(self._mask == self._DRAW_BG['val'],self._mask == self._DRAW_PR_BG['val']))
        fgd_indexes = np.where(np.logical_or(self._mask == self._DRAW_FG['val'],self._mask == self._DRAW_PR_FG['val']))
        
    
    def run(self ,n ,mask=None ,rect=None ,ismask =False):
        print('Init Kmeans')
        
        if ismask==False:
            self._mask[rect[1]+3:rect[1]+rect[3]-3, rect[0]+3:rect[0]+rect[2]-3] = self._GC_PR_FGD
            self.init_with_kmeans()
        else:
            self._mask = mask
        
        for i in range(n):
            print('Assign Componants')
            self.assign_GMM_components()
            
            print('Learn GMM')
            self.learn_GMM_parameters()
            
            print('Construct Graph')
            self.construct_gcgraph()
            
            print('Estimate Seg')
            self.estimate_segmentation() 

In [51]:
def run(filename: str):
    """
    Main loop that implements GrabCut. 
    
    Input
    -----
    filename (str) : Path to image
    """
    
    COLORS = {
    'BLACK' : [0,0,0],
    'RED'   : [0, 0, 255],
    'GREEN' : [0, 255, 0],
    'BLUE'  : [255, 0, 0],
    'WHITE' : [255,255,255]
    }

    DRAW_BG = {'color' : COLORS['BLACK'], 'val' : 0}
    DRAW_FG = {'color' : COLORS['WHITE'], 'val' : 1}

    FLAGS = {
        'RECT' : (0, 0, 1, 1),
        'DRAW_STROKE': False,         # flag for drawing strokes
        'DRAW_RECT' : False,          # flag for drawing rectangle
        'rect_over' : False,          # flag to check if rectangle is  drawn
        'rect_or_mask' : -1,          # flag for selecting rectangle or stroke mode
        'value' : DRAW_FG,            # drawing strokes initialized to mark foreground
    }

    img = cv2.imread(filename)
    img2 = img.copy()                                
    mask = np.zeros(img.shape[:2], dtype = np.uint8)  
    output = np.zeros(img.shape, np.uint8)           

    # Input and segmentation windows
    cv2.namedWindow('Input Image')
    cv2.namedWindow('Segmented output')
    
    EventObj = EventHandler(FLAGS, img, mask, COLORS)
    cv2.setMouseCallback('Input Image', EventObj.handler)
    cv2.moveWindow('Input Image', img.shape[1] + 10, 90)
    
    GC = GrabC(img, k = 5)
    while(1):
        
        img = EventObj.image
        mask = EventObj.mask
        FLAGS = EventObj.flags
        
        cv2.imshow('Segmented image', output)
        cv2.imshow('Input Image', img)
        
        k = cv2.waitKey(1)

        if k == 27:
            break
        
        elif k == ord('0'): 
            FLAGS['value'] = DRAW_BG
        
        elif k == ord('1'):
            FLAGS['value'] = DRAW_FG
                
        elif k == 13: 
            if FLAGS['rect_or_mask'] == 0:
                GC.run(1 ,mask ,FLAGS['RECT'] ,False)
                FLAGS['rect_or_mask'] = 1
            elif FLAGS['rect_or_mask'] == 1:
                GC.run(1 ,mask , None , True)
        
            EventObj.flags = FLAGS
            FGD = np.where((GC._mask == 1) + (GC._mask == 3), 255, 0).astype('uint8')
            output = cv2.bitwise_and(GC.img2, GC.img2, mask = FGD)
            EventObj.mask = GC._mask

In [52]:
if __name__ == '__main__':
    filename = '../images/banana1.jpg'               # Path to image file
    run(filename)
    cv2.destroyAllWindows()

Init Kmeans
Assign Componants
Learn GMM
Construct Graph
Estimate Seg
Init Kmeans
Assign Componants
Learn GMM
Construct Graph
Estimate Seg
Init Kmeans
Assign Componants
Learn GMM
Construct Graph
Estimate Seg
Init Kmeans
Assign Componants
Learn GMM
Construct Graph
Estimate Seg
