## Label image data in plate 180528_Plate3
*Vladislav Kim*


* [Introduction](#1)
* [Initial training set](#2)

<a id="1"></a> 
## Introduction
The idea of this notebook series is to train a pseudo-online random forest classifier for AML vs stroma cell classification. From selected plates we (for now) sample 6 DMSO wells with the highest Calcein cell count, generate predictions and correct misclassified instances and check in live ("online") mode how the predictions improve as we add more data. Note that the classifier is not truly an online classifier as we don't update the model as we go, but completely retrain the RF classfier in multicore mode.

In general we can implement targeted online learning strategy: we can select a number of wells that are of interest to us (target wells), e.g. DMSO control wells or wells with certain high-priority drugs, the accuracy of which we want to improve, in the first place. We sample from these target wells from selected plates and evaluate the classification accuracy as we go (pseudo online learning).


<a id="2"></a>
## Initial Training Set: 180528_Plate3
At first we will re-train the classifier on the plate `180528_Plate3`, as it manifests a very striking contrast between mono- and co-cultures. We want to rule out the fact that this could be a segmentation (classification in this case) artefact 

In [None]:
# load third-party Python modules
import javabridge
import bioformats as bf
import skimage
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sn
import pandas as pd
import os
import sys
sys.path.append('../../..')

javabridge.start_vm(class_path=bf.JARS)

In [None]:
from base.utils import load_imgstack
imgstack = load_imgstack(fname="../../data/AML_trainset/180528_Plate3/r02c14.tiff")

# remove a 'dummy' z-axis
img = np.squeeze(imgstack)

# nuclei
hoechst = img[:,:,0]**0.3

In [None]:
df = pd.read_csv('../../data/AML_trainset/180528_Plate3/r02c14.csv')

In [None]:
from segment.tools import read_bbox
rmax, cmax = hoechst.shape

bbox = read_bbox(df=df, rmax=rmax, cmax=cmax)

In [None]:
from base.plot import show_bbox
#show_bbox(hoechst, bbox)

**Plotly visualization works!**

In [None]:
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.graph_objs as go
init_notebook_mode(connected=True)

In [None]:
from extra.viz import plotly_viz

In [None]:
from skimage.exposure import equalize_adapthist
gamma = 0.3
img_g = img**gamma
mip_rgb = equalize_adapthist(np.dstack((img_g[:,:,1],
                                        img_g[:,:,2],
                                        img_g[:,:,0])))

In [None]:
layout, cells = plotly_viz(mip_rgb, bb=bbox)

In [None]:
from extra.viz import plotly_predictions
ypred = np.zeros(len(bbox), dtype=np.int)
labels = ['cells']
layout, cells = plotly_predictions(img=mip_rgb, bb=bbox,
                                  ypred=ypred, labels=labels)

In [None]:
#iplot(dict(data=cells, layout=layout))

In [None]:
def get_train_instance(path, fname, pad=0):
    imgstack = load_imgstack(fname=os.path.join(path, fname + ".tiff"),
                            verbose=False)
    img = np.squeeze(imgstack)
    df = pd.read_csv(os.path.join(path, fname + ".csv"))
    rmax, cmax, _ = img.shape
    bbox = read_bbox(df=df, rmax=rmax,
                     cmax=cmax, pad=pad)
    return img, bbox

In [None]:
img, bbox = get_train_instance(path='../../data/AML_trainset/180528_Plate3',
                              fname='r02c14', pad=0)

**Here need to fix how we import `ImgX`**

In [None]:
sys.path.append('../../../../')
from bioimg.classify import ImgX

In [None]:
imgx_test = ImgX(img=img**0.4, bbox=bbox)
imgx_test = imgx_test.compute_props(n_chan=3)

In [None]:
# imgx_test.data[0].iloc[:10,:12]

**Implement a class that has `ImgX` instance and can accumulate training data:**
+ How to do proper class composition, i.e. `IncrementalClassifier` **has** `ImgX`


In [None]:
class IncrementalClassifier:
    def __init__(self, imgx):
        # initialize with 'None' something to be loaded later
        self.imgx = imgx
        self.newlabels = None
        # training data
        self.Xtrain = None
        # inialize classifier as 'None'
        self.clf = None
        self.pca = None

    def __setattr__(self, name, value):
        self.__dict__[name] = value

    # function for setting individual class parameters
    def set_param(self, **kwargs):
        for k in kwargs.keys():
            self.__setattr__(k, kwargs[k])
            
    def set_classifier(self, n_components=150,
                         bootstrap=True,
                         max_depth=None,
                         n_estimators=100,
                         max_features='sqrt',
                         min_samples_split=5,
                         min_samples_leaf=5,
                         random_state=100):
        self.pca = PCA(n_components=n_components, svd_solver='randomized',
                       whiten=True, random_state=random_state).fit(self.X_train_norm)
        
        self.clf = RandomForestClassifier(bootstrap=bootstrap,
                                          class_weight="balanced",
                                          max_depth=max_depth,
                                          n_estimators=n_estimators,
                                          max_features=max_features,
                                          min_samples_split=min_samples_split,
                                          min_samples_leaf=min_samples_leaf,
                                          random_state=random_state,
                                          n_jobs=-1)
        return self
    
    def train_classifier(self):
        # project the train data
        X_train_pca = self.pca.transform(self.X_train_norm)
        X_train_all = np.append(X_train_pca, self.X_train_prop, axis=1)

        self.clf.fit(X_train_all, self.y_train)
        return self

    # generate predictions with the loaded feature data
    def generate_predictions(self, prob=False):
        if self.X_test_norm is None:
            cellbb_norm = [resize(cb, (self.w, self.h), anti_aliasing=True)
                           for cb in self.cellbb]
            self.X_test_norm = np.array([cbn.ravel() for cbn in cellbb_norm])

        if self.X_test_prop is None:
            X_prop_list = [get_regionprop_feats(
                mip_rgb=cbb, exclude=exclude) for cbb in self.cellbb]
            self.X_test_prop = np.vstack(X_prop_list)

        X_test_pca = self.pca.transform(self.X_test_norm)
        X_test_all = np.append(X_test_pca, self.X_test_prop, axis=1)

        self.y_pred = self.clf.predict(X_test_all)
        if prob:
            self.y_prob = self.clf.predict_proba(X_test_all)
        return self

    # set plotly graphical layers
    def set_scene(self):
        # derived object attributes
        wellpos = convert_well_name(self.select_well)

        mip = fn.get_mip(path=self.path, wellpos=wellpos)
        mip_rgb = equalize_adapthist(np.dstack((mip[:, :, 1],
                                                mip[:, :, 2],
                                                mip[:, :, 0])))
        self.img = mip_rgb
        if self.y_pred is not None:
            self.layout, self.feats = vi.plotly_predictions(img=mip_rgb,
                                                            bb=self.bb,
                                                            y_pred=self.y_pred,
                                                            target_names=self.target_names)
        else:
            self.layout, self.feats = vi.plotly_viz(img=mip_rgb,
                                                    bb=self.bb)
        return self

    # plot predictions overlaid with the original image
    # plot is a 'void' function (returns 'None')
    def plot(self):
        iplot(dict(data=self.feats, layout=self.layout))

    # update scene after refitting the pipeline
    def update_scene(self):
        self.feats = vi.update_feats(img=self.img,
                                     bb=self.bb,
                                     y_pred=self.y_pred,
                                     target_names=self.target_names)
        return self

    def add_instances(self, newlabels):
        if self.newlabels is None:
            self.newlabels = newlabels
        else:
            a1_rows = newlabels.view(
                [('', newlabels.dtype)] * newlabels.shape[1])
            a2_rows = self.newlabels.view(
                [('', self.newlabels.dtype)] * self.newlabels.shape[1])

            self.newlabels = newlabels
            newlabels = (np.setdiff1d(a1_rows, a2_rows).
                         view(newlabels.dtype).
                         reshape(-1, newlabels.shape[1]))

        if newlabels.size:
            cellbb_new = [self.cellbb[i] for i in newlabels[:, 0]]
            cellbb_new_norm = [resize(cb, (self.w, self.h), anti_aliasing=True)
                               for cb in cellbb_new]
            X_new_norm = np.array([cbn.ravel() for cbn in cellbb_new_norm])

            self.X_train_norm = np.concatenate((self.X_train_norm, X_new_norm))
            self.y_train = np.append(self.y_train, newlabels[:, 1])

            X_prop_list = [get_regionprop_feats(
                mip_rgb=cbb, exclude=exclude) for cbb in cellbb_new]
            X_new_prop = np.vstack(X_prop_list)
            self.X_train_prop = np.concatenate((self.X_train_prop, X_new_prop))
        return self

    # print the confusion matrix on the existing training set
    def get_classifiction_report(self):
        X_train_pca = self.pca.transform(self.X_train_norm)
        X_train_all = np.append(X_train_pca, self.X_train_prop, axis=1)
        y_pred = self.clf.predict(X_train_all)
        print(classification_report(self.y_train,
                                    y_pred, target_names=self.target_names))
        print(confusion_matrix(self.y_train, y_pred,
                               labels=range(len(self.target_names))))
        return self

    def get_cross_val_score(self, kfold=5):
        X_train_pca = self.pca.transform(self.X_train_norm)
        X_train_all = np.append(X_train_pca, self.X_train_prop, axis=1)
        scores = cross_val_score(self.clf, X_train_all, self.y_train, cv=kfold)
        print scores

        return self

    def reset(self):
        self.newlabels = None
        self.X_test_norm = None
        self.X_test_prop = None
        return self

    def h5_write(self, fname):
        hf = h5py.File(fname, 'w')
        hf.create_dataset('X_train_norm', data=self.X_train_norm)
        hf.create_dataset('X_train_prop', data=self.X_train_prop)
        hf.create_dataset('y_train', data=self.y_train)
        hf.close()


In [None]:
'''# old version of the function
def compute_props(self):
        X_prop_list =  [OT.get_regionprop_feats(mip_rgb=cbb,
                                                exclude=exclude) for cbb in cellbb_train]
        X_train_prop = np.vstack(X_prop_list)


        cellbb_norm = [resize(cb, (w, h), anti_aliasing=True) for cb in cellbb_train]
        X_train_norm = np.array([cbn.ravel() for cbn in cellbb_norm])
        # compute PCA of the image data set
        n_components = 150
        pca = PCA(n_components=n_components, svd_solver='randomized',
                  whiten=True).fit(X_train_norm)

        # project the train data
        X_train_pca = pca.transform(X_train_norm)

        X_train_all = np.append(X_train_pca, X_train_prop, axis=1)
        
        return self'''

**Modify `IncrementalClassifier` class to adapt to our use**

In [None]:
# incremental ("online") classifier
clf_incr = OT.IncrementalClassifier(path=path, featdir=featdir,
                                 select_well=select_inst[0],
                                 target_names=target_names,
                                 X_train_norm=X_train_norm,
                                 X_train_prop=X_train_prop,
                                 y_train=y_train
                                )

In [None]:
clf_incr = (clf_incr.load_img().
            train_classifier().
           generate_predictions().
           set_scene())

In [None]:
#clf_incr.plot()

In [None]:
newlabels = np.array([[45,2], [91,5], [85,0], [2,2]])

In [None]:
clf_incr = (clf_incr.
            add_instances(newlabels=newlabels).
            train_classifier().
            generate_predictions().update_scene())