## Validate Object Detection on the Test Set

In [None]:
# load third-party Python modules
import javabridge
import bioformats as bf
import skimage
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sn
import pandas as pd
import os
import sys
import h5py
sys.path.append('..')

javabridge.start_vm(class_path=bf.JARS)

In [None]:
path = '/Users/vlad/Documents/embl/gitlab/microscopy/data/detect/test/'

In [None]:
imgs = [f.replace('.csv', '') for f in os.listdir(path) if '.csv' in f]

In [None]:
len(imgs)

In [None]:
test_gtruth = pd.read_csv('data/AML_trainset/test_labels.csv')

In [None]:
from collections import namedtuple
Box = namedtuple('Box', 'xmin xmax ymin ymax')

def area_overlap(a, b):  # returns None if rectangles don't intersect
    dx = min(a.xmax, b.xmax) - max(a.xmin, b.xmin)
    dy = min(a.ymax, b.ymax) - max(a.ymin, b.ymin)
    if (dx>=0) and (dy>=0):
        return dx*dy

In [None]:
i=0
df = pd.read_csv(os.path.join(path, imgs[i] + ".csv"))

In [None]:
from segment.tools import read_bbox
bbox_gt = read_bbox(df=test_gtruth[test_gtruth.filename ==''.join([imgs[i], '.png'])], 
             columns=['ymin','xmin','ymax','xmax'],
             rmax=720,
             cmax=720, pad=0)

In [None]:
bbox = read_bbox(df=df, 
             columns=['ymin','xmin','ymax','xmax'],
             rmax=720,
             cmax=720, pad=0)

In [None]:
y_gt = test_gtruth[test_gtruth.filename ==''.join([imgs[i], '.png'])]['class'].values
y_gt[y_gt=='apoptotic AML'] = 1
y_gt[y_gt=='viable AML'] = 2

In [None]:
bbox_gt[0]

In [None]:
df_list = []
for bt,y in zip(bbox_gt, y_gt):
    for b in bbox:
        A_common = area_overlap(Box(*bt), Box(*b))
        if A_common is not None:
            if A_common > 200:
                label_df = pd.DataFrame(data=b[None,...],
                                        columns=['xmin', 'xmax', 'ymin', 'ymax'])
                label_df['y'] = y
                label_df['Acom'] = A_common
                df_list.append(label_df)
if(len(df_list)):
    df_unique = (pd.concat(df_list).
                 groupby(by=['ymin', 'xmin', 'ymax', 'xmax']).
                 apply(lambda x: x.loc[x.Acom == x.Acom.max(),:]).
                 reset_index(drop=True))
    img_df = pd.merge(left=df, right=df_unique,
         on=['ymin', 'xmin', 'ymax', 'xmax'], how='left')
    img_df = img_df[['ymin', 'xmin', 'ymax', 'xmax', 'y', 'class', 'prob']]
    #img_df = img_df.fillna(0)

In [None]:
pred_df = []
for i in range(len(imgs)):
    df = pd.read_csv(os.path.join(path, imgs[i] + ".csv"))
    bbox_gt = read_bbox(df=test_gtruth[test_gtruth.filename ==''.join([imgs[i], '.png'])], 
             columns=['ymin','xmin','ymax','xmax'],
             rmax=720,
             cmax=720, pad=0)
    bbox = read_bbox(df=df, 
             columns=['ymin','xmin','ymax','xmax'],
             rmax=720,
             cmax=720, pad=0)
    y_gt = test_gtruth[test_gtruth.filename ==''.join([imgs[i], '.png'])]['class'].values
    y_gt[y_gt=='apoptotic AML'] = 1
    y_gt[y_gt=='viable AML'] = 2
    df_list = []
    for bt,y in zip(bbox_gt, y_gt):
        for b in bbox:
            A_common = area_overlap(Box(*bt), Box(*b))
            if A_common is not None:
                if A_common > 200:
                    label_df = pd.DataFrame(data=b[None,...],
                                            columns=['xmin', 'xmax', 'ymin', 'ymax'])
                    label_df['y'] = y
                    label_df['Acom'] = A_common
                    df_list.append(label_df)
    if(len(df_list)):
        df_unique = (pd.concat(df_list).
                     groupby(by=['ymin', 'xmin', 'ymax', 'xmax']).
                     apply(lambda x: x.loc[x.Acom == x.Acom.max(),:]).
                     reset_index(drop=True))
        img_df = pd.merge(left=df, right=df_unique,
             on=['ymin', 'xmin', 'ymax', 'xmax'], how='left')
        img_df = img_df[['ymin', 'xmin', 'ymax', 'xmax', 'y', 'class', 'prob']]
        pred_df.append(img_df)

In [None]:
pred_df = pd.concat(pred_df)

In [None]:
pred_df['p_apoptotic'] = pred_df.apply(lambda x: x['prob'] if x['class'] == 1 else 1.0 - x['prob'], axis=1)

In [None]:
pred_df['p_viable'] = pred_df.apply(lambda x: x['prob'] if x['class'] == 2 else 1.0 - x['prob'], axis=1)

In [None]:
pred_df['y'] = pred_df['y'].fillna(value=0)

In [None]:
pred_df.head(20)

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
confusion_matrix(y_pred=pred_df['class'], y_true=pred_df['y'])

In [None]:
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_score
from sklearn.preprocessing import label_binarize

In [None]:
y_test = label_binarize(pred_df['y'].values, [0,1,2])
probas_ = pred_df[['p_apoptotic', 'p_viable']].values

In [None]:
fpr = dict()
tpr = dict()
roc_auc = dict()
precision = dict()
recall = dict()
avprec = dict()
for i in range(1,3):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], probas_[:, i-1])
    roc_auc[i] = auc(fpr[i], tpr[i])
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                        probas_[:, i-1])
    avprec[i] = average_precision_score(y_test[:, i], probas_[:, i-1])

In [None]:
import matplotlib
font = {'family' : 'normal',
        'size'   : 14}

matplotlib.rc('font', **font)

In [None]:
from itertools import cycle
colors = cycle(['#27496d','#63b7af'])
class_names = ['','Apoptotic AML', 'Viable AML']

plt.figure(figsize=(7, 7))
f_scores = np.linspace(0.7, 0.96, num=5)

for f_score in f_scores:
    x = np.linspace(0.01, 1)
    y_ = f_score * x / (2 * x - f_score)
    plt.plot(x[y_ >= 0], y_[y_ >= 0], color='gray', alpha=0.2)
    plt.annotate('F1={0:0.1f}'.format(f_score), xy=(0.9, y_[45] + 0.02))
for i, color in zip(range(1,3), colors):
    plt.plot(recall[i], precision[i], color=color, lw=4,
             label='{0} (AP = {1:0.2f})'
             ''.format(class_names[i], avprec[i]))

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-recall object detection')
plt.legend(loc="lower left")
#plt.savefig('PR-objdetect.pdf')

In [None]:
plt.figure(figsize=(7, 7))
for i, color in zip(range(1,3), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=4,
             label='{0} (AUCROC = {1:0.2f})'
             ''.format(class_names[i], roc_auc[i]))

plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC object detection')
plt.legend(loc="lower right")
#plt.savefig('ROC-objdetect.pdf')