# Object-based Image Classification
*Author: Vladislav Kim*
* [Introduction](#intro)
* [Generate initial segmentation](#initialsegm)
* [Image labelling](#label)
* [Training and test set generation](#trainset)
* [Random forest classifier](#randomforest)
* [Parameter tuning and feature selection](#featureselect)
* [Comparison with other classifiers](#comparison)


<a id="intro"></a> 
## Introduction
Segmentation using classical computer vision approaches such as watershed may produce results that have to be filtered based on their region properties to eliminate segmentation artefacts, such as small objects, noise, etc. If the image set is large (such as in high-throughput screening), filtering based on fixed thresholds may be supoptimal. In order to automate the task of filtering artefacts we can resort to machine learning approaches. 


There is a number of different schemes and machine learning models that can be used for this purpose. Here we will show how to train an object-based random forest classifier. The input for this classifier will be cropped bounding regions of the initial segmentation generated by simple spot detection. The task will be to classify the image patches into various cell types. 

Here we are dealing with coculture images with 2 cell types: stroma and leukemia cells, which were not stained differentially. Due to this minimal staining palette we need to use machine learning to automate the process of identification of leukemia cells.

In [None]:
# load third-party Python modules
import javabridge
import bioformats as bf
import skimage
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sn
import pandas as pd

import sys
sys.path.append('..')

javabridge.start_vm(class_path=bf.JARS)

In [None]:
from base.utils import load_imgstack
imgstack = load_imgstack(fname="data/AML_trainset/180528_Plate3/r02c14.tiff")

In [None]:
# remove a 'dummy' z-axis
img = np.squeeze(imgstack)

In [None]:
from base.plot import plot_channels
gamma = 0.4
plot_channels([img[:,:,i]**gamma for i in range(3)],
              nrow=1, ncol=3, cmap='gray',
              titles=['Hoechst', 'Lysosomal dye', 'Viability'])

In [None]:
from base.plot import combine_channels

img_rgb = combine_channels([img[:,:,i] for i in range(3)],
                            colors=['blue', 'red', 'green'],
                            blend=[1.5,1.5,2],
                            gamma=[0.6, 0.6,0.6])

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(img_rgb)
plt.axis('off')

<a id="initialsegm"></a> 
## Generate initial segmentation
We can generate initial segmentation using spot detection in the nucleus channel:

In [None]:
from transform.process import threshold_img
hoechst = img[:,:,0]**gamma
img_th = threshold_img(hoechst, method='otsu')

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(img_th)
plt.axis('off')

In [None]:
from skimage.morphology import binary_erosion, disk
img_th = binary_erosion(threshold_img(hoechst, binary=True, method='otsu'), disk(5))

In [None]:
from skimage.measure import label
from skimage.color import label2rgb
segm = label(img_th, connectivity=1)

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(label2rgb(segm, image=hoechst, bg_label=0))
ax.axis('off')

In [None]:
from skimage.measure import regionprops
feats =  regionprops(label_image=segm, intensity_image=hoechst)

In [None]:
def get_feattable(feats, keys):
    return pd.DataFrame({key: [f[key] for f in feats] for key in keys})

In [None]:
feat_df = get_feattable(feats, keys=['area', 'eccentricity', 'mean_intensity', 'perimeter'])

In [None]:
upper = np.logical_and(feat_df.area < 6000, feat_df.perimeter < 1000)
lower = np.logical_and(feat_df.area > 100, feat_df.perimeter > 50)
feat_subset = np.logical_and(lower, upper)
# label count of non-background objects starts with 1
label_subset = np.where(feat_subset)[0] + 1

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(label2rgb(segm*np.isin(segm, label_subset), image=hoechst, bg_label=0))
ax.axis('off')

In [None]:
def filter_segm(labels, bounds):
    '''Subset labelled pixel map (segmentation)
       ----------------------------------------
       The function allows to subset the segmentation
       by providing lower and upper bounds of various
       region properties


       Parameters
       ----------
       labels : (nrow, ncol) array
           Labelled image (generated by segmentation)
       bounds : dict
           Dictionary with keys naming features (regionprops)
            and values providing lower and upper bounds


       Returns
       -------
       segm_out : (nrow, ncol) array
           Filtered labelled array
        '''
    pass

In [None]:
help(filter_segm)

First, add these to the initial candidate pool. Deal with the large objects: search for bright spots (apoptotic nuclei) in the mask that only has large objects that were filtered out. The approach we take here is to "break up" large objects into smaller chunks and these can be further prefiltered by intensity

In [None]:
big = np.where(~upper)[0]+1

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(label2rgb(segm*np.isin(segm, big), image=hoechst, bg_label=0))
ax.axis('off')

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(hoechst*np.isin(segm, big))
plt.axis('off')

In [None]:
from skimage.morphology import white_tophat
img_tophat = white_tophat(hoechst*np.isin(segm, big), disk(25))
img_tophat_th = threshold_img(img_tophat, method='yen')

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(img_tophat_th)
plt.axis('off')

In [None]:
from skimage.morphology import remove_small_objects

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(remove_small_objects(threshold_img(img_tophat, method='yen', binary=True),
                                min_size=500))
plt.axis('off')

Merge these together with the other spots, generate (cropped) bounding boxes of each "cell" candidate and label them in the next step.

In [None]:
segm1 = np.isin(segm, label_subset)

In [None]:
segm2 = remove_small_objects(threshold_img(img_tophat, method='yen', binary=True),
                                min_size=500)

In [None]:
segm_out = label(np.logical_or(segm1, segm2))

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(label2rgb(segm_out, hoechst))
plt.axis('off')

In [None]:
feats_out =  regionprops(label_image=segm_out, intensity_image=hoechst)

In [None]:
pad = 20
bbox = []

for f in feats_out:
    ymin, xmin, ymax, xmax = f.bbox
    bb = np.array((max(0, xmin - pad),
                  min(xmax + pad, hoechst.shape[0] - 1),
                  max(0, ymin - pad),
                  min(ymax + pad, hoechst.shape[0] - 1)))
    bbox.append(bb)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(12,12))
ax.imshow(img_rgb)
for bb in bbox:
    start = (bb[0], bb[2])
    extent = (bb[1] - bb[0],
              bb[3] - bb[2])
    rec = plt.Rectangle(xy=start,
                 width = extent[1],
                 height = extent[0],color = "white",
                 linewidth=2, fill=False)
    ax.add_patch(rec)
ax.axis('off')

Now we need to store both bounding box coordinates as well as region properties for each image patch and we need an easy (and scalable) way of retrieving bounding box information for each image.

First generate a table (`DataFrame`) with all the information:

In [None]:
from base.future_versions import regionprops_table

In [None]:
keys = [k for k in feats_out[0]]

In [None]:
exclude = ['convex_image', 'coords', 'extent',
           'filled_image', 'image']

In [None]:
selected_keys = list(set(keys) - set(exclude))
# sort by key lexicographically
selected_keys.sort()

In [None]:
feat_dict = regionprops_table(segm_out,
                       intensity_image=hoechst,
                      properties=selected_keys)

feat_df = pd.DataFrame(feat_dict)


In [None]:
feat_df.iloc[:6,:10]

In [None]:
feat_df.to_csv("data/AML_trainset/180528_Plate3/r02c14.csv", index=False)

<a id="label"></a> 
## Image labelling


<a id="trainset"></a> 
## Training and test set generation


<a id="randomforest"></a> 
## Random forest classifier

<a id="featureselect"></a> 
## Parameter tuning and feature selection

<a id="comparison"></a> 
## Comparison with other classifiers