In [147]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage.filters import laplace
from skimage import feature
from itertools import chain

from load import *

plt.rcParams['figure.figsize'] = (16, 16)

In [152]:
class TestSet(): 
    def __init__(self):
        files = []
        for j in range(3,5):
            prefix = '../data/AMG3_exp%d'%(j)
            files.append( (prefix+'.tif', prefix+'.zip') )
        self.data = []
        self.labels = []
        for i,(s,r) in enumerate(files):
            self.data.append(load_stack(s))
            self.labels.append(load_rois(r, 512, 512))

def calc_f1_score((precision, recall)):
    return (2 * precision * recall) / (precision + recall)

def score(alg):
    test_set = TestSet()
    #predictions = alg().predict(test_set.data)
    predictions = test_set.labels
    assert(all([predictions[i].shape[1] == test_set.labels[i].shape[1] for i in range(len(test_set.labels))]))
    assert(all([np.all(np.logical_or(predictions[i] == 1, predictions[i] == 0)) for i in range(len(predictions))]))
    categorized = categorize(predictions, test_set.labels)
    precisions, total_precision, recalls, total_recall = calc_precision_recall(categorized)
    f1_scores = map(calc_f1_score, zip(precisions, recalls))
    total_f1_score = calc_f1_score((total_precision, total_recall))
    overlap_bqs, total_overlap_bq = overlap_boundary_quality(categorized)
    print precisions, total_precision
    print recalls, total_recall
    print f1_scores, total_f1_score
    print overlap_bqs, total_overlap_bq

def calc_precision_recall(categorized):
    num_fps = [len(categorized[i]["fps"]) for i in range(len(categorized))]
    num_fns = [len(categorized[i]["fns"]) for i in range(len(categorized))]
    num_pairs = [len(categorized[i]["pairs"]) for i in range(len(categorized))]
    precisions = [num_pairs[i] / float(num_pairs[i] + num_fps[i]) for i in range(len(categorized))]
    recalls = [num_pairs[i] / float(num_pairs[i] + num_fns[i]) for i in range(len(categorized))]
    total_precision = sum(num_pairs) / float(sum(num_pairs) + sum(num_fps))
    total_recall = sum(num_pairs) / float(sum(num_pairs) + sum(num_fns))
    return precisions, total_precision, recalls, total_recall
    
def categorize(predictions, labels):
    categorized = []
    for i in range(len(predictions)):
        categorized.append({"fps":[], "fns":[], "pairs":[]})
        rois_pred, rois_true = list(predictions[i].copy()), list(labels[i].copy())
        for roi_pred in rois_pred:
            overlaps = map(lambda roi_true: calc_overlap(roi_pred, roi_true)[0], rois_true)
            best_overlap, best_match = np.max(overlaps), rois_true[np.argmax(overlaps)]
            if best_overlap > 0.5:
                categorized[i]["pairs"].append((roi_pred, best_match))
                rois_true.remove(best_match)
            else:
                categorized[i]["fps"].append(roi_pred)
        for roi_true in rois_true:
            categorized[i]["fns"].append(roi_true)
    return categorized
        
def calc_overlap(roi_pred, roi_true):
    intersection = np.sum(np.logical_and(roi_pred, roi_true))
    union = np.sum(np.logical_or(roi_pred, roi_true))
    if union == 0: 
        return 0, 0, 0
    precision = intersection / float(np.sum(roi_pred))
    recall = intersection / float(np.sum(roi_true))
    general = intersection / float(union)
    return general, precision, recall
    
def overlap_boundary_quality(categorized):
    qualities = []
    precisions = []
    recalls = []
    for i in range(len(categorized)):
        precisions.append([])
        recalls.append([])
        for roi_pred, roi_true in categorized[i]["pairs"]:
            _, precision, recall = calc_overlap(roi_pred, roi_true)
            precisions[i].append(precision)
            recalls[i].append(recall)
        qualities.append({"mean precision": np.mean(precisions[i]), "std precision": np.std(precisions[i]),
                          "mean recall": np.mean(recalls[i]), "std recall": np.std(recalls[i])})
    overall = {"mean precision": np.mean(list(chain.from_iterable(precisions))),
               "std precision": np.std(list(chain.from_iterable(precisions))),
               "mean recall": np.mean(list(chain.from_iterable(recalls))),
               "std recall": np.std(list(chain.from_iterable(recalls)))}
    return qualities, overall   

In [153]:
score(TestAlg)

[1.0, 1.0] 1.0
[1.0, 1.0] 1.0
[1.0, 1.0] 1.0
[{'std recall': 0.0, 'mean precision': 1.0, 'std precision': 0.0, 'mean recall': 1.0}, {'std recall': 0.0, 'mean precision': 1.0, 'std precision': 0.0, 'mean recall': 1.0}] {'std recall': 0.0, 'mean precision': 1.0, 'std precision': 0.0, 'mean recall': 1.0}


In [68]:
m = np.zeros((2,3,3))
n = np.zeros((2,3,3))
m[0,0,0] = 1
n[0,0,0] = 1
n[1,2,2] = 1
m[1,1,1] = 1

In [69]:
print m
print n
print
print
categorize(m,n)

[[[ 1.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]

 [[ 0.  0.  0.]
  [ 0.  1.  0.]
  [ 0.  0.  0.]]]
[[[ 1.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]

 [[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  1.]]]






[{'fns': [array([ 0.,  0.,  0.]), array([ 0.,  0.,  0.])],
  'fps': [array([ 0.,  0.,  0.]), array([ 0.,  0.,  0.])],
  'pairs': [(array([ 1.,  0.,  0.]), array([ 1.,  0.,  0.]))]},
 {'fns': [array([ 0.,  0.,  0.]),
   array([ 0.,  0.,  0.]),
   array([ 0.,  0.,  1.])],
  'fps': [array([ 0.,  0.,  0.]),
   array([ 0.,  1.,  0.]),
   array([ 0.,  0.,  0.])],
  'pairs': []}]