In [2]:
import numpy as np
import cv2 
from sklearn.mixture import GaussianMixture
import random
from networkx.algorithms.flow import shortest_augmenting_path
from matplotlib import pyplot as plt
import networkx as nx
import time

In [3]:
class EventHandler:
    """
    Class for handling user input during segmentation iterations 
    """
    
    def __init__(self, flags, img, _mask, colors):
        
        self.FLAGS = flags
        self.ix = -1
        self.iy = -1
        self.img = img
        self.img2 = self.img.copy()
        self._mask = _mask
        self.COLORS = colors

    @property
    def image(self):
        return self.img
    
    @image.setter
    def image(self, img):
        self.img = img
        
    @property
    def mask(self):
        return self._mask
    
    @mask.setter
    def mask(self, _mask):
        self._mask = _mask
    
    @property
    def flags(self):
        return self.FLAGS 
    
    @flags.setter
    def flags(self, flags):
        self.FLAGS = flags
    
    def handler(self, event, x, y, flags, param):

        # Draw the rectangle first
        if event == cv2.EVENT_RBUTTONDOWN:
            self.FLAGS['DRAW_RECT'] = True
            self.ix, self.iy = x,y

        elif event == cv2.EVENT_MOUSEMOVE:
            if self.FLAGS['DRAW_RECT'] == True:
                self.img = self.img2.copy()
                cv2.rectangle(self.img, (self.ix, self.iy), (x, y), self.COLORS['BLUE'], 2)
                self.FLAGS['RECT'] = (min(self.ix, x), min(self.iy, y), abs(self.ix - x), abs(self.iy - y))
                self.FLAGS['rect_or_mask'] = 0

        elif event == cv2.EVENT_RBUTTONUP:
            self.FLAGS['DRAW_RECT'] = False
            self.FLAGS['rect_over'] = True
            cv2.rectangle(self.img, (self.ix, self.iy), (x, y), self.COLORS['BLUE'], 2)
            self.FLAGS['RECT'] = (min(self.ix, x), min(self.iy, y), abs(self.ix - x), abs(self.iy - y))
            self.FLAGS['rect_or_mask'] = 0

        
        # Draw strokes for refinement 

        if event == cv2.EVENT_LBUTTONDOWN:
            if self.FLAGS['rect_over'] == False:
                print('Draw the rectangle first.')
            else:
                self.FLAGS['DRAW_STROKE'] = True
                cv2.circle(self.img, (x,y), 3, self.FLAGS['value']['color'], -1)
                cv2.circle(self._mask, (x,y), 3, self.FLAGS['value']['val'], -1)

        elif event == cv2.EVENT_MOUSEMOVE:
            if self.FLAGS['DRAW_STROKE'] == True:
                cv2.circle(self.img, (x, y), 3, self.FLAGS['value']['color'], -1)
                cv2.circle(self._mask, (x, y), 3, self.FLAGS['value']['val'], -1)

        elif event == cv2.EVENT_LBUTTONUP:
            if self.FLAGS['DRAW_STROKE'] == True:
                self.FLAGS['DRAW_STROKE'] = False
                cv2.circle(self.img, (x, y), 3, self.FLAGS['value']['color'], -1)
                cv2.circle(self._mask, (x, y), 3, self.FLAGS['value']['val'], -1)

In [11]:
class Params:
    
    def __init__(self ,pixels , labels ,n_componants ):
        self.n_componants = n_componants
        self.n_samples = np.zeros(n_components)
        
        self.coefs = np.zeros(n_components)
        self.means = np.zeros((n_components, 3))
        self.covariances = np.zeros((n_components, 3, 3))

        self.X = pixels
        self.labels = labels
        
    def fit_para(self):
        self.n_samples[:] = 0
        self.coefs[:] = 0
        
        uni_labels, count = np.unique(self.labels, return_counts=True)
        self.n_samples[uni_labels] = count
        
        variance = 0.01
        for ci in uni_labels:
            n = self.n_samples[ci]

            self.coefs[ci] = n / np.sum(self.n_samples)
            
            self.means[ci] = np.mean(self.X[ci == labels], axis=0)
            self.covariances[ci] = 0 if self.n_samples[ci] <= 1 else np.cov(self.X[ci == labels].T)

            det = np.linalg.det(self.covariances[ci])
            if det <= 0:
                # Adds the white noise to avoid singular covariance matrix.
                self.covariances[ci] += np.eye(3) * variance
                det = np.linalg.det(self.covariances[ci])

In [None]:
class GC:
    
    def __init__(self ,img ,iters=1, gamma=50, n_clusters=5):
        self.img = img
        self.n_clusters = n_clusters
        self.iters = iters
        self.gamma = gamma 
        
        self.alphas = None
        
        self.f_gmm = None
        self.b_gmm = None
    
    def get_alphas(self, img , rect , mask ,ismask):
        # 1 for foreground
        # 0 for background
        alpha_arr = np.zeros((img.shape[0], img.shape[1]))
        
        if ismask==False:
            for i in range(img.shape[0]):
                for j in range(img.shape[1]):
                    if j >= roi[0] and j <= roi[2] and i >= roi[1] and i <= roi[3]:
                        alpha_arr[i][j] = 1
                    else:
                        alpha_arr[i][j] = 0
            return alpha_arr
        else:
            for i in range(img.shape[0]):
                for j in range(img.shape[1]):
                    
                    if mask[i][j] != 0:
                        if mask[i][j] == 100:
                            alpha_arr[i][j] = 0
                        else:
                            alpha_arr[i][j] = 1
                    else:
                        alpha_arr[i][j] = self.alphas[i][j]
                        
            return alpha_arr
    
    def getMixtureModel(self ,D, k):
        clf = GaussianMixture(n_components = k, covariance_type='full' ,warm_start=True)
        clf.fit(D)
        return clf

    def initGMM(self ,img , alpha):
        print('GMM init')
        fg = []
        bg = []
        
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                if alpha[i][j] == 1:
                    fg.append(img[i][j])
                else:
                    bg.append(img[i][j])
        
        bg_GMM = self.getMixtureModel(bg, self.n_clusters)
        fg_GMM = self.getMixtureModel(fg, self.n_clusters)
        return fg_GMM, bg_GMM
    
    
    def assign_comp(img ,alpha ,f_gmm ,b_gmm):
        k_id = np.zeros(img.shape[:2])
    
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                if alpha[i][j] == 1:
                    k_id[i][j] = f_gmm.predict([img[i][j]])
                else:
                    k_id[i][j] = b_gmm.predict([img[i][j]])
        
        return k_id
    
    def learn_params(self ,img , alphas , k_id ,n_clusters):
        bgd_idx = np.where( alphas==0 )
        bg_comps = k_id[bgd_idx]
        bg_pixs = img[bgd_idx]
        bg_para = Params(bg_pixs ,bg_comps ,n_clusters)
        bg_para.fit_para()
        
        fgd_idx = np.where( alphas==1 )
        fg_comps = k_id[fgd_idx]
        fg_pixs = img[fgd_idx]
        fg_para = Params(fg_pixs ,fg_comps ,n_clusters)
        fg_para.fit_para()
        
        return bg_para ,fg_para
    
    def predict_pix(img , para , gmm):
        a = []
        invs = []
        c = []
        
        for clus in range(self.n_clusters):
            a.append(0.5 * np.log(np.linalg.det(para.covariances[clus])))
            invs.append(np.linalg.inv(para.covariances[clus]))
            c.append(-np.log(para.coefs[clus]))
            
        probs = []
        
        pix = []
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                pix.append(img[i][j])
        
        pix = np.array(pix)
        labs = gmm.predict(pix)
        
        for i in range(len(labs)):
            idx = labs[i]
            b = 0.5 * np.dot(np.dot(np.transpose(pix[i] - para.means[idx]),invs[idx] ), pix[i] - para.means[idx])
            probs.append( a[idx] + b + c[idx])
        
        probs = np.array(probs).reshape(img.shape[0] ,img.shape[1])
        
        return probs
    
    
    def getBinaryPotential(self,a, b, beta):
        return self.gamma * np.exp(-1.0 * beta * np.sum((a - b) ** 2))
    
    def createGraph(self ,img, fg_penalty, bg_penalty, alphas):
        s =time.time()
        four_neighbourhood = False
        print('creating Graph')
        G = nx.Graph()
        G.add_node('s')
        G.add_node('t')
        beta = 0.0
        cnt = 0
        
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                G.add_node((i, j))
                for k in range(-1, 2):
                    for l in range(-1, 2):
                        if abs(k) + abs(l) == 0 or (abs(k) + abs(l) > 1 and four_neighbourhood == True):
                            # Not a 4 neighbour
                            # Center
                            continue
                        if (i + k) < 0 or (i + k) >= img.shape[0] or (j + l) < 0 or (j + l) >= img.shape[1]:
                            continue
                        beta = beta + np.sum((img[i][j] - img[i + k][j + l])** 2)
                        cnt = cnt + 1
        beta = (beta * 0.5) / cnt
        
        inf = 1000000000000000000000000000000000000
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                if alphas[i][j] == 0:
                    G.add_edge('s', (i, j), capacity = 0)
                    G.add_edge((i, j), 't', capacity = inf)
                else:
                    G.add_edge('s', (i, j), capacity = bg_penalty[i][j])
                    G.add_edge((i, j), 't', capacity = fg_penalty[i][j])
                for k in range(-1, 2):
                    for l in range(-1, 2):
                        if abs(k) + abs(l) == 0 or (abs(k) + abs(l) > 1 and four_neighbourhood == True):
                            # Not a 4 neighbour
                            # Center
                            continue
                        if (i + k) < 0 or (i + k) >= img.shape[0] or (j + l) < 0 or (j + l) >= img.shape[1]:
                            continue
                        
                        if self.alphas[i][j] == self.alphas[i + k][j + l]:
                            continue
                        # Need to add pairwise edges
                        G.add_edge((i, j), (i + k, j + l), capacity = self.getBinaryPotential(img[i][j], img[i + k][j + l], beta))
        
        e= time.time()
        print('\ttime'+str(e-s))
        return G
        
    def main(self ,rect=None ,mask=None ,ismask=False):
        
        self.alphas = self.get_alphas(self.img , rect ,mask ,ismask)
        
        for it in range(self.iters):
            self.f_gmm ,self.b_gmm = self.initGMM(self.img, self.alphas)
            
            ## 1) Assign GMM components to pixels (Use Inbuilt GMM)
            self.k_id = self.assign_comp(self.img ,self.alphas ,self.f_gmm ,self.b_gmm)
        
        
            ## 2) Learn GMM parameters from data z:
            self.bg_para ,self.fg_para = self.learn_params(self.img , self.alphas ,self.k_id ,self.n_clusters)
            
            
            ## 3) Estimate segmentation: use min cut to solve:
            fg_probs = self.predict_pix(self.img ,self.fg_para ,self.f_gmm )
            bg_probs = self.predict_pix(self.img ,self.bg_para ,self.f_gmm )
            
            ##### 3.1) Create Graph
            G = self.createGraph(self.img, fg_probs, bg_probs, self.alphas)
            
            cut_value, partition = nx.minimum_cut(G, 's', 't', flow_func=shortest_augmenting_path)
            
            print('Mincut Done')
            reachable, non_reachable = partition
            
            temp_alpha = np.zeros((self.img.shape[0], self.img.shape[1]))
            
            print('Assinging Pixels')
            for px in reachable:
                if px != 's':
                    temp_alpha[px[0]][px[1]] = 1
            
            final_img = np.zeros(self.img.shape)
            
            for i in range(self.img.shape[0]):
                for j in range(self.img.shape[1]):
                    if temp_alpha[i][j] == 1:
                        final_img[i][j] = self.img[i][j]
            
            self.alphas = temp_alpha
            cv2.imwrite(str(it)+'.jpg' , final_img)

In [None]:
def run(filename):
    """
    Main loop that implements GrabCut. 
    
    Input
    -----
    filename (str) : Path to image
    """
    
    COLORS = {
    'BLACK' : [0,0,0],
    'RED'   : [0, 0, 255],
    'GREEN' : [0, 255, 0],
    'BLUE'  : [255, 0, 0],
    'WHITE' : [255,255,255]
    }

    DRAW_BG = {'color' : COLORS['BLACK'], 'val' : 100}
    DRAW_FG = {'color' : COLORS['WHITE'], 'val' : 200}

    FLAGS = {
        'RECT' : (0, 0, 1, 1),
        'DRAW_STROKE': False,         # flag for drawing strokes
        'DRAW_RECT' : False,          # flag for drawing rectangle
        'rect_over' : False,          # flag to check if rectangle is  drawn
        'rect_or_mask' : -1,          # flag for selecting rectangle or stroke mode
        'value' : DRAW_FG,            # drawing strokes initialized to mark foreground
    }

    img = cv2.imread(filename)
    img2 = img.copy()                                
    mask = np.zeros(img.shape[:2], dtype = np.uint8) # mask is a binary array with : 0 - background pixels
                                                     #                               1 - foreground pixels 
    output = np.zeros(img.shape, np.uint8)           # output image to be shown

    # Input and segmentation windows
    cv2.namedWindow('Input Image')
    cv2.namedWindow('Segmented output')
    
    EventObj = EventHandler(FLAGS, img, mask, COLORS)
    cv2.setMouseCallback('Input Image', EventObj.handler)
    cv2.moveWindow('Input Image', img.shape[1] + 10, 90)

    while(1):
        
        img = EventObj.image
        mask = EventObj.mask
        FLAGS = EventObj.flags
        cv2.imshow('Segmented image', output)
        cv2.imshow('Input Image', img)
        
        k = cv2.waitKey(1)

        # key bindings
        if k == 27:
            # esc to exit
            break
        
        elif k == ord('0'): 
            # Strokes for background
            FLAGS['value'] = DRAW_BG
        
        elif k == ord('1'):
            # FG drawing
            FLAGS['value'] = DRAW_FG
        
        elif k == ord('r'):
            # reset everything
            FLAGS['RECT'] = (0, 0, 1, 1)
            FLAGS['DRAW_STROKE'] = False
            FLAGS['DRAW_RECT'] = False
            FLAGS['rect_or_mask'] = -1
            FLAGS['rect_over'] = False
            FLAGS['value'] = DRAW_FG
            img = img2.copy()
            mask = np.zeros(img.shape[:2], dtype = np.uint8) 
            EventObj.image = img
            EventObj.mask = mask
            output = np.zeros(img.shape, np.uint8)
        
        elif k == ord('n'): 
            
            if (FLAGS['rect_or_mask'] == 0):
                print('Mask False')
                gb = GC( img ,iters=1)
                seg_mask = gb.main(FLAGS['RECT'], mask ,ismask=False)
                FLAGS['rect_or_mask'] = 1
                
            elif (FLAGS['rect_or_mask'] == 1):
                print('Mask True')
                seg_mask = gb.main(FLAGS['RECT'], mask ,ismask=True)
                
            
            EventObj.flags = FLAGS
            seg_mask = seg_mask.astype(np.uint8)
            
            mask2 = np.where((seg_mask == 1), 255, 0).astype('uint8')
  
            output = cv2.bitwise_and(img2, img2, mask = mask2)
