Skip to content

Commit

Permalink
Add batch mode to process many submissions in folder.
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jul 9, 2017
1 parent 0852cfd commit 26bea69
Showing 1 changed file with 158 additions and 84 deletions.
242 changes: 158 additions & 84 deletions tracklets/python/evaluate_tracklets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from shapely.geometry import Polygon
from collections import Counter, defaultdict
import numpy as np
import pandas as pd
import argparse
import os
import sys
Expand Down Expand Up @@ -419,65 +420,17 @@ def process_sequence(
metric_fn=metric_fn)


def main():
parser = argparse.ArgumentParser(description='Evaluate two tracklet files.')
parser.add_argument('prediction', type=str, nargs='?', default='tracklet_labels.xml',
help='Predicted tracklet label filename or folder')
parser.add_argument('groundtruth', type=str, nargs='?', default='tracklet_labels_gt.xml',
help='Groundtruth tracklet label filename or folder')
parser.add_argument('-f', '--include_indices', type=str, nargs='?', default=None,
help='CSV file containing frame indices to include in evaluation. All frames included if argument empty.')
parser.add_argument('-e', '--exclude_indices', type=str, nargs='?', default=None,
help='CSV file containing frame indices to exclude (takes priority over inclusions) from evaluation.')
parser.add_argument('-o', '--outdir', type=str, nargs='?', default=None,
help='Output folder')
parser.add_argument('-m', '--method', type=str, nargs='?', default='',
help='Volume intersection calculation method override. "box", "cylinder", '
'"sphere" (default = "", no override)')
parser.add_argument('-v', '--eval_metric', type=str, nargs='?', default='iou',
help='Eval metric. "iou" or "dice" (default = "iou")')
parser.add_argument('-w', '--class_weight', type=str, nargs='?', default='instance',
help='Weighting method across all classes. "simple", "volume", "instance", or "none" (default="instance")')
parser.add_argument('-g', dest='override_lwh_with_gt', action='store_true',
help='Override predicted lwh values with value from first gt tracklet.')
parser.add_argument('--test', dest='test_mode', action='store_true', help='Test mode enable')
parser.add_argument('-d', dest='debug', action='store_true', help='Debug print enable')
parser.set_defaults(test_mode=False)
parser.set_defaults(debug=False)
parser.set_defaults(override_lwh_with_gt=False)
args = parser.parse_args()
include_indices_path = args.include_indices
exclude_indices_path = args.exclude_indices
output_dir = args.outdir
eval_metric = args.eval_metric
class_weighting = args.class_weight
if class_weighting not in CLASS_WEIGHTING:
print('Error: Invalid class weighting "%s". Must be one of %s\n'
% (class_weighting, CLASS_WEIGHTING))
exit(-1)

process_params = dict()
process_params['test_mode'] = args.test_mode
process_params['override_lwh_with_gt'] = args.override_lwh_with_gt
process_params['override_volume_method'] = ''
if args.method:
if args.method not in VOLUME_METHODS:
print('Error: Invalid volume method override "%s". Must be one of %s\n'
% (args.method, VOLUME_METHODS))
exit(-1)
else:
print('Overriding volume intersection method with %s' % args.method)
process_params['override_volume_method'] = args.method

pred_path = args.prediction
if not os.path.exists(pred_path):
sys.stderr.write('Error: Prediction file %s not found.\n' % pred_path)
exit(-1)

gt_path = args.groundtruth
if not os.path.exists(gt_path):
sys.stderr.write('Error: Ground-truth file %s not found.\n' % gt_path)
exit(-1)
def process_submission(
pred_path,
gt_path,
include_indices_path,
exclude_indices_path,
process_params,
metric_params,
output_dir,
class_weighting='instance',
prefix='',
suppress_print=False):

sequences = []
if os.path.isfile(pred_path) and os.path.isfile(gt_path):
Expand Down Expand Up @@ -508,7 +461,7 @@ def _f(path):
if pb in exclude_indices_files:
seq['exclude_indices_file'] = exclude_indices_files[pb]
sequences.append(seq)
if len(gt_files) != len(pred_files):
if len(pred_files) < len(gt_files):
print('Warning: Only %d of %d ground-truth files matched with predictions.'
% (len(pred_files), len(gt_files)))
assert len(gt_files) == len(sequences)
Expand All @@ -522,19 +475,10 @@ def _f(path):
exit(-1)

metric_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
if eval_metric == 'dice':
metric_fn = dice
METRIC_PER_OBJ = 'dice_score_per_obj'
PR_PER_METRIC = 'pr_per_dice_score'
else:
metric_fn = iou
METRIC_PER_OBJ = 'iou_per_obj'
PR_PER_METRIC = 'pr_per_iou'

pr_at_thresh = {k: Counter() for k in metric_thresholds}
counters = defaultdict(Counter)
for s in sequences:
process_sequence(s, counters, pr_at_thresh, metric_fn, process_params)
process_sequence(s, counters, pr_at_thresh, metric_params['fn'], process_params)

# Class weighting
# * '' or 'none': calculate the metric across all class volumes (large volume objects dominate small)
Expand All @@ -545,12 +489,12 @@ def _f(path):
metric_weight_sum = 0.0
combined_vol_sum = 0.0
intersection_vol_sum = 0.0
results_table = {METRIC_PER_OBJ: {}, PR_PER_METRIC: {}}
results_table = {metric_params['metric_per_obj']: {}, metric_params['pr_per_metric']: {}}
for k in counters['combined_volume'].keys():
combined_vol_sum += counters['combined_volume'][k]
intersection_vol_sum += counters['intersection_volume'][k]
metric_val = metric_fn(counters['combined_volume'][k], 0., counters['intersection_volume'][k])
results_table[METRIC_PER_OBJ][k] = float(metric_val)
metric_val = metric_params['fn'](counters['combined_volume'][k], 0., counters['intersection_volume'][k])
results_table[metric_params['metric_per_obj']][k] = float(metric_val)
if class_weighting == 'instance':
metric_weight = counters['gt_instance_count'][k]
elif class_weighting == 'volume':
Expand All @@ -565,29 +509,159 @@ def _f(path):
all_metric = metric_value_sum / metric_weight_sum if metric_weight_sum else 0.
else:
# no weighting is enabled, compute the metric over the volumes summed across all classes
all_metric = metric_fn(combined_vol_sum, 0., intersection_vol_sum)
results_table[METRIC_PER_OBJ]['All'] = float(all_metric)
all_metric = metric_params['fn'](combined_vol_sum, 0., intersection_vol_sum)
results_table[metric_params['metric_per_obj']]['All'] = float(all_metric)


# FIXME add support for per class P/R scores?
# NOTE P/R scores need further analysis given their use with the greedy pred - gt matching
for k, v in pr_at_thresh.items():
p = v['TP'] / (v['TP'] + v['FP']) if v['TP'] else 0.0
r = v['TP'] / (v['TP'] + v['FN']) if v['TP'] else 0.0
results_table[PR_PER_METRIC][k] = {'precision': p, 'recall': r}
results_table[metric_params['pr_per_metric']][k] = {'precision': p, 'recall': r}

print('\nResults')
print(yaml.safe_dump(results_table, default_flow_style=False, explicit_start=True))
if not suppress_print:
print('\nResults')
print(yaml.safe_dump(results_table, default_flow_style=False, explicit_start=True))

if output_dir is not None:
with open(os.path.join(output_dir, METRIC_PER_OBJ + '.csv'), 'w') as f:
f.write('object_type,%s\n' % eval_metric)
with open(os.path.join(output_dir, metric_params['metric_per_obj'] + '.csv'), 'w') as f:
f.write('object_type,%s\n' % metric_params['name'])
[f.write('{0},{1}\n'.format(k, v))
for k, v in sorted(results_table[METRIC_PER_OBJ].items(), key=lambda x: x[0])]
with open(os.path.join(output_dir, PR_PER_METRIC + '.csv'), 'w') as f:
f.write('%s_threshold,p,r\n' % eval_metric)
for k, v in sorted(results_table[metric_params['metric_per_obj']].items(), key=lambda x: x[0])]
with open(os.path.join(output_dir, metric_params['pr_per_metric'] + '.csv'), 'w') as f:
f.write('%s_threshold,p,r\n' % metric_params['name'])
[f.write('{0},{1},{2}\n'.format(k, v['precision'], v['recall']))
for k, v in sorted(results_table[PR_PER_METRIC].items(), key=lambda x: x[0])]
for k, v in sorted(results_table[metric_params['pr_per_metric']].items(), key=lambda x: x[0])]

return results_table


def find_submissions(folder):
submissions = []
for root, _, files in os.walk(folder, topdown=False):
for rel_filename in files:
base, ext = os.path.splitext(rel_filename)
if ext.lower() == '.xml':
submissions.append(root)
break
return submissions


def get_outdir(base_dir, name=''):
outdir = os.path.join(base_dir, name)
if not os.path.exists(outdir):
os.makedirs(outdir)
return outdir


def main():
parser = argparse.ArgumentParser(description='Evaluate two tracklet files.')
parser.add_argument('prediction', type=str, nargs='?', default='tracklet_labels.xml',
help='Predicted tracklet label filename or folder')
parser.add_argument('groundtruth', type=str, nargs='?', default='tracklet_labels_gt.xml',
help='Groundtruth tracklet label filename or folder')
parser.add_argument('-f', '--include_indices', type=str, nargs='?', default=None,
help='CSV file containing frame indices to include in evaluation. All frames included if argument empty.')
parser.add_argument('-e', '--exclude_indices', type=str, nargs='?', default=None,
help='CSV file containing frame indices to exclude (takes priority over inclusions) from evaluation.')
parser.add_argument('-o', '--outdir', type=str, nargs='?', default=None,
help='Output folder')
parser.add_argument('-m', '--method', type=str, nargs='?', default='',
help='Volume intersection calculation method override. "box", "cylinder", '
'"sphere" (default = "", no override)')
parser.add_argument('-v', '--eval_metric', type=str, nargs='?', default='iou',
help='Eval metric. "iou" or "dice" (default = "iou")')
parser.add_argument('-w', '--class_weight', type=str, nargs='?', default='instance',
help='Weighting method across all classes. "simple", "volume", "instance", or "none" (default="instance")')
parser.add_argument('-g', dest='override_lwh_with_gt', action='store_true',
help='Override predicted lwh values with value from first gt tracklet.')
parser.add_argument('--batch', dest='batch_mode', action='store_true', help='Batch mode enable')
parser.add_argument('--test', dest='test_mode', action='store_true', help='Test mode enable')
parser.add_argument('-d', dest='debug', action='store_true', help='Debug print enable')
parser.set_defaults(test_mode=False)
parser.set_defaults(debug=False)
parser.set_defaults(override_lwh_with_gt=False)
args = parser.parse_args()
include_indices_path = args.include_indices
exclude_indices_path = args.exclude_indices
output_dir = args.outdir
eval_metric = args.eval_metric
class_weighting = args.class_weight
if class_weighting not in CLASS_WEIGHTING:
print('Error: Invalid class weighting "%s". Must be one of %s\n'
% (class_weighting, CLASS_WEIGHTING))
exit(-1)

process_params = dict()
process_params['test_mode'] = args.test_mode
process_params['override_lwh_with_gt'] = args.override_lwh_with_gt
process_params['override_volume_method'] = ''
if args.method:
if args.method not in VOLUME_METHODS:
print('Error: Invalid volume method override "%s". Must be one of %s\n'
% (args.method, VOLUME_METHODS))
exit(-1)
else:
print('Overriding volume intersection method with %s' % args.method)
process_params['override_volume_method'] = args.method

pred_path = args.prediction
if not os.path.exists(pred_path):
sys.stderr.write('Error: Prediction file/folder %s not found.\n' % pred_path)
exit(-1)

gt_path = args.groundtruth
if not os.path.exists(gt_path):
sys.stderr.write('Error: Ground-truth file/folder %s not found.\n' % gt_path)
exit(-1)

metric_params = {}
if eval_metric == 'dice':
metric_params['name'] = 'dice'
metric_params['fn'] = dice
metric_params['metric_per_obj'] = 'dice_score_per_obj'
metric_params['pr_per_metric'] = 'pr_per_dice_score'
else:
metric_params['name'] = 'iou'
metric_params['fn'] = iou
metric_params['metric_per_obj'] = 'iou_per_obj'
metric_params['pr_per_metric'] = 'pr_per_iou'

if args.batch_mode:
results = {}
submission_paths = find_submissions(pred_path)
for p in submission_paths:
pb = os.path.relpath(p, pred_path)
result = process_submission(
p, gt_path,
include_indices_path, exclude_indices_path,
process_params,
metric_params,
get_outdir(output_dir, pb),
class_weighting=class_weighting,
suppress_print=True)
results[pb] = result

if output_dir:
per_obj = {}
print(metric_params['metric_per_obj'])
for k, r in results.items():
print(k, os.path.basename(k))
per_obj[k] = r[metric_params['metric_per_obj']]

per_obj_df = pd.DataFrame.from_dict(per_obj, orient='index')
per_obj_df.sort_values('All', ascending=False, inplace=True)
per_obj_df.to_csv(os.path.join(output_dir, 'results_per_obj.csv'), index=True, index_label='Submission')

print(yaml.safe_dump(results, default_flow_style=False, explicit_start=True))

else:
process_submission(
pred_path, gt_path,
include_indices_path, exclude_indices_path,
process_params, metric_params, output_dir,
class_weighting=class_weighting)


if __name__ == '__main__':
Expand Down

0 comments on commit 26bea69

Please sign in to comment.