Skip to content

Commit

Permalink
Merge 03698f2 into 2d01948
Browse files Browse the repository at this point in the history
  • Loading branch information
ngreenwald committed Apr 7, 2020
2 parents 2d01948 + 03698f2 commit d2ab686
Show file tree
Hide file tree
Showing 3 changed files with 408 additions and 86 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ python:
- 3.7

env:
- TF_VERSION=1.14.0
- TF_VERSION=1.14.0
- TF_VERSION=1.15.0

install:
Expand Down
206 changes: 174 additions & 32 deletions deepcell/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@
from __future__ import print_function
from __future__ import division

from collections import Counter
import datetime
import decimal
import glob
import json
import operator
import os
Expand All @@ -53,16 +51,17 @@
import pandas as pd
import networkx as nx

from scipy.optimize import linear_sum_assignment
import matplotlib as mpl
import matplotlib.pyplot as plt

import skimage.io
from scipy.optimize import linear_sum_assignment
from skimage.measure import regionprops
from skimage.segmentation import relabel_sequential
from skimage.external.tifffile import TiffFile
from sklearn.metrics import confusion_matrix
from tensorflow.python.platform import tf_logging as logging

from deepcell.utils.compute_overlap import compute_overlap
from deepcell_toolbox import erode_edges


def stats_pixelbased(y_true, y_pred):
Expand Down Expand Up @@ -150,21 +149,21 @@ class ObjectAccuracy(object): # pylint: disable=useless-object-inheritance
analysis during testing
seg (:obj:`bool`, optional): Calculates SEG score for cell tracking
competition
force_event_links(:obj:'bool, optional): Flag that determines whether to modify IOU
calculation so that merge or split events with cells of very different sizes are
never misclassified as misses/gains.
Raises:
ValueError: If y_true and y_pred are not the same shape
Warning:
Position indicies are not currently collected appropriately
"""
# TODO: Implement recording of object indices for each error group
def __init__(self,
y_true,
y_pred,
cutoff1=0.4,
cutoff2=0.1,
test=False,
seg=False):
seg=False,
force_event_links=False):
self.cutoff1 = cutoff1
self.cutoff2 = cutoff2
self.seg = seg
Expand Down Expand Up @@ -211,14 +210,19 @@ def __init__(self,
self.gained_indices = {}
self.gained_indices['y_pred'] = []

self.merge_indices = {}
self.merge_indices['y_true'] = []
self.merge_indices = {
'y_true': [],
'y_pred': []
}

self.split_indices = {}
self.split_indices['y_true'] = []
self.split_indices = {
'y_true': [],
'y_pred': []
}

self.catastrophe_indices = {}
self.catastrophe_indices['y_true'] = []
self.catastrophe_indices = {
'y_true': []
}
self.catastrophe_indices['y_pred'] = []

# Check if either frame is empty before proceeding
Expand All @@ -233,6 +237,7 @@ def __init__(self,
elif test is False:
self.empty_frame = False
self._calc_iou()
self._modify_iou(force_event_links)
self._make_matrix()
self._linear_assignment()

Expand Down Expand Up @@ -293,6 +298,49 @@ def get_box_labels(images):
(intersection.sum() > 0.5 * np.sum(self.y_true == index)):
self.seg_thresh[iou_y_true_idx - 1, iou_y_pred_idx - 1] = 1

def _modify_iou(self, force_event_links):
"""Modifies the IOU matrix to boost the value for small cells.
Args:
force_event_links (:obj:`bool'): flag that determines whether to modify IOU values of
large cells if a small cell has been split or merged with them.
"""

# identify cells that have matches in IOU but may be too small
true_labels, pred_labels = np.where(np.logical_and(self.iou > 0,
self.iou < (1 - self.cutoff1)))

self.iou_modified = self.iou.copy()

for idx in range(len(true_labels)):
# add 1 to get back to original label id
true_label, pred_label = true_labels[idx] + 1, pred_labels[idx] + 1
true_mask = self.y_true == true_label
pred_mask = self.y_pred == pred_label

# fraction of true cell that is contained within pred cell, vice versa
true_in_pred = np.sum(self.y_true[pred_mask] == true_label) / np.sum(true_mask)
pred_in_true = np.sum(self.y_pred[true_mask] == pred_label) / np.sum(pred_mask)

iou_val = self.iou[true_label - 1, pred_label - 1]
max_val = np.max([true_in_pred, pred_in_true])

# if this cell has a small IOU due to its small size,
# but is at least half contained within the big cell,
# we bump its IOU value up so it doesn't get dropped from the graph
if iou_val <= self.cutoff1 and max_val > 0.5:
self.iou_modified[true_label - 1, pred_label - 1] = self.cutoff2

# optionally, we can also decrease the IOU value of the cell that
# swallowed up the small cell so that it doesn't directly match a different cell
if force_event_links:
if true_in_pred > 0.5:
fix_idx = np.where(self.iou[:, pred_label - 1] > 1 - self.cutoff1)
self.iou_modified[fix_idx, pred_label - 1] = 1 - self.cutoff1 - 0.01
elif pred_in_true > 0.5:
fix_idx = np.where(self.iou[true_label - 1, :] > 1 - self.cutoff1)
self.iou_modified[true_label - 1, fix_idx] = 1 - self.cutoff1 - 0.01

def _make_matrix(self):
"""Assembles cost matrix using the iou matrix and cutoff1
Expand All @@ -306,8 +354,8 @@ def _make_matrix(self):
self.cm = np.ones((self.n_obj, self.n_obj))

# Assign 1 - iou to top left and bottom right
self.cm[:self.n_true, :self.n_pred] = 1 - self.iou
self.cm[-self.n_pred:, -self.n_true:] = 1 - self.iou.T
self.cm[:self.n_true, :self.n_pred] = 1 - self.iou_modified
self.cm[-self.n_pred:, -self.n_true:] = 1 - self.iou_modified.T

# Calculate diagonal corners
bl = self.cutoff1 * \
Expand Down Expand Up @@ -339,8 +387,8 @@ def _linear_assignment(self):
# Identify direct matches as true positives
correct_index = np.where(self.cm_res[:self.n_true, :self.n_pred] == 1)
self.correct_detections += len(correct_index[0])
self.correct_indices['y_true'].append(correct_index[0])
self.correct_indices['y_pred'].append(correct_index[1])
self.correct_indices['y_true'] += list(correct_index[0] + 1)
self.correct_indices['y_pred'] += list(correct_index[1] + 1)

# Calc seg score for true positives if requested
if self.seg is True:
Expand All @@ -366,9 +414,9 @@ def _assign_loners(self):

for i, t in enumerate(self.loners_true):
for j, p in enumerate(self.loners_pred):
self.cost_l[i, j] = self.iou[t, p]
self.cost_l[i, j] = self.iou_modified[t, p]

self.cost_l_bin = self.cost_l > self.cutoff2
self.cost_l_bin = self.cost_l >= self.cutoff2

def _array_to_graph(self):
"""Transform matrix for unassigned cells into a graph object
Expand Down Expand Up @@ -422,8 +470,8 @@ def _classify_graph(self):
# Get the highest degree node
k = max(dict(g.degree).items(), key=operator.itemgetter(1))[0]

# Map index back to original cost matrix
index = int(k.split('_')[-1])
# Map index back to original cost matrix, adjust for 1-based indexing in labels
index = int(k.split('_')[-1]) + 1
# Process degree 0 nodes
if g.degree[k] == 0:
if 'pred' in k:
Expand All @@ -436,7 +484,7 @@ def _classify_graph(self):
# Process degree 1 nodes
if g.degree[k] == 1:
for node in g.nodes:
node_index = int(node.split('_')[-1])
node_index = int(node.split('_')[-1]) + 1
if 'pred' in node:
self.gained_detections += 1
self.gained_indices['y_pred'].append(node_index)
Expand All @@ -458,21 +506,27 @@ def _classify_graph(self):
if 'pred' in node_type:
self.merge += 1
self.missed_det_from_merge += len(nodes) - 2
merge_indices = [int(node.split('_')[-1])
for node in nodes if 'true' in node]
self.merge_indices['y_true'] += merge_indices
true_merge_indices = [int(node.split('_')[-1]) + 1
for node in nodes if 'true' in node]
self.merge_indices['y_true'] += true_merge_indices
self.merge_indices['y_pred'].append(index)
# Check for splits
elif 'true' in node_type:
self.split += 1
self.gained_det_from_split += len(nodes) - 2
self.split_indices['y_true'].append(index)
pred_split_indices = [int(node.split('_')[-1]) + 1
for node in nodes if 'pred' in node]
self.split_indices['y_pred'] += pred_split_indices

# If there are multiple types of the high degree node,
# then we have a catastrophe
else:
self.catastrophe += 1
true_indices = [int(node.split('_')[-1]) for node in nodes if 'true' in node]
pred_indices = [int(node.split('_')[-1]) for node in nodes if 'pred' in node]
true_indices = [int(node.split('_')[-1]) + 1
for node in nodes if 'true' in node]
pred_indices = [int(node.split('_')[-1]) + 1
for node in nodes if 'pred' in node]

self.true_det_in_catastrophe = len(true_indices)
self.pred_det_in_catastrophe = len(pred_indices)
Expand All @@ -499,7 +553,7 @@ def _classify_graph(self):
split_label_image = np.zeros_like(self.y_true)
for l in self.split_indices['y_true']:
split_label_image[self.y_true == l] = l
self.split_props = regionprops(merge_label_image)
self.split_props = regionprops(split_label_image)

def print_report(self):
"""Print report of error types and frequency
Expand Down Expand Up @@ -544,6 +598,20 @@ def save_to_dataframe(self):

return df

def save_error_ids(self):
"""Saves the ids of cells in each error category for subsequent visualization
Returns:
error_dict: dictionary containing {category_name: id list} pairs
"""

error_dict = {"splits": self.split_indices, "merges": self.merge_indices,
"gains": self.gained_indices, "misses": self.missed_indices,
"catastrophes": self.catastrophe_indices,
"correct": self.correct_indices}

return error_dict, self.y_true, self.y_pred


def to_precision(x, p):
"""
Expand Down Expand Up @@ -727,7 +795,7 @@ def print_pixel_report(self):
print('\nConfusion Matrix')
print(self.cm)

def calc_object_stats(self, y_true, y_pred):
def calc_object_stats(self, y_true, y_pred, return_predictions=False):
"""Calculate object statistics and save to output
Loops over each frame in the zeroth dimension, which should pass in
Expand All @@ -737,8 +805,17 @@ def calc_object_stats(self, y_true, y_pred):
Args:
y_true (numpy.array): Labeled ground truth annotations
y_pred (numpy.array): Labeled prediction mask
return_predictions (bool): Determine whether predictions will be returned for analysis
Raises:
ValueError: if the shape of the input tensor is less than length three
"""

if len(y_true.shape) < 3:
raise ValueError("Invalid input dimensions: must be at least 3D tensor")

self.stats = pd.DataFrame()
self.predictions = []

for i in range(y_true.shape[0]):
o = ObjectAccuracy(y_true[i],
Expand All @@ -747,6 +824,9 @@ def calc_object_stats(self, y_true, y_pred):
cutoff2=self.cutoff2,
seg=self.seg)
self.stats = self.stats.append(o.save_to_dataframe())
if return_predictions:
predictions = o.save_error_ids()
self.predictions.append(predictions)
if i % 500 == 0:
logging.info('{} samples processed'.format(i))

Expand All @@ -768,6 +848,8 @@ def calc_object_stats(self, y_true, y_pred):
))

self.print_object_report()
if return_predictions:
return self.predictions

def print_object_report(self):
"""Print neat report of object based statistics
Expand Down Expand Up @@ -982,3 +1064,63 @@ def match_nodes(gt, res):
iou[frame, iou_gt_idx, iou_res_idx] = intersection.sum() / union.sum()

return iou


def assign_plot_values(y_true, y_pred, error_dict):
"""Generates a matrix with cells belong to error classes numbered for plotting
Args:
y_true: 2D matrix of true labels
y_pred 2D matrix of predicted labels
error_dict: dictionary produced by save_error_ids with IDs of all error cells
Returns:
plotting_tiff: 2D matrix with cells belonging to same error class having same value
"""

plotting_tif = np.zeros_like(y_true)

# erode edges for easier visualization of adjacent cells
y_true = erode_edges(y_true, 1)
y_pred = erode_edges(y_pred, 1)

# missed detections are tracked with true labels
misses = error_dict.pop("misses")["y_true"]
plotting_tif[np.isin(y_true, misses)] = 1

# all other events are tracked with predicted labels
category_id = 2
for key in error_dict.keys():
labels = error_dict[key]["y_pred"]
plotting_tif[np.isin(y_pred, labels)] = category_id
category_id += 1

return plotting_tif


def plot_errors(y_true, y_pred, error_dict):
"""Plots the errors identified from linear assignment code
Due to sequential relabeling that occurs within the metrics code, only run
this plotting function on the outputs of save_error_ids so that values match up.
Args:
y_true: 2D matrix of true labels returned by save_error_ids
y_pred: 2D matrix of predicted labels returned by save_error_ids
error_dict: dictionary returned by save_error_ids with IDs of all error cells
"""

plotting_tif = assign_plot_values(y_true, y_pred, error_dict)

plotting_colors = ['Black', 'Pink', 'Blue', 'Green', 'tan', 'Red', 'Grey']
cmap = mpl.colors.ListedColormap(plotting_colors)

fig, ax = plt.subplots(nrows=1, ncols=1)
mat = ax.imshow(plotting_tif, cmap=cmap, vmin=np.min(plotting_tif) - .5,
vmax=np.max(plotting_tif) + .5)

# tell the colorbar to tick at integers
cbar = fig.colorbar(mat, ticks=np.arange(np.min(plotting_tif), np.max(plotting_tif) + 1))
cbar.ax.set_yticklabels(["Background", "misses", "splits", "merges",
"gains", "catastrophes", "correct"])
fig.tight_layout()

0 comments on commit d2ab686

Please sign in to comment.