# GrabCut Algorithm

This notebook is a demo notebook demonstrating the working of the GrabCut Algorithm using Python. This is an implementation of the GrabCut Algorithm method mentioned in this [paper](https://cvg.ethz.ch/teaching/cvl/2012/grabcut-siggraph04.pdf). It is essentially an iterative version of GraphCut algorithm. 

## Procedure
The code below takes an input image and follows these steps:
- It requires a bounding box to be drawn by the user to roughly segment out the foreground pixels
- It runs an initial min-cut optimization using the provided annotation
- The result of this optimization gives an initial segmentation 
- To further refine this segmentation, the user provides two kinds of strokes to aid the optimization
    - strokes on the background pixels
    - strokes on the foreground pixels
- The algorithm now utilizes this to refine the original segmentation

In [1]:
# Importing the necessary packages

import numpy as np
import cv2
from gmm import *
from grabcut import *
from ui import *
import warnings
warnings.filterwarnings("ignore")

In [3]:
def run(filename, gamma=50, gmm_components=7, neighbours=8, color_space='RGB'):
    """
    Main loop that implements GrabCut. 
    
    Input
    -----
    filename (str) : Path to image
    
    Output
    ------
    Displays the segmented output on a window.
    
    Procedure
    ----------
    Accepts all the hyperparameters, initialises the UI that helps to graw the bounding box,
    passes the user inputs and params to the grabcut module and after all computations,
    displays the final output on a separate screen.
    """
    
    COLORS = {
    'BLACK' : [0,0,0],
    'RED'   : [0, 0, 255],
    'GREEN' : [0, 255, 0],
    'BLUE'  : [255, 0, 0],
    'WHITE' : [255,255,255]
    }

    DRAW_PR_BG = {'color' : COLORS['BLACK'], 'val' : 2}
    DRAW_PR_FG = {'color' : COLORS['WHITE'], 'val' : 3}
    DRAW_BG = {'color' : COLORS['BLACK'], 'val' : 0}
    DRAW_FG = {'color' : COLORS['WHITE'], 'val' : 1}

    FLAGS = {
        'RECT' : (0, 0, 1, 1),
        'DRAW_STROKE': False,         # flag for drawing strokes
        'DRAW_RECT' : False,          # flag for drawing rectangle
        'rect_over' : False,          # flag to check if rectangle is  drawn
        'rect_or_mask' : -1,          # flag for selecting rectangle or stroke mode
        'value' : DRAW_PR_FG,            # drawing strokes initialized to mark foreground
    }

    img = load_image(filename, scale=0.8)
    img2 = img.copy()                                
    mask = np.ones(img.shape[:2], dtype = np.uint8) * DRAW_PR_BG['val'] # mask is a binary array with : 0 - background pixels
                                                     #                               1 - foreground pixels 
    output = np.zeros(img.shape, np.uint8)           # output image to be shown

    # Input and segmentation windows
    cv2.namedWindow('Input Image')
    cv2.namedWindow('Segmented image')
    
    # Instantiating the UI object to start the UI and creates the display windows
    EventObj = EventHandler(FLAGS, img, mask, COLORS)
    cv2.setMouseCallback('Input Image', EventObj.handler)
    cv2.moveWindow('Input Image', img.shape[1] + 10, 90)
    
    
    while(1):
        
        '''
        Loop that accepts the user inputa and performs actions accordingly
        '''
        
        img = EventObj.image
        mask = EventObj.mask
        FLAGS = EventObj.flags
        cv2.imshow('Segmented image', output)
        cv2.imshow('Input Image', img)
        
        k = cv2.waitKey(1)

        # key bindings
        if k == 27:
            # esc to exit
            break
        
        elif k == ord('0'): 
            # Strokes for background
            FLAGS['value'] = DRAW_BG
        
        elif k == ord('1'):
            # FG drawing
            FLAGS['value'] = DRAW_FG
        
        elif k == ord('r'):
            # reset everything
            FLAGS['RECT'] = (0, 0, 1, 1)
            FLAGS['DRAW_STROKE'] = False
            FLAGS['DRAW_RECT'] = False
            FLAGS['rect_or_mask'] = -1
            FLAGS['rect_over'] = False
            FLAGS['value'] = DRAW_PR_FG
            img = img2.copy()
            mask = np.zeros(img.shape[:2], dtype = np.uint8) 
            EventObj.image = img
            EventObj.mask = mask
            output = np.zeros(img.shape, np.uint8)
        
        elif k == 13:
            # Press Enter to initiate segmentation
            gc = GrabCut(img=img2, mask=mask, n_iters=2, gamma=gamma, gmm_components=gmm_components, neighbours=neighbours)
            mask = gc.mask.copy()
        
            EventObj.flags = FLAGS
            mask2 = np.where(np.logical_or((mask == 3), (mask==1)), 255, 0).astype('uint8')
            output = cv2.bitwise_and(img2, img2, mask = mask2)

In [6]:
# Initialise the parameters and call the run method

gmm_components = 5
gamma = 30
neighbours = 8 
color_space = 'RGB'
n_iters = 10


filename = './images/sheep.jpg'               # Path to image file
run(filename, gamma, gmm_components, neighbours, color_space)
cv.destroyAllWindows()