Skip to content

Commit

Permalink
Update print reports
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 committed Feb 12, 2019
1 parent d4b7464 commit 5be40a4
Showing 1 changed file with 47 additions and 3 deletions.
50 changes: 47 additions & 3 deletions deepcell/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,6 @@ def all_pixel_stats(self, y_true, y_pred):
Raises:
ValueError: If y_true and y_pred are not the same shape
"""
# Record length of output to use for printing later on

if y_pred.shape != y_true.shape:
raise ValueError('Shape of inputs need to match. Shape of prediction '
Expand All @@ -813,14 +812,16 @@ def all_pixel_stats(self, y_true, y_pred):
self.output = self.output + self.pixel_df_to_dict(self.pixel_df)

# Calculate confusion matrix
cm = self.calc_pixel_confusion_matrix(y_true, y_pred)
self.cm = self.calc_pixel_confusion_matrix(y_true, y_pred)
self.output.append(dict(
name='confusion_matrix',
value=cm.tolist(),
value=self.cm.tolist(),
feature='all',
stat_type='pixel'
))

self.print_pixel_report()

def pixel_df_to_dict(self, df):
"""Output pandas df as a list of dictionary objects
Expand Down Expand Up @@ -873,6 +874,15 @@ def calc_pixel_confusion_matrix(self, y_true, y_pred):

return confusion_matrix(y_true, y_pred)

def print_pixel_report(self):
"""Print report of pixel based statistics
"""

print('\n____________Pixel-based statistics____________\n')
print(self.pixel_df)
print('\nConfusion Matrix')
print(self.cm)

def calc_object_stats(self, y_true, y_pred):
"""Calculate object statistics and save to output
Expand Down Expand Up @@ -912,6 +922,40 @@ def calc_object_stats(self, y_true, y_pred):
stat_type='object'
))

self.print_object_report()

def print_object_report(self):
"""Print neat report of object based statistics
"""

print('\n____________Object-based statistics____________\n')
print('Number of true cells:\t\t', int(self.stats['n_true'].sum()))
print('Number of predicted cells:\t', int(self.stats['n_pred'].sum()))

print('\nTrue positives: {}\tAccuracy: {}%'.format(
int(self.stats['true_pos'].sum()),
100 * round(self.stats['true_pos'].sum() / self.stats['n_true'].sum(), 4)))

total_err = (self.stats['false_pos'].sum()
+ self.stats['false_neg'].sum()
+ self.stats['split'].sum()
+ self.stats['merge'].sum())
print('\nFalse positives: {}\tPerc Error: {}%'.format(
int(self.stats['false_pos'].sum()),
100 * round(self.stats['false_pos'].sum() / total_err, 4)))
print('False negatives: {}\tPerc Error: {}%'.format(
int(self.stats['true_pos'].sum()),
100 * round(self.stats['false_neg'].sum() / total_err, 4)))
print('Merges:\t\t {}\tPerc Error: {}%'.format(
int(self.stats['merge'].sum()),
100 * round(self.stats['merge'].sum() / total_err, 4)))
print('Splits:\t\t {}\tPerc Error: {}%'.format(
int(self.stats['split'].sum()),
100 * round(self.stats['split'].sum() / total_err, 4)))

if self.seg is True:
print('\nSEG:', round(self.stats['seg'].mean(), 4))

def run_all(self,
y_true_lbl,
y_pred_lbl,
Expand Down

0 comments on commit 5be40a4

Please sign in to comment.