## Label image data in plate 180528_Plate3
*Vladislav Kim*


* [Introduction](#1)
* [Initial training set](#2)

<a id="1"></a> 
## Introduction
The idea of this notebook series is to train a pseudo-online random forest classifier for AML vs stroma cell classification. From selected plates we (for now) sample 6 DMSO wells with the highest Calcein cell count, generate predictions and correct misclassified instances and check in live ("online") mode how the predictions improve as we add more data. Note that the classifier is not truly an online classifier as we don't update the model as we go, but completely retrain the RF classfier in multicore mode.

In general we can implement targeted online learning strategy: we can select a number of wells that are of interest to us (target wells), e.g. DMSO control wells or wells with certain high-priority drugs, the accuracy of which we want to improve, in the first place. We sample from these target wells from selected plates and evaluate the classification accuracy as we go (pseudo online learning).


<a id="2"></a>
## Initial Training Set: 180528_Plate3
At first we will re-train the classifier on the plate `180528_Plate3`, as it manifests a very striking contrast between mono- and co-cultures. We want to rule out the fact that this could be a segmentation (classification in this case) artefact 

In [None]:
# load third-party Python modules
import javabridge
import bioformats as bf
import skimage
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sn
import pandas as pd
import os
import sys
sys.path.append('../../..')

javabridge.start_vm(class_path=bf.JARS)

In [None]:
from base.utils import load_imgstack
imgstack = load_imgstack(fname="../../data/AML_trainset/180528_Plate3/r02c14.tiff")

# remove a 'dummy' z-axis
img = np.squeeze(imgstack)

# nuclei
hoechst = img[:,:,0]**0.3

In [None]:
df = pd.read_csv('../../data/AML_trainset/180528_Plate3/r02c14.csv')

In [None]:
from segment.tools import read_bbox
rmax, cmax = hoechst.shape

bbox = read_bbox(df=df, rmax=rmax, cmax=cmax)

In [None]:
from base.plot import show_bbox
#show_bbox(hoechst, bbox)

**Plotly visualization works!**

In [None]:
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.graph_objs as go
init_notebook_mode(connected=True)

In [None]:
from extra.viz import plotly_viz

In [None]:
from skimage.exposure import equalize_adapthist
gamma = 0.3
img_g = img**gamma
mip_rgb = equalize_adapthist(np.dstack((img_g[:,:,1],
                                        img_g[:,:,2],
                                        img_g[:,:,0])))

In [None]:
layout, cells = plotly_viz(mip_rgb, bb=bbox)

In [None]:
from extra.viz import plotly_predictions
ypred = np.zeros(len(bbox), dtype=np.int)
labels = ['cells']
layout, cells = plotly_predictions(img=mip_rgb, bb=bbox,
                                  ypred=ypred, labels=labels)

In [None]:
#iplot(dict(data=cells, layout=layout))

In [None]:
def get_train_instance(path, fname, pad=0):
    imgstack = load_imgstack(fname=os.path.join(path, fname + ".tiff"),
                            verbose=False)
    img = np.squeeze(imgstack)
    df = pd.read_csv(os.path.join(path, fname + ".csv"))
    rmax, cmax, _ = img.shape
    bbox = read_bbox(df=df, rmax=rmax,
                     cmax=cmax, pad=pad)
    return img, bbox

In [None]:
img, bbox = get_train_instance(path='../../data/AML_trainset/180528_Plate3',
                              fname='r02c14', pad=20)

**Here need to fix how we import `ImgX`**

In [None]:
sys.path.append('../../../../')
from bioimg.classify import ImgX

In [None]:
imgx_test = ImgX(img=img**0.4, bbox=bbox, n_chan=3)
imgx_test = imgx_test.compute_props()

In [None]:
# imgx_test.data[0].iloc[:10,:12]

**`IncrementalClassifier` is a class that has `ImgX` instance and can accumulate training data:**

In [None]:
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.ensemble import RandomForestClassifier

from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import re

class IncrementalClassifier:
    def __init__(self):
        # initialize with 'None' something to be loaded later
        self.imgx = None
        
        self.newlabels = None
        # training data
        self.Xtrain = None
        self.ytrain = None
        # inialize classifier as 'None'
        self.clf = None
        self.classes = None

    def __setattr__(self, name, value):
        self.__dict__[name] = value
        # if a new ImgX object is passed,
        # compute its features
        if name == 'imgx':
            self._compute_imgx_data()
            self.newlabels = None

    # function for setting individual class parameters
    def set_param(self, **kwargs):
        for k in kwargs.keys():
            self.__setattr__(k, kwargs[k])
         
    # internal function checks if the embedded
    # imgx object has the features computed
    def _compute_imgx_data(self):
        if self.imgx is not None and len(self.imgx.data) == 0:
            self.imgx.compute_props()
            
    # plot predictions overlaid with the original image
    # plot is a 'void' function (returns 'None')
    def plot_predictions(self):
        if self.imgx.y is not None:
            layout, feats = plotly_predictions(img=self.imgx.img,
                                                         bb=self.imgx.bbox,
                                                         ypred=self.imgx.y,
                                                         labels=self.classes
                                                        )
        else:
            layout, feats = plotly_viz(img=self.imgx.img,
                                                    bb=self.imgx.bbox)
        iplot(dict(data=feats, layout=layout))
        
    def add_instances(self, newlabels):
        newlabels = np.unique(newlabels, axis=0)
        if self.newlabels is None:
            self.newlabels = newlabels
        else:
            a1_rows = newlabels.view(
                [('', newlabels.dtype)] * newlabels.shape[1])
            a2_rows = self.newlabels.view(
                [('', self.newlabels.dtype)] * self.newlabels.shape[1])

            newlabels = (np.setdiff1d(a1_rows, a2_rows).
                         view(newlabels.dtype).
                         reshape(-1, newlabels.shape[1]))
            self.newlabels = np.append(self.newlabels, newlabels, axis=0)
        # if 'newlabels' array is not empty
        if len(newlabels) > 0:
            self._push_traindata(newlabels=newlabels)
        return self
    
    def _push_traindata(self, newlabels):
        ids = newlabels[:,0]
        if self.Xtrain is None:
            self.Xtrain = self.imgx.data.iloc[ids,:]
            self.ytrain = label_binarize(newlabels[:,1],
                                        classes=range(len(self.classes)))
        else:
            self.Xtrain = pd.concat([self.Xtrain, self.imgx.data.iloc[ids,:]], axis=0)
            self.ytrain = np.append(self.ytrain, label_binarize(newlabels[:,1],
                                     classes=range(len(self.classes))), axis=0)
                        
        
    def set_classifier(self, clf=None):
            self.clf = clf
            # if 'None' then some reasonable default
            if clf is None:
                self.clf = OneVsRestClassifier(RandomForestClassifier(bootstrap=True,
                                                  class_weight="balanced",
                                                  n_estimators=500,
                                                  random_state=123,
                                                  n_jobs=-1))
            return self
    
    def train_classifier(self):
        self.clf.fit(self.Xtrain, self.ytrain)
        return self
    
    # print the confusion matrix on the existing training set
    def train_error(self):
        ypred = self.clf.predict(self.Xtrain)
        print(classification_report(self.ytrain.argmax(axis=1),
                                    ypred.argmax(axis=1),
                                    target_names=self.classes))
        #print(confusion_matrix(self.ytrain.argmax(axis=1), ypred.argmax(axis=1),
        #                       labels=range(len(self.classes))))
        
    # generate predictions and pass them to self.imgx.y
    def generate_predictions(self, prob=False):
        Xtest = self.imgx.data
        ypred = self.clf.predict(Xtest)
        # set labels to these
        self.imgx.y =  ypred.argmax(axis=1)
    
    def h5_write(self, fname):
        hf = h5py.File(fname, 'w')
        hf.create_dataset('Xtrain', data=self.Xtrain)
        hf.create_dataset('ytrain', data=self.ytrain.argmax(axis=1))
        hf.close()

In [None]:
gamma = 0.4
# adjust brightness by gamma correction
img_g = img**gamma
# sort color channels in'RGB' order
img_rgb = np.dstack((img_g[:,:,1],
                     img_g[:,:,2],
                     img_g[:,:,0]))
# initialize 'ImgX' class
imgx = ImgX(img=img_rgb, bbox=bbox, n_chan=3)

In [None]:
clf_incr = IncrementalClassifier()

In [None]:
clf_incr.imgx = imgx
clf_incr.classes = ['apoptotic', 'viable', 'other']

In [None]:
#clf_incr.plot_predictions()

In [None]:
def make_labels(arr, label=1):
    return np.vstack((arr, label * np.ones(arr.shape, dtype=np.int))).T

In [None]:
viable = np.array([43, 41, 6, 29, 16, 48, 61, 69, 
                   59, 66, 73, 77, 80, 89, 94, 98,
                   112, 120, 119, 122, 100, 103, 57, 67, 55, 62])

apoptotic = np.array([123, 76, 82, 53, 47, 37, 30, 18])

other = np.array([12,22, 44, 34,1,51,38,19,10,17,0])

In [None]:
newlabels = np.concatenate((make_labels(viable, label=1),
              make_labels(apoptotic, label=0),
              make_labels(other, label=2)),
          axis=0)

In [None]:
clf_incr = clf_incr.add_instances(newlabels=newlabels)

In [None]:
clf_incr.Xtrain.shape

In [None]:
clf_incr.set_classifier().train_classifier()

In [None]:
#clf_incr.train_error()

In [None]:
clf_incr.generate_predictions()

In [None]:
#clf_incr.plot_predictions()

**Try loading a new image and replacing `imgx` in `IncrementalClassifier`**

In [None]:
img, bbox = get_train_instance(path='../../data/AML_trainset/180528_Plate3',
                              fname='r05c12', pad=20)

img_g = img**gamma
# sort color channels in'RGB' order
img_rgb = np.dstack((img_g[:,:,1],
                     img_g[:,:,2],
                     img_g[:,:,0]))
# initialize 'ImgX' class
imgx = ImgX(img=img_rgb, bbox=bbox, n_chan=3)

Update the `imgx` in `clf_incr`:

In [None]:
clf_incr.imgx = imgx

In [None]:
clf_incr.generate_predictions()

In [None]:
viable = np.array([21])
other = np.array([22,32,14,52,98,76])

In [None]:
#clf_incr.plot_predictions()

In [None]:
newlabels = np.concatenate((make_labels(viable, label=1),
              make_labels(other, label=2)),
          axis=0)

In [None]:
clf_incr = clf_incr.add_instances(newlabels=newlabels)

In [None]:
clf_incr.train_classifier()

In [None]:
clf_incr.generate_predictions()

In [None]:
#clf_incr.plot_predictions()

Load the next image:

In [None]:
img, bbox = get_train_instance(path='../../data/AML_trainset/180528_Plate3',
                              fname='r05c24', pad=20)

img_g = img**gamma
# sort color channels in'RGB' order
img_rgb = np.dstack((img_g[:,:,1],
                     img_g[:,:,2],
                     img_g[:,:,0]))
# initialize 'ImgX' class
imgx = ImgX(img=img_rgb, bbox=bbox, n_chan=3)

Update the `imgx` in `clf_incr`:

In [None]:
clf_incr.imgx = imgx

In [None]:
clf_incr.generate_predictions()

In [None]:
viable = np.array([19, 107, 61, 81, 77])
other = np.array([0,22,26])

In [None]:
newlabels = np.concatenate((make_labels(viable, label=1),
              make_labels(other, label=2)),
          axis=0)

In [None]:
clf_incr = clf_incr.add_instances(newlabels=newlabels)

In [None]:
clf_incr.train_classifier()

In [None]:
clf_incr.generate_predictions()

In [None]:
#clf_incr.plot_predictions()

Load the next image:

In [None]:
img, bbox = get_train_instance(path='../../data/AML_trainset/180528_Plate3',
                              fname='r06c16', pad=20)

img_g = img**gamma
# sort color channels in'RGB' order
img_rgb = np.dstack((img_g[:,:,1],
                     img_g[:,:,2],
                     img_g[:,:,0]))
# initialize 'ImgX' class
imgx = ImgX(img=img_rgb, bbox=bbox, n_chan=3)

In [None]:
clf_incr.imgx = imgx

In [None]:
clf_incr.generate_predictions()

In [None]:
other = np.array([22, 80, 105])

In [None]:
newlabels = make_labels(other, label=2)

In [None]:
clf_incr = clf_incr.add_instances(newlabels=newlabels)

In [None]:
clf_incr.train_classifier()

In [None]:
clf_incr.generate_predictions()

In [None]:
#clf_incr.plot_predictions()

In [None]:
img, bbox = get_train_instance(path='../../data/AML_trainset/180528_Plate3',
                              fname='r10c18', pad=20)

img_g = img**gamma
# sort color channels in'RGB' order
img_rgb = np.dstack((img_g[:,:,1],
                     img_g[:,:,2],
                     img_g[:,:,0]))
# initialize 'ImgX' class
imgx = ImgX(img=img_rgb, bbox=bbox, n_chan=3)

In [None]:
clf_incr.imgx = imgx

In [None]:
clf_incr.generate_predictions()

In [None]:
other = np.array([57,8, 86,97])

In [None]:
newlabels = make_labels(other, label=2)

In [None]:
clf_incr = clf_incr.add_instances(newlabels=newlabels)

In [None]:
clf_incr.train_classifier()

In [None]:
clf_incr.generate_predictions()

In [None]:
#clf_incr.plot_predictions()

In [None]:
img, bbox = get_train_instance(path='../../data/AML_trainset/180528_Plate3',
                              fname='r11c06', pad=20)

img_g = img**gamma
# sort color channels in'RGB' order
img_rgb = np.dstack((img_g[:,:,1],
                     img_g[:,:,2],
                     img_g[:,:,0]))
# initialize 'ImgX' class
imgx = ImgX(img=img_rgb, bbox=bbox, n_chan=3)

In [None]:
clf_incr.imgx = imgx

In [None]:
clf_incr.generate_predictions()

In [None]:
viable = np.array([16,57])
apoptotic = np.array([8])
other = np.array([5, 82,83, 47, 23])

In [None]:
newlabels = np.concatenate((make_labels(viable, label=1),
              make_labels(apoptotic, label=0),
              make_labels(other, label=2)),
          axis=0)

In [None]:
clf_incr = clf_incr.add_instances(newlabels=newlabels)

In [None]:
clf_incr.train_classifier()

In [None]:
clf_incr.generate_predictions()

In [None]:
#clf_incr.plot_predictions()

In [None]:
img, bbox = get_train_instance(path='../../data/AML_trainset/180528_Plate3',
                              fname='r12c10', pad=20)

img_g = img**gamma
# sort color channels in'RGB' order
img_rgb = np.dstack((img_g[:,:,1],
                     img_g[:,:,2],
                     img_g[:,:,0]))
# initialize 'ImgX' class
imgx = ImgX(img=img_rgb, bbox=bbox, n_chan=3)

In [None]:
clf_incr.imgx = imgx

In [None]:
clf_incr.generate_predictions()

In [None]:
other = np.array([56])

In [None]:
newlabels = make_labels(other, label=2)

In [None]:
clf_incr = clf_incr.add_instances(newlabels=newlabels)

In [None]:
clf_incr.train_classifier()

In [None]:
clf_incr.generate_predictions()

In [None]:
#clf_incr.plot_predictions()

Add the last image:

In [None]:
img, bbox = get_train_instance(path='../../data/AML_trainset/180528_Plate3',
                              fname='r16c02', pad=20)

img_g = img**gamma
# sort color channels in'RGB' order
img_rgb = np.dstack((img_g[:,:,1],
                     img_g[:,:,2],
                     img_g[:,:,0]))
# initialize 'ImgX' class
imgx = ImgX(img=img_rgb, bbox=bbox, n_chan=3)

In [None]:
clf_incr.imgx = imgx

In [None]:
clf_incr.generate_predictions()

In [None]:
#clf_incr.plot_predictions()

In [None]:
viable = np.array([3, 126])

In [None]:
newlabels = make_labels(viable, label=1)

In [None]:
clf_incr = clf_incr.add_instances(newlabels=newlabels)

In [None]:
clf_incr.train_classifier()

In [None]:
clf_incr.generate_predictions()

In [None]:
#clf_incr.plot_predictions()

In [None]:
clf_incr.Xtrain.shape

In [None]:
clf_incr.train_error()