Skip to content

Commit

Permalink
Classify divisions that are +/- 1 frame as correct (#103)
Browse files Browse the repository at this point in the history
* Define a new class for benchmarking tracking that supports division offsets

* Bugfixes and import corrections

* Restore classify divisions as an independent function

* Move division shift correction function outside the metrics class

* Add test for correcting shifted divisions

* pep8
  • Loading branch information
msschwartz21 committed Oct 6, 2022
1 parent 45d9ee5 commit 09a835d
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 44 deletions.
2 changes: 1 addition & 1 deletion deepcell_tracking/isbi_utils.py
Expand Up @@ -37,7 +37,7 @@

# Imports for backwards compatibility
from deepcell_tracking.utils import match_nodes, contig_tracks
from deepcell_tracking.metrics import classify_divisions, calculate_summary_stats
from deepcell_tracking.metrics import calculate_summary_stats
from deepcell_tracking.metrics import benchmark_tracking_performance


Expand Down
249 changes: 215 additions & 34 deletions deepcell_tracking/metrics.py
Expand Up @@ -30,6 +30,9 @@
from __future__ import print_function

from collections import Counter
import itertools

import numpy as np

from deepcell_tracking.trk_io import load_trks
from deepcell_tracking.utils import match_nodes, trk_to_graph
Expand Down Expand Up @@ -94,10 +97,9 @@ def _map_node(gt_node):
div_res = [node for node, d in G_res.nodes(data=True)
if d.get('division', False)]

correct = 0 # Correct division
incorrect = 0 # Wrong division
false_positive = 0 # False positive division
missed = 0 # Missed division
correct = [] # Correct division
incorrect = [] # Wrong/mismatch division
missed = [] # Missed division

for node in div_gt:
idx = int(node.split('_')[0])
Expand All @@ -113,15 +115,15 @@ def _map_node(gt_node):
else:
# Node doesn't exist so count this division as missed
print('missed node {} division completely'.format(node))
missed += 1
missed.append(node)
continue # move on to next node in div_gt
# Check if the node exists with same id in G_res
elif node in G_res.nodes:
r_node = node
# Node doesn't exist
else:
print('missed node {} division completely'.format(node))
missed += 1
missed.append(node)
continue # move on to next node in div_gt

# If we found the results node, evaluate division result
Expand All @@ -139,10 +141,10 @@ def _map_node(gt_node):
# Parents and daughters are the same, perfect!
if (Counter(pred_gt) == Counter(pred_res) and
Counter(succ_gt) == Counter(succ_res)):
correct += 1
correct.append(node)

else: # what went wrong?
incorrect += 1
incorrect.append(node)
errors = ['out degree = {}'.format(G_res.out_degree(r_node))]
if Counter(succ_gt) != Counter(succ_res):
errors.append('daughters mismatch')
Expand All @@ -156,10 +158,10 @@ def _map_node(gt_node):

else: # valid division not in results, it was missed
print('missed node {} division completely'.format(node))
missed += 1
missed.append(node)

# Count any remaining res nodes as false positives
false_positive += len(div_res)
false_positive = div_res

return {
'correct_division': correct,
Expand All @@ -170,6 +172,113 @@ def _map_node(gt_node):
}


def correct_shifted_divisions(
missed, false_positive, correct,
y_gt, y_res,
G_gt, G_res,
threshold):
"""Correct divisions errors that are shifted by a frame and should be counted as correct
Args:
missed (list): List of nodes classifed as a false negative division
false_positive (list): List of nodes classified as false positive division
correct (list): List of nodes where divisions were correctly assigned
y_gt (np.array): Y mask for the ground truth data
y_res (np.array): Y mask for the predicted data
G_gt (networkx.graph): Graph of the ground truth
G_res (networkx.graph): Graph of the results
threshold (float): Value between 0 and 1 used to determine matching cells using IoU
Returns:
dict: Dictionary of updated missed, false_positive and correct division lists
"""

metrics = {
'missed': missed,
'false_positive': false_positive,
'correct': correct
}
y = {'gt': y_gt, 'res': y_res}
G = {'gt': G_gt, 'res': G_res}

# Explicitly label nodes according to source
missed = ['gt-' + n for n in missed]
false_positive = ['res-' + n for n in false_positive]

# Convert to dictionary for lookup by frame
d_missed, d_fp = {}, {}
for d, l in [(d_missed, missed), (d_fp, false_positive)]:
for n in l:
t = int(n.split('_')[-1])
v = d.get(t, [])
v.append(n)
d[t] = v

frame_pairs = []
for t in d_missed:
if t + 1 in d_fp:
frame_pairs.append((t, t + 1))
if t - 1 in d_fp:
frame_pairs.append((t - 1, t))

# Convert to set to remove any duplicates
frame_pairs = list(set(frame_pairs))

matches = []

# Loop over each pair of frames
for t1, t2 in frame_pairs:
# Get nodes from each frames
n1s = d_missed.get(t1, []) + d_fp.get(t1, [])
n2s = d_missed.get(t2, []) + d_fp.get(t2, [])

# Compare each pair and save if they are above the threshold
for n1, n2 in itertools.product(n1s, n2s):
source1, node1 = n1.split('-')[0], n1.split('-')[1]
source2, node2 = n2.split('-')[0], n2.split('-')[1]

# Check if the nodes are from different sources
if source1 == source2:
continue

# Compare sum of daughters in n1 to parent in n2
daughters = [int(d.split('_')[0]) for d in list(G[source1].succ[node1])]
if len(daughters) == 1:
mask1 = y[source1][t2] == daughters[0]
else:
mask1 = np.logical_or(
y[source1][t2] == daughters[0],
y[source1][t2] == daughters[1])
if len(daughters) > 2:
for d in range(2, len(daughters)):
mask1 = np.logical_or(
mask1,
y[source1][t2] == daughters[d]
)
mask2 = y[source2][t2] == int(node2.split('_')[0])

# Compute iou
intersection = np.logical_and(mask1, mask2)
union = np.logical_or(mask1, mask2)
iou = intersection.sum() / union.sum()
if iou >= threshold:
matches.extend([n1, n2])

# Remove matches from the list of errors
for n in matches:
source, node = n.split('-')[0], n.split('-')[1]
# Remove error counts
if source == 'gt':
metrics['missed'].remove(node)
# Add node to the correct count
metrics['correct'].append(node)
print('corrected division {} as a frameshift division not an error'.format(node))
elif source == 'res':
metrics['false_positive'].remove(node)

return metrics


def calculate_association_accuracy(lineage_gt, lineage_res, cells_gt=[], cells_res=[]):
"""Calculate the association accuracy for each ground truth lineage
Expand Down Expand Up @@ -343,41 +452,113 @@ def calculate_summary_stats(correct_division,
}


def benchmark_tracking_performance(trk_gt, trk_res, threshold=1):
class TrackingMetrics:
def __init__(self,
lineage_gt, y_gt,
lineage_res, y_res,
threshold=1,
allow_division_shift=True):
"""Class to coordinate the benchmarking of a pair of trk files
Args:
trk_gt (path): Path to the ground truth .trk file.
trk_res (path): Path to the predicted results .trk file.
threshold (optional, float): threshold value for IoU to count as same cell. Default 1.
If segmentations are identical, 1 works well.
For imperfect segmentations try 0.6-0.8 to get better matching
allow_temporal_shifts (optional, bool): Allows divisions to be treated as correct if
they are off by a single frame. Default True.
"""

self.lineage_gt = lineage_gt
self.lineage_res = lineage_res
self.y_gt = y_gt
self.y_res = y_res
self.threshold = threshold
self.allow_division_shift = allow_division_shift

# Match up labels in GT to Results to allow for direct comparisons
self.cells_gt, self.cells_res = match_nodes(y_gt, y_res, self.threshold)

# Generate graphs without remapping nodes to avoid losing lineages
self.G_gt = trk_to_graph(lineage_gt)
self.G_res = trk_to_graph(lineage_res)

self.stats = self.calculate_metrics()

@classmethod
def from_trk_files(cls, trk_gt, trk_res, threshold=1, allow_division_shift=True):
# Load data
trks = load_trks(trk_gt)
lineage_gt, y_gt = trks['lineages'][0], trks['y']
trks = load_trks(trk_res)
lineage_res, y_res = trks['lineages'][0], trks['y']

return cls(
lineage_gt=lineage_gt, y_gt=y_gt,
lineage_res=lineage_res, y_res=y_res,
threshold=threshold,
allow_division_shift=allow_division_shift
)

def calculate_metrics(self):
# Classify divison errors
stats = classify_divisions(
self.G_gt, self.G_res, cells_gt=self.cells_gt, cells_res=self.cells_res)

if self.allow_division_shift:
updates = correct_shifted_divisions(
missed=stats['false_negative_division'],
false_positive=stats['false_positive_division'],
correct=stats['correct_division'],
y_gt=self.y_gt,
y_res=self.y_res,
G_gt=self.G_gt,
G_res=self.G_res,
threshold=self.threshold)

for k, v in updates.items():
stats[k] = v

# Convert list of nodes to counts
for k, v in stats.items():
if isinstance(v, list):
stats[k] = len(v)

# Calculate aa and te
aa_tp, aa_total = calculate_association_accuracy(
self.lineage_gt, self.lineage_res, self.cells_gt, self.cells_res)

te_tp, te_total = calculate_target_effectiveness(
self.lineage_gt, self.lineage_res, self.cells_gt, self.cells_res)

return {
**stats,
'aa_tp': aa_tp,
'aa_total': aa_total,
'te_tp': te_tp,
'te_total': te_total
}


def benchmark_tracking_performance(trk_gt, trk_res, threshold=1, allow_division_shift=True):
"""Compare two related .trk files (one being the GT of the other)
Calculate division statistics, target effectiveness and association accuracy
Currently included for backwards compatibility, but is no longer necessary
Args:
trk_gt (path): Path to the ground truth .trk file.
trk_res (path): Path to the predicted results .trk file.
threshold (optional, float): threshold value for IoU to count as same cell. Default 1.
If segmentations are identical, 1 works well.
For imperfect segmentations try 0.6-0.8 to get better matching
"""
stats = {}

# Load data
trks = load_trks(trk_gt)
lineage_gt, y_gt = trks['lineages'][0], trks['y']
trks = load_trks(trk_res)
lineage_res, y_res = trks['lineages'][0], trks['y']

# Match up labels in GT to Results to allow for direct comparisons
cells_gt, cells_res = match_nodes(y_gt, y_res, threshold)

# Generate graphs without remapping nodes to avoid losing lineages
G_gt = trk_to_graph(lineage_gt)
G_res = trk_to_graph(lineage_res)

# Calculate metrics
division_stats = classify_divisions(G_gt, G_res, cells_gt, cells_res)
stats.update(division_stats)

stats['aa_tp'], stats['aa_total'] = calculate_association_accuracy(lineage_gt, lineage_res,
cells_gt, cells_res)

stats['te_tp'], stats['te_total'] = calculate_target_effectiveness(lineage_gt, lineage_res,
cells_gt, cells_res)
m = TrackingMetrics.from_trk_files(trk_gt, trk_res,
threshold=threshold,
allow_division_shift=allow_division_shift)

return stats
return m.stats

0 comments on commit 09a835d

Please sign in to comment.