<a href="https://colab.research.google.com/github/wjdolan/start-here-guidelines/blob/master/ml_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ML Assisted Segmentation Labelling Tool

This colab is to show you how ML assisted segmentation works. After you run the code, you will be able to do the label yourself and witness how working alongside a pretrained machine learning model can speed up your labelling process.

Note that this demo works best for image with single object. For more complicated images, you should explore commercial software such as [Datature](https://datature.io/).

Link to medium article: 

Credits go to: 
- https://github.com/ianhi/AC295-final-project-JWI
- https://github.com/ayoolaolafenwa/PixelLib

## Installation

In [None]:
#@title
!pip install --quiet gradio requests ipympl sidecar pixellib

[K     |████████████████████████████████| 979 kB 5.3 MB/s 
[K     |████████████████████████████████| 84 kB 3.1 MB/s 
[K     |████████████████████████████████| 80 kB 8.6 MB/s 
[K     |████████████████████████████████| 430 kB 47.0 MB/s 
[K     |████████████████████████████████| 2.0 MB 28.7 MB/s 
[K     |████████████████████████████████| 206 kB 44.6 MB/s 
[K     |████████████████████████████████| 8.6 MB 21.8 MB/s 
[K     |████████████████████████████████| 58 kB 5.3 MB/s 
[K     |████████████████████████████████| 396 kB 47.9 MB/s 
[K     |████████████████████████████████| 428 kB 47.0 MB/s 
[K     |████████████████████████████████| 78 kB 6.9 MB/s 
[K     |████████████████████████████████| 129 kB 46.9 MB/s 
[K     |████████████████████████████████| 53 kB 1.6 MB/s 
[K     |████████████████████████████████| 69 kB 6.6 MB/s 
[K     |████████████████████████████████| 145 kB 24.9 MB/s 
[K     |████████████████████████████████| 49 kB 5.5 MB/s 
[K     |██████████████████████████████

In [None]:
#@title
!wget -N 'https://github.com/ayoolaolafenwa/PixelLib/releases/download/0.2.0/pointrend_resnet50.pkl'

--2021-12-13 07:11:44--  https://github.com/ayoolaolafenwa/PixelLib/releases/download/0.2.0/pointrend_resnet50.pkl
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/255074156/9a868b31-4ac5-477e-9611-a011b1fecf8b?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20211213%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20211213T071144Z&X-Amz-Expires=300&X-Amz-Signature=602297ee02ad9be509f0421f6f14d7f6b81182eccb4af07a4e3a725ae0add1a8&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=255074156&response-content-disposition=attachment%3B%20filename%3Dpointrend_resnet50.pkl&response-content-type=application%2Foctet-stream [following]
--2021-12-13 07:11:44--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/255074156/9a868b31-4ac5-477e-9611-a

In [None]:
#@title
import matplotlib.pyplot as plt
import numpy as np
from skimage.segmentation import flood
from os import path
import os
import cv2
from skimage import io
from skimage.transform import resize,rescale
import ipywidgets as widgets
import glob
from matplotlib.widgets import LassoSelector
from matplotlib.path import Path
from sidecar import Sidecar
from google.colab import output
output.enable_custom_widget_manager()

import pixellib
from pixellib.torchbackend.instance import instanceSegmentation

import warnings
warnings.filterwarnings("ignore")

%matplotlib widget

In [None]:
IMAGE_DIR   = '/content/images'
MASK_DIR    = '/content/masks'
ML_MASK_DIR = '/content/ml_assisted_masks'

if not os.path.exists(IMAGE_DIR):
    os.mkdir(IMAGE_DIR)
if not os.path.exists(MASK_DIR):
    os.mkdir(MASK_DIR)
if not os.path.exists(ML_MASK_DIR):
    os.mkdir(ML_MASK_DIR)

## Things to take note before running the code

Before running the code, you need to put your images into the **Images** folder.
- For images, I would suggest you to download and try these two free-to-use images from pexels.com (eg. https://www.pexels.com/search/bird/)

**Masks** folder will be automatically created. Once you create a mask, it will be saved under this folder

**Model Predicted Masks** folder will be automatically created. During initialization, the predicted masks will be automatically created.





In [None]:
VALID_IMAGE_TYPES = ['jpeg', 'png', 'bmp', 'gif', 'jpg']

__all__ = [
    'panhandler',
    'image_segmenter'
]

class panhandler:
    """
    enable right click to pan image
    this doesn't set up the eventlisteners, whatever calls this needs to do
    fig.mpl_connect('button_press_event', panhandler.press)
    fig.mpl_connect('button_release_event', panhandler.release)
    
    or somehitng 
    """
    def __init__(self, figure):
        self.figure = figure
        self._id_drag = None

    def _cancel_action(self):
        self._xypress = []
        if self._id_drag:
            self.figure.canvas.mpl_disconnect(self._id_drag)
            self._id_drag = None
        
    def press(self, event):
        if event.button == 1:
            return
        elif event.button == 3:
            self._button_pressed = 1
        else:
            self._cancel_action()
            return

        x, y = event.x, event.y

        self._xypress = []
        for i, a in enumerate(self.figure.get_axes()):
            if (x is not None and y is not None and a.in_axes(event) and
                    a.get_navigate() and a.can_pan()):
                a.start_pan(x, y, event.button)
                self._xypress.append((a, i))
                self._id_drag = self.figure.canvas.mpl_connect(
                    'motion_notify_event', self._mouse_move)
    def release(self, event):
        self._cancel_action()
        self.figure.canvas.mpl_disconnect(self._id_drag)


        for a, _ind in self._xypress:
            a.end_pan()
        if not self._xypress:
            self._cancel_action()
            return
        self._cancel_action()

    def _mouse_move(self, event):
        for a, _ind in self._xypress:
            # safer to use the recorded button at the _press than current
            # button: # multiple button can get pressed during motion...
            a.drag_pan(1, event.key, event.x, event.y)
        self.figure.canvas.draw_idle()

class image_segmenter:
    def __init__(self, img_dir, overlay_alpha=.5,figsize=(10,10), scroll_to_zoom=True, zoom_scale=1.1):
        """
        TODO allow for intializing with a shape instead of an image
        
        parameters
        ----------
        img_dir : string
            path to directory 'images' that contains 'train/' and images are in 'train/'
        classes : Int or list
            Number of classes or a list of class names
        ensure_rgba : boolean
            whether to force the displayed image to have an alpha channel to enable transparent overlay
        zoom_scale : float or None
            How much to scale the image per scroll. If you do this I recommend using jupyterlab-sidecar in order
            to prevent the page from scrolling. or checking in on: https://github.com/matplotlib/ipympl/issues/222
            To disable zoom set this to None.
        """

        self.img_dir = img_dir

        if not path.isdir(self.img_dir):
            raise ValueError(f"{self.img_dir} must exist and contain the the folder 'train'")
        #ensure that there is a sibling directory named masks
        if not os.path.isdir(MASK_DIR):
            os.makedirs(MASK_DIR)

        self.image_paths = []
        for type_ in VALID_IMAGE_TYPES:
            self.image_paths += (glob.glob(self.img_dir.rstrip('/')+f'/*.{type_}'))
        self.shape = None

        # generate masks and save it in 
        print('Preparing ML predicted masks')
        ins = instanceSegmentation()
        ins.load_model("pointrend_resnet50.pkl", detection_speed='rapid')
        for img_path in self.image_paths:
            result = ins.segmentImage(img_path,show_bboxes=False)
            mask = (result[0]['masks']*1).astype(np.uint8)
            img_name = os.path.basename(img_path)
            mask_path = os.path.join(ML_MASK_DIR,img_name)
            io.imsave(mask_path,mask,check_contrast=False,quality=100)
        
        plt.ioff() # see https://github.com/matplotlib/matplotlib/issues/17013
        self.fig = plt.figure(figsize=figsize)
        self.ax = self.fig.gca()
        lineprops = {'color': 'black', 'linewidth': 1, 'alpha': 0.8}
        self.lasso = LassoSelector(self.ax, self.onselect,lineprops=lineprops, button=1,useblit=False)
        self.lasso.set_visible(True)
        self.fig.canvas.mpl_connect('button_press_event', self.onclick)
        self.fig.canvas.mpl_connect('button_release_event', self._release)
        self.panhandler = panhandler(self.fig)

        # setup lasso stuff
        plt.ion()
        
        # hardcoded classes because we don't need it for our demo
        classes = 1
        if isinstance(classes, int):
            classes = np.arange(classes)
        if len(classes)<=10:
            self.colors = 'tab10'
        elif len(classes)<=20:
            self.colors = 'tab20'
        else:
            raise ValueError(f'Currently only up to 20 classes are supported, you tried to use {len(classes)} classes')
        
        self.colors = np.vstack([[0,0,0],plt.get_cmap(self.colors)(np.arange(len(classes)))[:,:3]])

        self.lasso_button = widgets.Button(
            description='lasso select',
            disabled=False,
            button_style='success', # 'success', 'info', 'warning', 'danger' or ''
            icon='mouse-pointer', # (FontAwesome names without the `fa-` prefix)
        )
        self.flood_button = widgets.Button(
            description='flood fill',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            icon='fill-drip', # (FontAwesome names without the `fa-` prefix)
        )
        self.ml_button = widgets.Button(
            description='ml-assisted',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            icon='fill-drip', # (FontAwesome names without the `fa-` prefix)
        )
        self.ml_assisted_check_box = widgets.Checkbox(
            value=False,
            description='ML Assisted',
            disabled=False,
            indent=False
        )
        
        self.erase_check_box = widgets.Checkbox(
            value=False,
            description='Erase Mode',
            disabled=False,
            indent=False
        )
        
        self.reset_button = widgets.Button(
            description='reset',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            icon='refresh', # (FontAwesome names without the `fa-` prefix)
        )
        self.save_button = widgets.Button(
            description='save mask',
            button_style='',
            icon='floppy-o'
        )
        self.next_button = widgets.Button(
            description='next image',
            button_style='',
            icon='arrow-right'
        )
        self.prev_button = widgets.Button(
            description='previous image',
            button_style='',
            icon='arrow-left',
            disabled=True
        )
        self.reset_button.on_click(self.reset)
        self.save_button.on_click(self.save_mask)
        self.next_button.on_click(self._change_image_idx)
        self.prev_button.on_click(self._change_image_idx)

        self.ml_assisted = False

        def button_click(button):
            if button.description == 'ml-assisted':
                self.ml_assisted = True
                self.new_image(self.img_idx)
                self.ml_assisted = False
            elif button.description == 'flood fill':
                self.flood_button.button_style='success'
                self.lasso_button.button_style=''
                self.lasso.set_active(False)
            else:
                self.flood_button.button_style=''
                self.lasso_button.button_style='success'
                self.lasso.set_active(True)
        
        self.ml_button.on_click(button_click)
        self.lasso_button.on_click(button_click)
        self.flood_button.on_click(button_click)
        self.overlay_alpha = overlay_alpha
        self.indices = None
        
        self.new_image(0)

    def _change_image_idx(self, button):
        if button is self.next_button:
            if self.img_idx +1 < len(self.image_paths):
                self.img_idx += 1
                self.save_mask()
                self.new_image(self.img_idx)
                
                if self.img_idx == len(self.image_paths):
                    self.next_button.disabled = True
                self.prev_button.disabled=False
        elif button is self.prev_button:
            if self.img_idx>=1:
                self.img_idx -= 1
                self.save_mask()
                self.new_image(self.img_idx)
                
                if self.img_idx == 0:
                    self.prev_button.disabled=True
                
                self.next_button.disabled=False
            
    def new_image(self, img_idx):
        self.indices=None
        image = cv2.imread(self.image_paths[img_idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        self.img = resize(image, (450,540))
        self.img_idx = img_idx
        img_path = self.image_paths[self.img_idx]
        self.ax.set_title(os.path.basename(img_path))
        self.mask_path = MASK_DIR + f'/{os.path.basename(img_path)}'
        self.ml_mask_path = ML_MASK_DIR + f'/{os.path.basename(img_path)}'
        if self.img.shape != self.shape:
            self.shape = self.img.shape
            pix_x = np.arange(self.shape[0])
            pix_y = np.arange(self.shape[1])
            xv, yv = np.meshgrid(pix_y,pix_x)
            self.pix = np.vstack( (xv.flatten(), yv.flatten()) ).T
            self.displayed = self.ax.imshow((self.img * 255).astype(np.uint8))
            #ensure that the _nav_stack is empty
            self.fig.canvas.toolbar._nav_stack.clear()
            
            #add the initial view to the stack so that the home button works.
            self.fig.canvas.toolbar.push_current()
            if self.ml_assisted and os.path.exists(self.ml_mask_path):
                mask = io.imread(self.ml_mask_path)
                self.class_mask = cv2.resize(mask, (540,450)).astype(np.uint8)
            elif os.path.exists(self.mask_path):
                self.class_mask = io.imread(self.mask_path).astype(np.uint8)
            else:
                if os.path.exists(self.ml_mask_path) and self.ml_assisted_check_box.value:
                    mask = io.imread(self.ml_mask_path)
                    self.class_mask = cv2.resize(mask, (540,450)).astype(np.uint8)
                else:
                    self.class_mask = np.zeros([self.shape[0],self.shape[1]],dtype=np.uint8)
        else:
            self.displayed.set_data((self.img * 255).astype(np.uint8))
            if self.ml_assisted and os.path.exists(self.ml_mask_path):
                mask = io.imread(self.ml_mask_path)
                self.class_mask = cv2.resize(mask, (540,450)).astype(np.uint8)
            elif os.path.exists(self.mask_path):
                self.class_mask = io.imread(self.mask_path).astype(np.uint8)
                # should probs check that the first two dimensions are the same as the img
            else:
                if os.path.exists(self.ml_mask_path) and self.ml_assisted_check_box.value:
                    mask = io.imread(self.ml_mask_path)
                    self.class_mask = cv2.resize(mask, (540,450)).astype(np.uint8)
                else:
                    self.class_mask[:,:] = 0
            self.fig.canvas.toolbar.home()
        self.updateArray()

    def _release(self, event):
        self.panhandler.release(event)

    def reset(self,*args):
        self.displayed.set_data(self.img)
        self.class_mask[:,:] = -1
        self.fig.canvas.draw()

    def onclick(self, event):
        """
        handle clicking to remove already added stuff
        """
        if event.button == 1:
            if event.xdata is not None and not self.lasso.active:
                # transpose x and y bc imshow transposes
                self.indices = flood(self.class_mask,(np.int(event.ydata), np.int(event.xdata)))
                self.updateArray()
        elif event.button == 3:
            self.panhandler.press(event)

    def updateArray(self):
        array = self.displayed.get_array().data

        if self.erase_check_box.value:
            if self.indices is not None:
                self.class_mask[self.indices] = 0
                array[self.indices] = self.img[self.indices]
        elif self.indices is not None:
            self.class_mask[self.indices] = 1
            array[self.indices] = self.img[self.indices]*(1-self.overlay_alpha)
        else:
            # new image and we found a class mask
            # so redraw entire array where class != 0
            idx = self.class_mask != 0
            array[idx] = self.img[idx]*(1-self.overlay_alpha)
        self.displayed.set_data(array)

    def onselect(self,verts):
        self.verts = verts
        p = Path(verts)

        self.indices = p.contains_points(self.pix, radius=0).reshape(450,540)

        self.updateArray()
        self.fig.canvas.draw_idle()
        
    def render(self):
        layers = [widgets.HBox([self.lasso_button, self.flood_button, self.ml_button])]
        layers.append(widgets.HBox([self.reset_button, self.erase_check_box]))
        layers.append(self.fig.canvas)
        layers.append(widgets.HBox([self.save_button, self.prev_button, self.next_button]))
        return widgets.VBox(layers)

    def save_mask(self, save_if_no_nonzero=False):
        """
        save_if_no_nonzero : boolean
            Whether to save if class_mask only contains 0s
        """
        if (save_if_no_nonzero or np.any(self.class_mask != 0)):
            io.imsave(self.mask_path,self.class_mask,check_contrast =False, quality=100)

    def _ipython_display_(self):
        display(self.render())

In [None]:
segmenter = image_segmenter(IMAGE_DIR)

Preparing ML predicted masks


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


In [None]:
sc = Sidecar(title='Segmentation')
with sc:
    display(segmenter)

VBox(children=(HBox(children=(Button(button_style='success', description='lasso select', icon='mouse-pointer',…

Unhandled message type set_device_pixel_ratio. {'type': 'set_device_pixel_ratio', 'device_pixel_ratio': 2}
Unhandled message type set_device_pixel_ratio. {'type': 'set_device_pixel_ratio', 'device_pixel_ratio': 2}
Unhandled message type set_device_pixel_ratio. {'type': 'set_device_pixel_ratio', 'device_pixel_ratio': 2}
