diff --git a/deepcell_toolbox/metrics.py b/deepcell_toolbox/metrics.py index 82cc0f7..a39ff26 100644 --- a/deepcell_toolbox/metrics.py +++ b/deepcell_toolbox/metrics.py @@ -63,8 +63,6 @@ from deepcell_toolbox.compute_overlap import compute_overlap # pylint: disable=E0401 from deepcell_toolbox.compute_overlap_3D import compute_overlap_3D -# TODO: store the object/pixel metrics on Metrics better -# TODO: clean up print functions class Detection(object): # pylint: disable=useless-object-inheritance """Object to hold relevant information about a given detection.""" @@ -92,7 +90,7 @@ def __eq__(self, other): def __hash__(self): """Custom hasher, allow Detections to be hashable.""" return tuple((self.true_index, self.pred_index)).__hash__() - + def __repr__(self): return 'Detection({}, {})'.format(self.true_index, self.pred_index) @@ -178,10 +176,10 @@ def __init__(self, y_true, y_pred): if not np.issubdtype(y_pred.dtype, np.integer): warnings.warn('Casting y_pred from {} to int'.format(y_pred.dtype)) y_pred = y_pred.astype('int32') - + self.y_true = y_true self.y_pred = y_pred - + def to_dict(self): return dict() @@ -256,7 +254,7 @@ def precision(self): except ZeroDivisionError: _precision = 0 return _precision - + @property def f1(self): _recall = self.recall @@ -365,7 +363,7 @@ def __init__(self, raise ValueError('Expected dimensions for y_true (2D data) are 2 ' '(x, y) and 3 (x, y, chan). ' 'Got ndim: {}'.format(y_true.ndim)) - + elif is_3d and y_true.ndim != 3: raise ValueError('Expected dimensions for y_true (3D data) is 3.' 'Requires format is: (z, x, y)' @@ -501,7 +499,7 @@ def _get_modified_iou(self, force_event_links): force_event_links (:obj:`bool'): Whether to modify IOU values of large objects if they have been split or merged by a small object. - + Returns: np.array: The modified IoU matrix. """ @@ -520,8 +518,10 @@ def _get_modified_iou(self, force_event_links): pred_mask = self.y_pred == pred_label # fraction of true cell that is contained within pred cell, vice versa - true_in_pred = np.count_nonzero(self.y_true[pred_mask] == true_label) / np.sum(true_mask) - pred_in_true = np.count_nonzero(self.y_pred[true_mask] == pred_label) / np.sum(pred_mask) + true_in_pred = np.count_nonzero( + self.y_true[pred_mask] == true_label) / np.sum(true_mask) + pred_in_true = np.count_nonzero( + self.y_pred[true_mask] == pred_label) / np.sum(pred_mask) iou_val = self.iou[true_idx, pred_idx] max_val = np.max([true_in_pred, pred_in_true]) @@ -981,7 +981,7 @@ def calc_pixel_stats(self, y_true, y_pred, axis=-1): Args: y_true (numpy.array): Ground truth annotations after transform y_pred (numpy.array): Model predictions without labeling - + Returns: list: list of dictionaries with each stat being a key. @@ -1079,7 +1079,7 @@ def calc_object_stats(self, y_true, y_pred, progbar=True): all_object_metrics = [] # store all calculated metrics is_batch_relabeled = False # used to warn if batches were relabeled - + for i in tqdm(range(y_true.shape[0]), disable=not progbar): # check if labels aren't sequential, raise warning on first occurence if so true_batch, pred_batch = y_true[i], y_pred[i] @@ -1136,7 +1136,7 @@ def print_object_report(self, object_metrics): } for k in errors: errors[k] = int(object_metrics[k].sum()) - + bad_detections = [ 'gained_det_from_split', 'missed_det_from_merge', diff --git a/deepcell_toolbox/metrics_test.py b/deepcell_toolbox/metrics_test.py index ce50c80..e51722c 100644 --- a/deepcell_toolbox/metrics_test.py +++ b/deepcell_toolbox/metrics_test.py @@ -295,7 +295,7 @@ def test_init(self): true_idx, pred_idx = [1, 2], [2, 3] detection = metrics.Detection(true_idx, pred_idx) assert detection.is_catastrophe - + def test_hash(self): # test that Detections get hashed appropriately detection_set = set() @@ -306,7 +306,7 @@ def test_hash(self): assert d2 is not d1 # should be in the set since d2 == d1 assert d2 in detection_set - + def test_eq(self): # test Detection equality comparisons detection_set = set() @@ -317,12 +317,13 @@ def test_eq(self): assert d2 is not d1 assert d2 == d1 - assert d1 != None + assert d1 is not None print(d1) # test that __repr__ is called class TestPixelMetrics(): + def test_init(self): y_true, _ = _sample1(10, 10, 30, 30, True) @@ -332,10 +333,10 @@ def test_init(self): # Test mismatched input size with pytest.raises(ValueError): metrics.PixelMetrics(y_true, y_true[0]) - + # using float dtype warns but still works o = metrics.PixelMetrics(y_true.astype('float'), y_true.astype('float')) - + def test_y_true_equals_y_pred(self): y_true, _ = _sample1(10, 10, 30, 30, True) y_pred = y_true.copy() @@ -360,7 +361,7 @@ def test_y_pred_empty(self): assert o.precision == 0 assert o.f1 == 0 assert o.jaccard == 0 - + def test_y_true_empty(self): y_pred, _ = _sample1(10, 10, 30, 30, True) @@ -398,11 +399,11 @@ def test_init(self): y_true = np.zeros(shape=(10, 15, 15, 10)) # too many dimensions with pytest.raises(ValueError): metrics.ObjectMetrics(y_true, y_true, is_3d=True) - + # Test mismatched input size with pytest.raises(ValueError): metrics.ObjectMetrics(y_true, y_true[0]) - + # using float dtype warns but still works o = metrics.PixelMetrics(y_true.astype('float'), y_true.astype('float')) @@ -446,7 +447,7 @@ def test_y_pred_empty(self): assert o.precision == 0 assert o.f1 == 0 assert o.jaccard == 0 - + def test_y_true_empty(self): y_pred, _ = _sample1(10, 10, 30, 30, True) @@ -620,4 +621,4 @@ def test_split_stack(): with pytest.raises(ValueError): metrics.split_stack(arr, False, 11, 0, 10, 1) with pytest.raises(ValueError): - metrics.split_stack(arr, False, 10, 0, 11, 1) \ No newline at end of file + metrics.split_stack(arr, False, 10, 0, 11, 1)