# Grab Cut from scratch (Image Segmentation using MRFs)

Here we will implement the GrabCut method mentioned in this [paper](https://cvg.ethz.ch/teaching/cvl/2012/grabcut-siggraph04.pdf). It is essentially an iterative version of GraphCut as shown in the figure below. 
![graphcut.png](https://github.com/saiamrit/technical-blog/blob/master/images/graphcut.png)

The code below takes an input image and follows these steps:
- It requires a bounding box to be drawn by the user to roughly segment out the foreground pixels
- It runs an initial min-cut optimization using the provided annotation
- The result of this optimization gives an initial segmentation 
- To further refine this segmentation, the user provides two kinds of strokes to aid the optimization
    - strokes on the background pixels
    - strokes on the foreground pixels
- The algorithm now utilizes this to refine the original segmentation

You can view this [video](https://www.youtube.com/watch?v=aOqOwM-Qbtg) to get a better idea of the steps involved.

Image segmentation is one exciting application of MRFs. You can further read about other applications of MRFs for Computer Vision [here](https://cedar.buffalo.edu/~srihari/CSE574/Chap8/Ch8-PGM-Undirected/9.5-MRFinCV.pdf).

#### Useful Links
* https://courses.engr.illinois.edu/cs543/sp2011/lectures/Lecture%2012%20-%20MRFs%20and%20Graph%20Cut%20Segmentation%20-%20Vision_Spring2011.pdf

In [1]:
import sys
import os
import numpy as np
import math
import cv2 as cv
import igraph as ig
from sklearn.cluster import KMeans

In [2]:
def score_formula(mult,mat):
    score = np.exp(-.5 * mult) / np.sqrt(2 * np.pi)/np.sqrt(np.linalg.det(mat))
    return score

class GaussianMixture:
    def __init__(self, X, gmm_components):
        self.n_components = gmm_components
        self.n_features = X.shape[1]
        self.n_samples = np.zeros(self.n_components)

        self.coefs = np.zeros(self.n_components)
        self.means = np.zeros((self.n_components, self.n_features))
        self.covariances = np.zeros(
            (self.n_components, self.n_features, self.n_features))

        self.init_with_kmeans(X)

    def init_with_kmeans(self, X):
        label = KMeans(n_clusters=self.n_components, n_init=1).fit(X).labels_
        self.fit(X, label)

    def calc_score(self, X, ci):
        score = np.zeros(X.shape[0])
        if self.coefs[ci] > 0:
            diff = X - self.means[ci]
            Tdiff = diff.T
            inv_cov = np.linalg.inv(self.covariances[ci])
            dot = np.dot(inv_cov, Tdiff)
            Tdot = dot.T
            mult = np.einsum('ij,ij->i', diff, Tdot)
            score = score_formula(mult,self.covariances[ci])
        return score

    def calc_prob(self, X):
        prob = []
        for ci in range(self.n_components):
            score = np.zeros(X.shape[0])
            if self.coefs[ci] > 0:
                diff = X - self.means[ci]
                Tdiff = diff.T
                inv_cov = np.linalg.inv(self.covariances[ci])
                dot = np.dot(inv_cov, Tdiff)
                Tdot = dot.T
                mult = np.einsum('ij,ij->i', diff, Tdot)
                score = score_formula(mult,self.covariances[ci])
            prob.append(score)
        ans = np.dot(self.coefs, prob)
        return ans

    def which_component(self, X):
        prob = []
        for ci in range(self.n_components):
            score = self.calc_score(X,ci)
            prob.append(score)
        prob = np.array(prob).T
        return np.argmax(prob, axis=1)

    def fit(self, X, labels):
        assert self.n_features == X.shape[1]
        self.n_samples[:] = 0
        self.coefs[:] = 0
        uni_labels, count = np.unique(labels, return_counts=True)
        self.n_samples[uni_labels] = count
        variance = 0.01
        for ci in uni_labels:
            n = self.n_samples[ci]
            sum = np.sum(self.n_samples)
            self.coefs[ci] = n / sum
            self.means[ci] = np.mean(X[ci == labels], axis=0)
            if self.n_samples[ci] <= 1:
                self.covariances[ci] = 0
            else:
                self.covariances[ci] =  np.cov(X[ci == labels].T)
            det = np.linalg.det(self.covariances[ci])
            if det <= 0:
                self.covariances[ci] += np.eye(self.n_features) * variance
                det = np.linalg.det(self.covariances[ci])

In [3]:
def construct_gc_graph(img,mask,gc_source,gc_sink,fgd_gmm,bgd_gmm,gamma,rows,cols,left_V, up_V, neighbours, upleft_V=None,upright_V=None):
    bgd_indexes = np.where(mask.reshape(-1) == DRAW_BG['val'])
    fgd_indexes = np.where(mask.reshape(-1) == DRAW_FG['val'])
    pr_indexes = np.where(np.logical_or(mask.reshape(-1) == DRAW_PR_BG['val'],mask.reshape(-1) == DRAW_PR_FG['val']))
#     print('bgd count: %d, fgd count: %d, uncertain count: %d' % (len(bgd_indexes[0]), len(fgd_indexes[0]), len(pr_indexes[0])))
    edges = []
    gc_graph_capacity = []
    
    edges.extend(list(zip([gc_source] * pr_indexes[0].size, pr_indexes[0])))
    _D = -np.log(bgd_gmm.calc_prob(img.reshape(-1, 3)[pr_indexes]))
    gc_graph_capacity.extend(_D.tolist())
    
    edges.extend(list(zip([gc_sink] * pr_indexes[0].size, pr_indexes[0])))
    _D = -np.log(fgd_gmm.calc_prob(img.reshape(-1, 3)[pr_indexes]))
    gc_graph_capacity.extend(_D.tolist())
    
    edges.extend(list(zip([gc_source] * bgd_indexes[0].size, bgd_indexes[0])))
    gc_graph_capacity.extend([0] * bgd_indexes[0].size)
    edges.extend(list(zip([gc_sink] * bgd_indexes[0].size, bgd_indexes[0])))
    gc_graph_capacity.extend([9 * gamma] * bgd_indexes[0].size)
    edges.extend(list(zip([gc_source] * fgd_indexes[0].size, fgd_indexes[0])))
    gc_graph_capacity.extend([9 * gamma] * fgd_indexes[0].size)
    edges.extend(list(zip([gc_sink] * fgd_indexes[0].size, fgd_indexes[0])))
    gc_graph_capacity.extend([0] * fgd_indexes[0].size)

    img_indexes = np.arange(rows*cols,dtype=np.uint32).reshape(rows,cols)
    temp1 = img_indexes[:, 1:]
    temp2 = img_indexes[:, :-1]
    mask1 = temp1.reshape(-1)
    mask2 = temp2.reshape(-1)
    edges.extend(list(zip(mask1, mask2)))
    gc_graph_capacity.extend(left_V.reshape(-1).tolist())
    
    temp1 = img_indexes[1:, 1:]
    temp2 = img_indexes[:-1, :-1]
    mask1 = temp1.reshape(-1)
    mask2 = temp2.reshape(-1)
    edges.extend(list(zip(mask1, mask2)))
    gc_graph_capacity.extend(up_V.reshape(-1).tolist())
    
    if neighbours == 8:
        temp1 = img_indexes[1:, :]
        temp2 = img_indexes[:-1, :]
        mask1 = temp1.reshape(-1)
        mask2 = temp2.reshape(-1)
        edges.extend(list(zip(mask1, mask2)))
        gc_graph_capacity.extend(upleft_V.reshape(-1).tolist())
        
        temp1 = img_indexes[1:, :-1]
        temp2 = img_indexes[:-1, 1:]
        mask1 = temp1.reshape(-1)
        mask2 = temp2.reshape(-1)
        edges.extend(list(zip(mask1, mask2)))
        gc_graph_capacity.extend(upright_V.reshape(-1).tolist())
        
    gc_graph = ig.Graph(cols * rows + 2)
    gc_graph.add_edges(edges)
    return gc_graph,gc_source,gc_sink,gc_graph_capacity

def estimate_segmentation(mask,gc_graph,gc_source,gc_sink,gc_graph_capacity,rows,cols):
    mincut = gc_graph.st_mincut(gc_source,gc_sink, gc_graph_capacity)
#     print('foreground pixels: %d, background pixels: %d' % (len(mincut.partition[0]), len(mincut.partition[1])))
    pr_indexes = np.where(np.logical_or(mask == DRAW_PR_BG['val'], mask == DRAW_PR_FG['val']))
    img_indexes = np.arange(rows * cols,dtype=np.uint32).reshape(rows, cols)
    mask[pr_indexes] = np.where(np.isin(img_indexes[pr_indexes], mincut.partition[0]),DRAW_PR_FG['val'], DRAW_PR_BG['val'])
    bgd_indexes = np.where(np.logical_or(mask == DRAW_BG['val'],mask == DRAW_PR_BG['val']))
    fgd_indexes = np.where(np.logical_or(mask == DRAW_FG['val'],mask == DRAW_PR_FG['val']))
#     print('probble background count: %d, probable foreground count: %d' % (bgd_indexes[0].size,fgd_indexes[0].size))
    return pr_indexes,img_indexes,mask,bgd_indexes,fgd_indexes

def classify_pixels(mask):
    bgd_indexes = np.where(np.logical_or(mask == DRAW_BG['val'], mask == DRAW_PR_BG['val']))
    fgd_indexes = np.where(np.logical_or(mask == DRAW_FG['val'], mask == DRAW_PR_FG['val']))
    return fgd_indexes, bgd_indexes

def compute_smoothness(img, rows, cols, neighbours):
    left_diff = img[:, 1:] - img[:, :-1]
    up_diff = img[1:, :] - img[:-1, :]
    sq_left_diff = np.square(left_diff)
    sq_up_diff = np.square(up_diff)
    beta_sum = (np.sum(sq_left_diff) + np.sum(sq_up_diff))
    avg = (2 * rows * cols) - cols - rows

    if neighbours == 8:
        upleft_diff = img[1:, 1:] - img[:-1, :-1]
        upright_diff = img[1:, :-1] - img[:-1, 1:]
        sq_upleft_diff = np.square(upleft_diff)
        sq_upright_diff = np.square(upright_diff)
        beta_sum += np.sum(sq_upleft_diff) + np.sum(sq_upright_diff)
        avg += (2 * rows * cols) - (2 * cols) - (2 * rows) + 2

    
    beta = avg / (2 * beta_sum)
#     print('Beta:',beta)
    left_V = gamma * np.exp(-beta * np.sum(np.square(left_diff), axis=2))
    up_V = gamma * np.exp(-beta * np.sum(np.square(up_diff), axis=2))
    
    if neighbours == 8:
        upleft_V = gamma / np.sqrt(2) * np.exp(-beta * np.sum(np.square(upleft_diff), axis=2))
        upright_V = gamma / np.sqrt(2) * np.exp(-beta * np.sum(np.square(upright_diff), axis=2))
        return gamma, left_V, up_V, upleft_V, upright_V
    else:
        return gamma, left_V, up_V, None, None

def initialize_gmm(img, bgd_indexes, fgd_indexes, gmm_components):
    bgd_gmm = GaussianMixture(img[bgd_indexes], gmm_components)
    fgd_gmm = GaussianMixture(img[fgd_indexes], gmm_components)
    
    return fgd_gmm, bgd_gmm

def GrabCut(img, mask, rect, gmm_components, gamma, neighbours, n_iters):
    img = np.asarray(img, dtype=np.float64)
    rows,cols, _ = img.shape
    if rect is not None:
        mask[rect[1]:rect[1] + rect[3],rect[0]:rect[0] + rect[2]] = DRAW_PR_FG['val']

    fgd_indexes, bgd_indexes = classify_pixels(mask)
    
    gmm_components = gmm_components
    gamma = gamma
    beta = 0
    neighbours = neighbours
    
    left_V = np.empty((rows,cols - 1))
    up_V = np.empty((rows - 1,cols))
    
    if neighbours == 8:
        upleft_V = np.empty((rows - 1,cols - 1))
        upright_V = np.empty((rows - 1,cols - 1))
        
    bgd_gmm = None
    fgd_gmm = None
    
    comp_idxs = np.empty((rows,cols), dtype=np.uint32)
    
    gc_graph = None
    gc_graph_capacity = None
    gc_source = cols*rows
    gc_sink = gc_source + 1

    gamma, left_V, up_V, upleft_V, upright_V = compute_smoothness(img, rows, cols, neighbours)
    fwd_gmm, bgd_gmm = initialize_gmm(img, bgd_indexes, fgd_indexes, gmm_components)
    
    n_iters = n_iters
    for iters in range(n_iters):
        fgd_gmm, bgd_gmm = initialize_gmm(img, bgd_indexes, fgd_indexes, gmm_components)
        
        if neighbours == 8:
            gc_graph,gc_source,gc_sink,gc_graph_capacity = construct_gc_graph(img,mask,gc_source,gc_sink,
                                                                              fgd_gmm,bgd_gmm,gamma,rows,
                                                                              cols,left_V, up_V, neighbours, 
                                                                              upleft_V, upright_V)
        else:
            gc_graph,gc_source,gc_sink,gc_graph_capacity = construct_gc_graph(img,mask,gc_source,gc_sink,
                                                                              fgd_gmm,bgd_gmm,gamma,rows,
                                                                              cols,left_V, up_V, neighbours,
                                                                              upleft_V=None, upright_V=None)

        pr_indexes,img_indexes,mask,bgd_indexes,fgd_indexes = estimate_segmentation(mask,gc_graph,gc_source,
                                                                                    gc_sink,gc_graph_capacity,
                                                                                    rows,cols)
    return mask

In [4]:
class EventHandler:
    """
    Class for handling user input during segmentation iterations 
    """
    
    def __init__(self, flags, img, _mask, colors):
        
        self.FLAGS = flags
        self.ix = -1
        self.iy = -1
        self.img = img
        self.img2 = self.img.copy()
        self._mask = _mask
        self.COLORS = colors

    @property
    def image(self):
        return self.img
    
    @image.setter
    def image(self, img):
        self.img = img
        
    @property
    def mask(self):
        return self._mask
    
    @mask.setter
    def mask(self, _mask):
        self._mask = _mask
    
    @property
    def flags(self):
        return self.FLAGS 
    
    @flags.setter
    def flags(self, flags):
        self.FLAGS = flags
    
    def handler(self, event, x, y, flags, param):

        # Draw the rectangle first
        if event == cv.EVENT_RBUTTONDOWN:
            self.FLAGS['DRAW_RECT'] = True
            self.ix, self.iy = x,y

        elif event == cv.EVENT_MOUSEMOVE:
            if self.FLAGS['DRAW_RECT'] == True:
                self.img = self.img2.copy()
                cv.rectangle(self.img, (self.ix, self.iy), (x, y), self.COLORS['BLUE'], 2)
                cv.rectangle(self._mask, (self.ix, self.iy), (x, y), self.FLAGS['value']['val'], -1)
                self.FLAGS['RECT'] = (min(self.ix, x), min(self.iy, y), abs(self.ix - x), abs(self.iy - y))
                self.FLAGS['rect_or_mask'] = 0

        elif event == cv.EVENT_RBUTTONUP:
            self.FLAGS['DRAW_RECT'] = False
            self.FLAGS['rect_over'] = True
            cv.rectangle(self.img, (self.ix, self.iy), (x, y), self.COLORS['BLUE'], 2)
            cv.rectangle(self._mask, (self.ix, self.iy), (x, y), self.FLAGS['value']['val'], -1)
            self.FLAGS['RECT'] = (min(self.ix, x), min(self.iy, y), abs(self.ix - x), abs(self.iy - y))
            self.FLAGS['rect_or_mask'] = 0

        
        # Draw strokes for refinement 

        if event == cv.EVENT_LBUTTONDOWN:
            if self.FLAGS['rect_over'] == False:
                print('Draw the rectangle first.')
            else:
                self.FLAGS['DRAW_STROKE'] = True
                cv.circle(self.img, (x,y), 3, self.FLAGS['value']['color'], -1)
                cv.circle(self._mask, (x,y), 3, self.FLAGS['value']['val'], -1)

        elif event == cv.EVENT_MOUSEMOVE:
            if self.FLAGS['DRAW_STROKE'] == True:
                cv.circle(self.img, (x, y), 3, self.FLAGS['value']['color'], -1)
                cv.circle(self._mask, (x, y), 3, self.FLAGS['value']['val'], -1)

        elif event == cv.EVENT_LBUTTONUP:
            if self.FLAGS['DRAW_STROKE'] == True:
                self.FLAGS['DRAW_STROKE'] = False
                cv.circle(self.img, (x, y), 3, self.FLAGS['value']['color'], -1)
                cv.circle(self._mask, (x, y), 3, self.FLAGS['value']['val'], -1)

In [5]:
def load_image(filename, color_space='RGB', scale=1.0):
    im = cv.imread(filename)
    if color_space == "RGB":
        pass
#         im = cv.cvtColor(im, cv.COLOR_BGR2RGB)
    elif color_space == "HSV":
        im = cv.cvtColor(im, cv.COLOR_BGR2HSV)
    elif color_space == "LAB":
        im = cv.cvtColor(im, cv.COLOR_BGR2LAB)
    if not scale == 1.0:
        im = cv.resize(im, (int(im.shape[1]*scale), int(im.shape[0]*scale)))
    return im

In [6]:
COLORS = {
    'BLACK' : [0,0,0],
    'RED'   : [0, 0, 255],
    'GREEN' : [0, 255, 0],
    'BLUE'  : [255, 0, 0],
    'WHITE' : [255,255,255]
    }

gmm_components = 15
gamma = 30
neighbours = 8 
color_space = 'RGB'
n_iters = 5

DRAW_PR_BG = {'color' : COLORS['BLACK'], 'val' : 2}
DRAW_PR_FG = {'color' : COLORS['WHITE'], 'val' : 3}
DRAW_BG = {'color' : COLORS['BLACK'], 'val' : 0}
DRAW_FG = {'color' : COLORS['WHITE'], 'val' : 1}

In [7]:
def run(filename, gamma=50, gmm_components=7, neighbours=8, color_space='RGB'):
    """
    Main loop that implements GrabCut. 
    
    Input
    -----
    filename (str) : Path to image
    """
    
    COLORS = {
    'BLACK' : [0,0,0],
    'RED'   : [0, 0, 255],
    'GREEN' : [0, 255, 0],
    'BLUE'  : [255, 0, 0],
    'WHITE' : [255,255,255]
    }

#     gmm_components = 4
#     gamma = 30
#     neighbours = 8  
#     color_space = 'RGB'

    DRAW_PR_BG = {'color' : COLORS['BLACK'], 'val' : 2}
    DRAW_PR_FG = {'color' : COLORS['WHITE'], 'val' : 3}
    DRAW_BG = {'color' : COLORS['BLACK'], 'val' : 0}
    DRAW_FG = {'color' : COLORS['WHITE'], 'val' : 1}

    FLAGS = {
        'RECT' : (0, 0, 1, 1),
        'DRAW_STROKE': False,         # flag for drawing strokes
        'DRAW_RECT' : False,          # flag for drawing rectangle
        'rect_over' : False,          # flag to check if rectangle is  drawn
        'rect_or_mask' : -1,          # flag for selecting rectangle or stroke mode
        'value' : DRAW_PR_FG,            # drawing strokes initialized to mark foreground
    }

    img = load_image(filename, color_space, scale=0.75)
    img2 = img.copy()                                
    mask = np.ones(img.shape[:2], dtype = np.uint8) * DRAW_PR_BG['val'] # mask is a binary array with : 0 - background pixels
                                                     #                               1 - foreground pixels 
    output = np.zeros(img.shape, np.uint8)           # output image to be shown

    # Input and segmentation windows
    cv.namedWindow('Input Image')
    cv.namedWindow('Segmented image')
    
    
    EventObj = EventHandler(FLAGS, img, mask, COLORS)
    cv.setMouseCallback('Input Image', EventObj.handler)
    cv.moveWindow('Input Image', img.shape[1] + 10, 90)
    
    while(1):
        
        img = EventObj.image
        mask = EventObj.mask
        FLAGS = EventObj.flags
        cv.imshow('Segmented image', output)
        cv.imshow('Input Image', img)
        
        k = cv.waitKey(1)

        # key bindings
        if k == 27:
            # esc to exit
            break
        
        elif k == ord('0'): 
            # Strokes for background
            FLAGS['value'] = DRAW_BG
        
        elif k == ord('1'):
            # FG drawing
            FLAGS['value'] = DRAW_FG
        
        elif k == ord('r'):
            # reset everything
            FLAGS['RECT'] = (0, 0, 1, 1)
            FLAGS['DRAW_STROKE'] = False
            FLAGS['DRAW_RECT'] = False
            FLAGS['rect_or_mask'] = -1
            FLAGS['rect_over'] = False
            FLAGS['value'] = DRAW_PR_FG
            img = img2.copy()
            mask = np.zeros(img.shape[:2], dtype = np.uint8) 
            EventObj.image = img
            EventObj.mask = mask
            output = np.zeros(img.shape, np.uint8)
        
        elif k == 13:
            # Press carriage return to initiate segmentation
            
            #-------------------------------------------------#
            # Implement GrabCut here.                         #  
            # Function should return a mask which can be used #
            # to segment the original image as shown on L90   # 
            #-------------------------------------------------#
            rect = FLAGS['RECT']
            mask = GrabCut(img2,mask,rect, gmm_components, gamma, neighbours, n_iters)
        
            EventObj.flags = FLAGS
            mask2 = np.where((mask == 1) + (mask == 3), 255, 0).astype('uint8')
            output = cv.bitwise_and(img2, img2, mask=mask2)

In [10]:
if __name__ == '__main__':
    filename = '../images/sheep.jpg'               # Path to image file
    run(filename, gamma, gmm_components, neighbours, color_space)
    cv.destroyAllWindows()

# Report

### 1. Different Color Spaces

#### RGB Space
![](https://github.com/saiamrit/technical-blog/blob/master/images/rgb.png)

#### HSV Space
![](https://github.com/saiamrit/technical-blog/blob/master/images/hsv.png)

#### LAB Space
![](https://github.com/saiamrit/technical-blog/blob/master/images/lab.png)

**Conclusion** - The results are better on the RGB color space

### 2. Different size of bounding box

#### Tight Bounding Box
![](https://github.com/saiamrit/technical-blog/blob/master/images/small.png)

#### Large Bounding Box
![](https://github.com/saiamrit/technical-blog/blob/master/images/large.png)

**Conclusion** - The results are better with a tighter bounding box

### 3. Different value of Gamma

#### Gamma = 10
![](https://github.com/saiamrit/technical-blog/blob/master/images/gamma10.png)

#### Gamma = 50
![](https://github.com/saiamrit/technical-blog/blob/master/images/gamma50.png)

#### Gamma = 100
![](https://github.com/saiamrit/technical-blog/blob/master/images/gamma100.png)

**Conclusion** - The results are better with a gamma around 50. With lower gamma, lot of area is segmented out and with a higher gamma, some foreground areas are missed out.

### 4. 4 or 8 Connectivity

#### 4 Connectivity
![](https://github.com/saiamrit/technical-blog/blob/master/images/4conn.png)

#### 8 Connectivity
![](https://github.com/saiamrit/technical-blog/blob/master/images/8conn.png)

**Conclusion** - The results are better with 8 connectivity, because similarity is more appropriately computed by taking into account more nearby pixels.

### 5. Number of GMM Components

#### 2 Components
![](https://github.com/saiamrit/technical-blog/blob/master/images/2c.png)

#### 5 Components
![](https://github.com/saiamrit/technical-blog/blob/master/images/5c.png)

#### 10 Components
![](https://github.com/saiamrit/technical-blog/blob/master/images/10c.png)

**Conclusion** - The results are better with a higher no of components but with increasing no. of components, the results donot change significantly.