Skip to content

Commit

Permalink
Use COCO metric instead of Pascal VOC's when evaluating
Browse files Browse the repository at this point in the history
  • Loading branch information
Agustín Azzinnari authored and nagitsu committed Mar 9, 2018
1 parent 07c46a4 commit 238d8b9
Showing 1 changed file with 109 additions and 70 deletions.
179 changes: 109 additions & 70 deletions luminoth/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
@click.option('--from-global-step', type=int, default=None, help='Consider only checkpoints after this global step') # noqa
@click.option('override_params', '--override', '-o', multiple=True, help='Override model config params.') # noqa
@click.option('--files-per-class', type=int, default=10, help='How many files per class display in every epoch.') # noqa
@click.option('--iou-threshold', type=float, default=0.5, help='IoU threshold to use.') # noqa
@click.option('--max-detections', type=int, default=100, help='Max detections to consider.') # noqa
def eval(dataset_split, config_files, watch, from_global_step,
override_params, files_per_class, iou_threshold, max_detections):
def eval(dataset_split, config_files, watch, from_global_step, override_params,
files_per_class, max_detections):
"""Evaluate models using dataset."""

# If the config file is empty, our config will be the base_config for the
Expand Down Expand Up @@ -56,6 +55,9 @@ def eval(dataset_split, config_files, watch, from_global_step,
# Override max detections with specified value.
config.model.rcnn.proposals.total_max_detections = max_detections

# Also overwrite `min_prob_threshold` in order to use all the detections.
config.model.rcnn.proposals.min_prob_threshold = 0.0

# Only a single run over the dataset to calculate metrics.
config.train.num_epochs = 1

Expand Down Expand Up @@ -174,7 +176,6 @@ def eval(dataset_split, config_files, watch, from_global_step,
image_vis=config.eval.image_vis,
files_per_class=files_per_class,
files_to_visualize=files_to_visualize,
iou_threshold=iou_threshold,
)
last_global_step = checkpoint['global_step']
tf.logging.info('Evaluated in {:.2f}s'.format(
Expand Down Expand Up @@ -259,8 +260,7 @@ def get_checkpoints(run_dir, from_global_step=None, last_only=False):

def evaluate_once(config, writer, saver, ops, checkpoint,
metrics_scope='metrics', image_vis=None,
files_per_class=None, files_to_visualize=None,
iou_threshold=0.5):
files_per_class=None, files_to_visualize=None):
"""Run the evaluation once.
Create a new session with the previously-built graph, run it through the
Expand Down Expand Up @@ -377,24 +377,59 @@ def evaluate_once(config, writer, saver, ops, checkpoint,

# Save final evaluation stats into summary under the checkpoint's
# global step.
map_at_iou, per_class_at_iou = calculate_map(
output_per_batch, config.model.network.num_classes,
iou_threshold
ap_per_class, ar_per_class = calculate_metrics(
output_per_batch, config.model.network.num_classes
)

map_at_50 = np.mean(ap_per_class[:, 0])
map_at_75 = np.mean(ap_per_class[:, 5])
map_at_range = np.mean(ap_per_class)
mar_at_range = np.mean(ar_per_class)

tf.logging.info('Finished evaluation at step {}.'.format(
checkpoint['global_step']))
tf.logging.info('Evaluated {} images.'.format(total_evaluated))
tf.logging.info('mAP@{} = {:.3f}'.format(
iou_threshold, map_at_iou
))

# TODO: Find a way to generate these summaries automatically, or
# less manually.
tf.logging.info(
'Average Precision (AP) @ [0.50] = {:.3f}'.format(map_at_50)
)
tf.logging.info(
'Average Precision (AP) @ [0.75] = {:.3f}'.format(map_at_75)
)
tf.logging.info(
'Average Precision (AP) @ [0.50:0.95] = {:.3f}'.format(
map_at_range
)
)
tf.logging.info(
'Average Recall (AR) @ [0.50:0.95] = {:.3f}'.format(
mar_at_range
)
)

for idx, val in enumerate(ap_per_class[:, 0]):
tf.logging.debug(
'Average Precision (AP) @ [0.50] for {} = {:.3f}'.format(
idx, val
)
)

summary = [
tf.Summary.Value(
tag='{}/mAP@{}'.format(metrics_scope, iou_threshold),
simple_value=map_at_iou
tag='{}/AP@0.50'.format(metrics_scope),
simple_value=map_at_50
),
tf.Summary.Value(
tag='{}/AP@0.75'.format(metrics_scope),
simple_value=map_at_75
),
tf.Summary.Value(
tag='{}/AP@[0.50:0.95]'.format(metrics_scope),
simple_value=map_at_range
),
tf.Summary.Value(
tag='{}/AR@[0.50:0.95]'.format(metrics_scope),
simple_value=mar_at_range
),
tf.Summary.Value(
tag='{}/total_evaluated'.format(metrics_scope),
Expand All @@ -406,17 +441,6 @@ def evaluate_once(config, writer, saver, ops, checkpoint,
),
]

for idx, val in enumerate(per_class_at_iou):
tf.logging.debug('AP@{} for {} = {:.2f}'.format(
iou_threshold, idx, val
))
summary.append(tf.Summary.Value(
tag='{}/AP@{}/{}'.format(
metrics_scope, iou_threshold, idx
),
simple_value=val
))

for loss_name, loss_value in val_losses.items():
tf.logging.debug('{} loss = {:.4f}'.format(
loss_name, loss_value))
Expand All @@ -436,15 +460,15 @@ def evaluate_once(config, writer, saver, ops, checkpoint,
coord.join(threads)


def calculate_map(output_per_batch, num_classes, iou_threshold=0.5):
"""Calculates mAP@iou_threshold from the detector's output.
def calculate_metrics(output_per_batch, num_classes):
"""Calculates mAP and mAR from the detector's output.
The procedure for calculating the average precision for class ``C`` is as
follows (see `VOC mAP metric`_ for more details):
Start by ranking all the predictions (for a given image and said class) in
order of confidence. Each of these predictions is marked as correct (true
positive, when it has a IoU-threshold greater or equal to `iou_threshold`)
positive, when it has a IoU-threshold greater or equal to `iou_thresholds`)
or incorrect (false positive, in the other case). This matching is
performed greedily over the confidence scores, so a higher-confidence
prediction will be matched over another lower-confidence one even if the
Expand All @@ -453,10 +477,9 @@ def calculate_map(output_per_batch, num_classes, iou_threshold=0.5):
We then integrate over the interpolated PR curve, thus obtaining the value
for the class' average precision. This interpolation makes sure the
precision curve is monotonically decreasing; for this, at each recall point
``r``, the precision is the maximum precision value among all recalls
higher than ``r``. The integration is performed over 11 fixed points over
the curve (``[0.0, 0.1, ..., 1.0]``).
precision curve is monotonically decreasing; for this, we go through the
precisions and make sure it's always decreasing. The integration is
performed over 101 fixed points over the curve (``[0.0, 0.01, ..., 1.0]``).
Average the result among all the classes to obtain the final, ``mAP``,
value.
Expand All @@ -467,23 +490,23 @@ def calculate_map(output_per_batch, num_classes, iou_threshold=0.5):
``gt_bboxes``, ``gt_classes``. Under each key, there should be a
list of the results per batch as returned by the detector.
num_classes (int): Number of classes on the dataset.
threshold (float): IoU threshold for considering a match.
Returns:
(``np.float``, ``ndarray``) tuple. The first value is the mAP, while
the second is an array of size (`num_classes`,), with the AP value per
class.
Note:
The "difficult example" flag of VOC dataset is being ignored.
Todo:
* Use VOC2012-style for integrating the curve. That is, use all recall
points instead of a fixed number of points like in VOC2007.
(``np.ndarray``, ``ndarray``) tuple. The first value is an array of
size (`num_classes`,), with the AP value per class, while the second
one is an array for the AR.
.. _VOC mAP metric:
http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.pdf
"""
iou_thresholds = np.linspace(
0.50, 0.95, np.round((0.95 - 0.50) / 0.05) + 1
)
# 101 recall levels, same as COCO evaluation.
rec_thresholds = np.linspace(
0.00, 1.00, np.round((1.00 - 0.00) / 0.01) + 1
)

# List; first by class, then by example. Each entry is a tuple of ndarrays
# of size (D_{c,i},), for tp/fp labels and for score, where D_{c,i} is the
# number of detected boxes for class `c` on image `i`.
Expand Down Expand Up @@ -518,10 +541,10 @@ def calculate_map(output_per_batch, num_classes, iou_threshold=0.5):
sorted_indices = np.argsort(-cls_scores)

# Whether the ground-truth has been previously detected.
is_detected = np.zeros(num_gt)
is_detected = np.zeros((num_gt, len(iou_thresholds)))

# TP/FP labels for detected bboxes of (class, image).
tp_fp_labels = np.zeros(len(sorted_indices))
tp_fp_labels = np.zeros((len(sorted_indices), len(iou_thresholds)))

if num_gt == 0:
# If no ground truth examples for class, all predictions must
Expand All @@ -537,19 +560,22 @@ def calculate_map(output_per_batch, num_classes, iou_threshold=0.5):
# Greedily assign bboxes to ground truths (highest score first).
for bbox_idx in sorted_indices:
gt_match = np.argmax(ious[bbox_idx, :])
if ious[bbox_idx, gt_match] >= iou_threshold:
# Over IoU threshold.
if not is_detected[gt_match]:
# And first detection: it's a true positive.
tp_fp_labels[bbox_idx] = True
is_detected[gt_match] = True
# TODO: Try to vectorize.
for iou_idx, iou_threshold in enumerate(iou_thresholds):
if ious[bbox_idx, gt_match] >= iou_threshold:
# Over IoU threshold.
if not is_detected[gt_match, iou_idx]:
# And first detection: it's a true positive.
tp_fp_labels[bbox_idx, iou_idx] = True
is_detected[gt_match, iou_idx] = True

tp_fp_labels_by_class[cls].append(
(tp_fp_labels, cls_scores[sorted_indices])
)

# Calculate average precision per class.
ap_per_class = np.zeros(num_classes)
ap_per_class = np.zeros((num_classes, len(iou_thresholds)))
ar_per_class = np.zeros((num_classes, len(iou_thresholds)))
for cls in range(num_classes):
tp_fp_labels = tp_fp_labels_by_class[cls]
num_examples = num_examples_per_class[cls]
Expand All @@ -562,11 +588,11 @@ def calculate_map(output_per_batch, num_classes, iou_threshold=0.5):
# Sort the tp/fp labels by decreasing confidence score and calculate
# precision and recall at every position of this ranked output.
sorted_indices = np.argsort(-scores)
true_positives = labels[sorted_indices]
true_positives = labels[sorted_indices, :]
false_positives = 1 - true_positives

cum_true_positives = np.cumsum(true_positives)
cum_false_positives = np.cumsum(false_positives)
cum_true_positives = np.cumsum(true_positives, axis=0)
cum_false_positives = np.cumsum(false_positives, axis=0)

recall = cum_true_positives.astype(float) / num_examples
precision = np.divide(
Expand All @@ -575,19 +601,32 @@ def calculate_map(output_per_batch, num_classes, iou_threshold=0.5):
)

# Find AP by integrating over PR curve, with interpolated precision.
ap = 0
for t in np.linspace(0, 1, 11):
if not np.any(recall >= t):
# Recall is never higher than `t`, continue.
continue
ap += np.max(precision[recall >= t]) / 11 # Interpolated.

ap_per_class[cls] = ap

# Finally, mAP.
mean_ap = np.mean(ap_per_class)

return mean_ap, ap_per_class
for iou_idx in range(len(iou_thresholds)):
p = precision[:, iou_idx]
r = recall[:, iou_idx]

# Interpolate the precision. (Make it monotonically-increasing.)
for i in range(len(p) - 1, 0, -1):
if p[i] > p[i-1]:
p[i-1] = p[i]

ap = 0
inds = np.searchsorted(r, rec_thresholds)
for ridx, pidx in enumerate(inds):
if pidx >= len(r):
# Out of bounds, no recall higher than threshold for any of
# the remaining thresholds (as they're ordered).
break

ap += p[pidx] / len(rec_thresholds)

ap_per_class[cls, iou_idx] = ap
if len(r):
ar_per_class[cls, iou_idx] = r[-1]
else:
ar_per_class[cls, iou_idx] = 0

return ap_per_class, ar_per_class


if __name__ == '__main__':
Expand Down

0 comments on commit 238d8b9

Please sign in to comment.