# GrabCut Algorithm

This notebook is a demo notebook demonstrating the working of the GrabCut Algorithm using Python. This is an implementation of the GrabCut Algorithm method mentioned in this [paper](https://cvg.ethz.ch/teaching/cvl/2012/grabcut-siggraph04.pdf). It is essentially an iterative version of GraphCut algorithm. 

## Procedure
The code below takes an input image and follows these steps:
- It requires a bounding box to be drawn by the user to roughly segment out the foreground pixels
- It runs an initial min-cut optimization using the provided annotation
- The result of this optimization gives an initial segmentation 
- To further refine this segmentation, the user provides two kinds of strokes to aid the optimization
    - strokes on the background pixels
    - strokes on the foreground pixels
- The algorithm now utilizes this to refine the original segmentation

In [9]:
import numpy as np
from os import listdir
from cv2 import imwrite
import random
from imageio import imread, imsave
from scipy.spatial.distance import cdist
import time
import os

import skimage.io as sio
import matplotlib.pyplot as plt

import igraph as ig
import cv2 as cv

import datetime
from sklearn.mixture import GaussianMixture
import tqdm

In [10]:
class EventHandler:
    """
    Class for handling user input during segmentation iterations 
    """
    
    def __init__(self, flags, img, _mask, colors):
        
        self.FLAGS = flags
        self.ix = -1
        self.iy = -1
        self.img = img
        self.img2 = self.img.copy()
        self._mask = _mask
        self.COLORS = colors

    @property
    def image(self):
        return self.img
    
    @image.setter
    def image(self, img):
        self.img = img
        
    @property
    def mask(self):
        return self._mask
    
    @mask.setter
    def mask(self, _mask):
        self._mask = _mask
    
    @property
    def flags(self):
        return self.FLAGS 
    
    @flags.setter
    def flags(self, flags):
        self.FLAGS = flags
    
    def handler(self, event, x, y, flags, param):

        # Draw the rectangle first
        if event == cv2.EVENT_RBUTTONDOWN:
            self.FLAGS['DRAW_RECT'] = True
            self.ix, self.iy = x,y

        elif event == cv2.EVENT_MOUSEMOVE:
            if self.FLAGS['DRAW_RECT'] == True:
                self.img = self.img2.copy()
                cv2.rectangle(self.img, (self.ix, self.iy), (x, y), self.COLORS['BLUE'], 2)
                cv2.rectangle(self._mask, (self.ix, self.iy), (x, y), self.FLAGS['value']['val'], -1)
                self.FLAGS['RECT'] = (min(self.ix, x), min(self.iy, y), abs(self.ix - x), abs(self.iy - y))
                self.FLAGS['rect_or_mask'] = 0

        elif event == cv2.EVENT_RBUTTONUP:
            self.FLAGS['DRAW_RECT'] = False
            self.FLAGS['rect_over'] = True
            cv2.rectangle(self.img, (self.ix, self.iy), (x, y), self.COLORS['BLUE'], 2)
            cv2.rectangle(self._mask, (self.ix, self.iy), (x, y), self.FLAGS['value']['val'], -1)
            self.FLAGS['RECT'] = (min(self.ix, x), min(self.iy, y), abs(self.ix - x), abs(self.iy - y))
            self.FLAGS['rect_or_mask'] = 0

        
        # Draw strokes for refinement 

        if event == cv2.EVENT_LBUTTONDOWN:
            if self.FLAGS['rect_over'] == False:
                print('Draw the rectangle first.')
            else:
                self.FLAGS['DRAW_STROKE'] = True
                cv2.circle(self.img, (x,y), 3, self.FLAGS['value']['color'], -1)
                cv2.circle(self._mask, (x,y), 3, self.FLAGS['value']['val'], -1)

        elif event == cv2.EVENT_MOUSEMOVE:
            if self.FLAGS['DRAW_STROKE'] == True:
                cv2.circle(self.img, (x, y), 3, self.FLAGS['value']['color'], -1)
                cv2.circle(self._mask, (x, y), 3, self.FLAGS['value']['val'], -1)

        elif event == cv2.EVENT_LBUTTONUP:
            if self.FLAGS['DRAW_STROKE'] == True:
                self.FLAGS['DRAW_STROKE'] = False
                cv2.circle(self.img, (x, y), 3, self.FLAGS['value']['color'], -1)
                cv2.circle(self._mask, (x, y), 3, self.FLAGS['value']['val'], -1)

In [11]:
class GMM:
    """
    An extension to the GaussianMixture class of scikit-learn.
    This class is a wrapper above the library package to provide
    a few extra functions.
    """
    def __init__(self, X_data, n_components):
        self.X_data = X_data
        self.n_components = n_components
        self.components = None
        
        # Initialize the actual the GMM
        self.model = GaussianMixture(n_components=self.n_components).fit(self.X_data)
        
        # Update model parameters
        self.update_params(self.model.predict(self.X_data))
    
    def get_components(self, x_data):
        preds = self.model.predict(x_data)
        return preds
    
    def fit(self, x_data, labels):
        pass
    
    def update_params(self, labels):
        """
        Function to update the parameters of this wrapper
        for easier access to useful  properties.
        """
        # Update the components labeling and the label counts
        self.components = labels
        self.counts = np.unique(labels, return_counts=True)[1]
        self.weights = self.counts / len(labels)
        
        # Save means and covariances from the model
        self.means = self.model.means_
        self.cov = self.model.covariances_ + 1e-08*np.eye(self.means.shape[1])
    
    def compute_D(self, data=None):
        if data is None:
            data = self.X_data
            components = self.components
        else:
            components = self.model.predict(data)
        
        diffs = data.reshape((-1, 1, 3)) - self.means.reshape((1, -1, 3))
        
        matrix_term = 0.5 * np.einsum('ijk,kji->ij', np.einsum('ijk,jkl->ijl', diffs, np.linalg.inv(self.cov)), diffs.T)
        det_term = 0.5 * np.log(np.linalg.det(self.cov))
        
        D_val = -np.log(self.weights) + det_term + matrix_term
        
        return D_val[range(len(components)), components]

In [12]:
COLORS = {
'BLACK' : [0,0,0],
'RED'   : [0, 0, 255],
'GREEN' : [0, 255, 0],
'BLUE'  : [255, 0, 0],
'WHITE' : [255,255,255]
}

gmm_components = 5
gamma = 50
neighbours = 4 
color_space = 'RGB'
n_iters = 5

DRAW_PR_BG = {'color' : COLORS['BLACK'], 'val' : 2}
DRAW_PR_FG = {'color' : COLORS['WHITE'], 'val' : 3}
DRAW_BG = {'color' : COLORS['BLACK'], 'val' : 0}
DRAW_FG = {'color' : COLORS['WHITE'], 'val' : 1}

In [13]:
class GrabCut:
    def __init__(self, img, mask, n_iters, gamma=50, gmm_components=5, neighbours=8, rect=None):
        # Save the image and it's size first
        self.img = img.copy()
        self.rows, self.cols, _ = self.img.shape
        self.mask = mask.copy()
        
        # Set the mask with the probable foreground
        if rect is not None:
            self.mask[rect[1]:rect[1]+rect[3], rect[0]:rect[0]+rect[2]] = DRAW_PR_FG['val']
        
        # plt.imshow(self.mask)
        
        # Update pixel labels from the bounding box
        self.classify_pixels()
        
        # Set the parameters
        self.gamma = gamma
        self.gmm_components = gmm_components
        self.n_iters = n_iters
        self.neighbours = neighbours
        
        # beta value for the smoothness term in the Gibbs energy function
        self.beta = 0
        
        # Placeholders for shifted image values for pairwise computations
        self.img_left = np.empty((self.rows, self.cols - 1))
        self.img_up = np.empty((self.rows - 1, self.cols))
        
        # Add more placeholders if we want 8 connectivity
        if self.neighbours == 8:
            self.img_upleft = np.empty((self.rows - 1, self.cols - 1))
            self.img_upright = np.empty((self.rows - 1, self.cols - 1))
            
        
        # Placeholders for GMM models
        self.gmm_fg = None
        self.gmm_bg = None
        self.comp_idx = np.empty((self.rows, self.cols), dtype=np.uint32)
        
        self.graph = None
        self.graph_capacity = None
        self.graph_source  = self.rows * self.cols
        self.graph_sink = self.graph_source + 1
        
        # Now, we initialise the Gibbs functions and GMMs
        self.compute_smoothness()
        self.initialise_GMMs()
        
        # Start the GrabCut method
        self.run()
        
        
    def classify_pixels(self):#     print('bgd count: %d, fgd count: %d, uncertain count: %d' % (len(bgd_indexes[0]), len(fgd_indexes[0]), len(pr_indexes[0])))

        # Allocate the indices with foregorund and background pixels
        self.bg_idx = np.where(np.logical_or(self.mask == DRAW_BG['val'], self.mask == DRAW_PR_BG['val']))
        self.fg_idx = np.where(np.logical_or(self.mask == DRAW_FG['val'], self.mask == DRAW_PR_FG['val']))
        # print(self.bg_idx[0].shape, self.fg_idx[0].shape)
    
    def compute_smoothness(self):
        """
        Compute the pairwise smoothness term from the Gibbs energy function
        """
        # First compute all the differences
        diff_left = self.img[:, 1:] - self.img[:, :-1]
        diff_up = self.img[1:, :] - self.img[:-1, :]
        
        # Compute the sum of difference terms to get beta value
        beta_sum = np.sum(np.square(diff_left)) + np.sum(np.square(diff_up))
        # We need to divide the sum above to get the expectation value
        n_avg = (2 * self.rows * self.cols) - self.cols - self.rows
        
        # Compute addditional difference if 8-neighbours are considered
        if self.neighbours == 8:
            diff_upleft = self.img[1:, 1:] - self.img[:-1, :-1]
            diff_upright = self.img[1:, :-1] - self.img[:-1, 1:]
            
            # Update the beta related values as well
            beta_sum += np.sum(np.square(diff_upleft)) + np.sum(np.square(diff_upright))
            n_avg += (2 * self.rows * self.cols) - (2 * self.cols) - (2 * self.rows) + 2
        
        # Finally, compute the beta value
        self.beta = n_avg / (2 * beta_sum)
        
        # Now, we need to compute the smoothness term V
        self.V_left = self.gamma * np.exp(-self.beta * np.sum(np.square(diff_left), axis=2))
        self.V_up = self.gamma * np.exp(-self.beta * np.sum(np.square(diff_up), axis=2))
        
        # Again, for the 8-neighbour case
        if self.neighbours == 8:
            self.V_upleft = self.gamma * np.exp(-self.beta * np.sum(np.square(diff_upleft), axis=2))
            self.V_upright = self.gamma * np.exp(-self.beta * np.sum(np.square(diff_upright), axis=2))
    
    def initialise_GMMs(self):
        # Initialise the GMM objects here
        self.gmm_bg = GMM(self.img[self.bg_idx], self.gmm_components)
        self.gmm_fg = GMM(self.img[self.fg_idx], self.gmm_components)
    
    def run(self):
        for iter_ in range(self.n_iters):
            # print("Current iteration: {}".format(iter_))
            self.assign_gmm_components()
            # print("Assigned GMM components")
            self.initialise_GMMs()
            # print("Initialized GMMs for this iteration")
            self.create_graph()
            # print("Completed the creation of graphs")
            self.estimate_segmentation()
            # print("Estimated the segmentation")
    
    def assign_gmm_components(self):
        #Step 1 in Figure 3: Assign GMM components to pixels
        self.comp_idx[self.bg_idx] = self.gmm_bg.get_components(self.img[self.bg_idx])
        self.comp_idx[self.fg_idx] = self.gmm_fg.get_components(self.img[self.fg_idx])
    
    def create_graph(self):
        # Get the indices of foreground, background and soft labels
        bg_idx = np.where(self.mask.reshape(-1) == DRAW_BG['val'])
        fg_idx = np.where(self.mask.reshape(-1) == DRAW_FG['val'])
        pr_idx = np.where(np.logical_or(self.mask.reshape(-1) == DRAW_PR_BG['val'], self.mask.reshape(-1) == DRAW_PR_FG['val']))
        pr_idx_ = np.where(np.logical_or(self.mask == DRAW_PR_BG['val'], self.mask == DRAW_PR_FG['val']))
        
        # Prepare to create the graph
        edges = []
        self.graph_capacity = []
        
        # Create t-links (connect nodes to terminal nodes)
        # Prob values to source
        edges.extend(
            list(zip([self.graph_source for ix in range(pr_idx[0].size)], pr_idx[0]))
        )
        
        start_time = datetime.datetime.now()
        start = datetime.datetime.now()
        _D = self.gmm_bg.compute_D(self.img[pr_idx_])
        self.graph_capacity.extend(_D.tolist())
        # print("Background weight mean: {} | time taken: {}".format(_D.mean(), datetime.datetime.now() - start))
        
        
        # prob values to sink
        start = datetime.datetime.now()
        edges.extend(
            list(zip([self.graph_sink for ix in range(pr_idx[0].size)], pr_idx[0]))
        )
        # print("time taken for edge extend: {}".format(datetime.datetime.now() - start))
        
        
        start = datetime.datetime.now()
        _D = self.gmm_fg.compute_D(self.img[pr_idx_])
        self.graph_capacity.extend(_D.tolist())
        # print("Foreground weight mean: {} | time taken: {}".format(_D.mean(), datetime.datetime.now() - start))
        # print("Time for whole chunk: {}".format(datetime.datetime.now() - start_time))
        # Background to source
        # print("Edges before FG abd BG: {}".format(len(edges)))
        edges.extend(
            list(zip([self.graph_source for ix in range(bg_idx[0].size)], bg_idx[0]))
        )
        self.graph_capacity.extend([0] * bg_idx[0].size)
        
        # background to sink
        edges.extend(
            list(zip([self.graph_sink for ix in range(bg_idx[0].size)], bg_idx[0]))
        )
        self.graph_capacity.extend([99 * self.gamma] * bg_idx[0].size)
        
        # Foreground to source
        edges.extend(
            list(zip([self.graph_source for ix in range(fg_idx[0].size)], fg_idx[0]))
        )
        self.graph_capacity.extend([99 * self.gamma] * fg_idx[0].size)
        
        # Foreground to sink
        edges.extend(
            list(zip([self.graph_sink for ix in range(fg_idx[0].size)], fg_idx[0]))
        )
        self.graph_capacity.extend([0] * fg_idx[0].size)
        # print("Edges after FG abd BG: {}".format(len(edges)))
        
        
        # Now we create n-links (connect nodes to other nodes (non-terminal))
        img_indexes = np.arange(self.rows * self.cols, dtype=np.uint32).reshape(self.rows, self.cols)
        
        # get shifted indices and connect the points (left)
        mask1 = img_indexes[:, 1:].reshape(-1)
        mask2 = img_indexes[:, :-1].reshape(-1)
        edges.extend(list(zip(mask1, mask2)))
        self.graph_capacity.extend(self.V_left.reshape(-1).tolist())
        
        # get shifted indices and connect the points (up)
        mask1 = img_indexes[1:, :].reshape(-1)
        mask2 = img_indexes[:-1, :].reshape(-1)
        edges.extend(list(zip(mask1, mask2)))
        self.graph_capacity.extend(self.V_up.reshape(-1).tolist())
        
        # For 8-connectivity
        if self.neighbours == 8:
            # get shifted indices and connect the points (up-left)
            mask1 = img_indexes[1:, 1:].reshape(-1)
            mask2 = img_indexes[:-1, :-1].reshape(-1)
            edges.extend(list(zip(mask1, mask2)))
            self.graph_capacity.extend(self.V_upleft.reshape(-1).tolist())
            
            # get shifted indices and connect the points (up-right)
            mask1 = img_indexes[1:, :-1].reshape(-1)
            mask2 = img_indexes[:-1, 1:].reshape(-1)
            edges.extend(list(zip(mask1, mask2)))
            self.graph_capacity.extend(self.V_upright.reshape(-1).tolist())
        
        # Construct the graph and add the edges
        self.graph = ig.Graph(self.cols * self.rows + 2)
        self.graph.add_edges(edges)
    
    def estimate_segmentation(self):
        # Apply the mincut algorithm first
        start = datetime.datetime.now()
        mincut = self.graph.st_mincut(self.graph_source, self.graph_sink, self.graph_capacity)
        
        # print("Mincut time: {}".format(datetime.datetime.now() - start))
        # Compute the probability indices
        pr_idx = np.where(np.logical_or(self.mask == DRAW_PR_BG['val'], self.mask == DRAW_PR_FG['val']))
        img_idx = np.arange(self.rows * self.cols, dtype=np.uint32).reshape((self.rows, self.cols))
        
        # Update mask with foregrund and background values and update indices
        self.mask[pr_idx] = np.where(np.isin(img_idx[pr_idx], mincut.partition[0]), DRAW_PR_FG['val'], DRAW_PR_BG['val'])
        
        # print(np.unique(self.mask.flatten(), return_counts=True))
        
        self.classify_pixels()

In [14]:
def load_image(filename, scale=1.0):
    im = cv2.imread(filename)
    
    if color_space == "RGB":
        pass
#         im = cv.cvtColor(im, cv.COLOR_BGR2RGB)
    elif color_space == "HSV":
        im = cv.cvtColor(im, cv.COLOR_BGR2HSV)
    elif color_space == "LAB":
        im = cv.cvtColor(im, cv.COLOR_BGR2LAB)
        
    if not scale == 1.0:
        im = cv2.resize(im, (int(im.shape[1]*scale), int(im.shape[0]*scale)))
    return im

In [15]:
def run(filename, gamma=50, gmm_components=7, neighbours=8, color_space='RGB'):
    """
    Main loop that implements GrabCut. 
    
    Input
    -----
    filename (str) : Path to image
    """
    
    COLORS = {
    'BLACK' : [0,0,0],
    'RED'   : [0, 0, 255],
    'GREEN' : [0, 255, 0],
    'BLUE'  : [255, 0, 0],
    'WHITE' : [255,255,255]
    }

    DRAW_PR_BG = {'color' : COLORS['BLACK'], 'val' : 2}
    DRAW_PR_FG = {'color' : COLORS['WHITE'], 'val' : 3}
    DRAW_BG = {'color' : COLORS['BLACK'], 'val' : 0}
    DRAW_FG = {'color' : COLORS['WHITE'], 'val' : 1}

    FLAGS = {
        'RECT' : (0, 0, 1, 1),
        'DRAW_STROKE': False,         # flag for drawing strokes
        'DRAW_RECT' : False,          # flag for drawing rectangle
        'rect_over' : False,          # flag to check if rectangle is  drawn
        'rect_or_mask' : -1,          # flag for selecting rectangle or stroke mode
        'value' : DRAW_PR_FG,            # drawing strokes initialized to mark foreground
    }

    img = load_image(filename, scale=0.8)
    img2 = img.copy()                                
    mask = np.ones(img.shape[:2], dtype = np.uint8) * DRAW_PR_BG['val'] # mask is a binary array with : 0 - background pixels
                                                     #                               1 - foreground pixels 
    output = np.zeros(img.shape, np.uint8)           # output image to be shown

    # Input and segmentation windows
    cv2.namedWindow('Input Image')
    cv2.namedWindow('Segmented image')
    
    
    EventObj = EventHandler(FLAGS, img, mask, COLORS)
    cv2.setMouseCallback('Input Image', EventObj.handler)
    cv2.moveWindow('Input Image', img.shape[1] + 10, 90)
    
    while(1):
        
        img = EventObj.image
        mask = EventObj.mask
        FLAGS = EventObj.flags
        cv2.imshow('Segmented image', output)
        cv2.imshow('Input Image', img)
        
        k = cv2.waitKey(1)

        # key bindings
        if k == 27:
            # esc to exit
            break
        
        elif k == ord('0'): 
            # Strokes for background
            FLAGS['value'] = DRAW_BG
        
        elif k == ord('1'):
            # FG drawing
            FLAGS['value'] = DRAW_FG
        
        elif k == ord('r'):
            # reset everything
            FLAGS['RECT'] = (0, 0, 1, 1)
            FLAGS['DRAW_STROKE'] = False
            FLAGS['DRAW_RECT'] = False
            FLAGS['rect_or_mask'] = -1
            FLAGS['rect_over'] = False
            FLAGS['value'] = DRAW_PR_FG
            img = img2.copy()
            mask = np.zeros(img.shape[:2], dtype = np.uint8) 
            EventObj.image = img
            EventObj.mask = mask
            output = np.zeros(img.shape, np.uint8)
        
        elif k == 13:
            # Press carriage return to initiate segmentation
            
            #-------------------------------------------------#
            # Implement GrabCut here.                         #  
            # Function should return a mask which can be used #
            # to segment the original image as shown on L90   # 
            #-------------------------------------------------#
            gc = GrabCut(img=img2, mask=mask, n_iters=2, gamma=gamma, gmm_components=gmm_components, neighbours=neighbours)
            mask = gc.mask.copy()
        
            EventObj.flags = FLAGS
            mask2 = np.where(np.logical_or((mask == 3), (mask==1)), 255, 0).astype('uint8')
            output = cv2.bitwise_and(img2, img2, mask = mask2)

In [18]:
filename = './images/banana2.jpg'               # Path to image file
run(filename, gamma, gmm_components, neighbours, color_space)
cv2.destroyAllWindows()