Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion research/object_detection/builders/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
iou_thresh=frcnn_config.first_stage_nms_iou_threshold,
max_size_per_class=frcnn_config.first_stage_max_proposals,
max_total_size=frcnn_config.first_stage_max_proposals,
use_static_shapes=use_static_shapes)
use_static_shapes=use_static_shapes,
use_combined_nms=frcnn_config.use_combined_nms_in_first_stage)
first_stage_loc_loss_weight = (
frcnn_config.first_stage_localization_loss_weight)
first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def _build_non_max_suppressor(nms_config):
'max_total_detections.')
if nms_config.soft_nms_sigma < 0.0:
raise ValueError('soft_nms_sigma should be non-negative.')
if nms_config.use_combined_nms and nms_config.use_class_agnostic_nms:
raise ValueError('combined_nms does not support class_agnostic_nms')

non_max_suppressor_fn = functools.partial(
post_processing.batch_multiclass_non_max_suppression,
score_thresh=nms_config.score_threshold,
Expand All @@ -97,7 +100,8 @@ def _build_non_max_suppressor(nms_config):
use_static_shapes=nms_config.use_static_shapes,
use_class_agnostic_nms=nms_config.use_class_agnostic_nms,
max_classes_per_detection=nms_config.max_classes_per_detection,
soft_nms_sigma=nms_config.soft_nms_sigma)
soft_nms_sigma=nms_config.soft_nms_sigma,
use_combined_nms=nms_config.use_combined_nms)
return non_max_suppressor_fn


Expand Down
50 changes: 49 additions & 1 deletion research/object_detection/core/batch_multiclass_nms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,55 @@ def test_batch_multiclass_nms_with_additional_fields_and_num_valid_boxes(
exp_nms_additional_fields[key])
self.assertAllClose(num_detections, [1, 1])

# TODO(bhattad): Remove conditional after CMLE moves to TF 1.9
def test_combined_nms_with_batch_size_2(self):
"""Test use_combined_nms."""
boxes = tf.constant([[[[0, 0, 0.1, 0.1], [0, 0, 0.1, 0.1]],
[[0, 0.01, 1, 0.11], [0, 0.6, 0.1, 0.7]],
[[0, -0.01, 0.1, 0.09], [0, -0.1, 0.1, 0.09]],
[[0, 0.11, 0.1, 0.2], [0, 0.11, 0.1, 0.2]]],
[[[0, 0, 0.2, 0.2], [0, 0, 0.2, 0.2]],
[[0, 0.02, 0.2, 0.22], [0, 0.02, 0.2, 0.22]],
[[0, -0.02, 0.2, 0.19], [0, -0.02, 0.2, 0.19]],
[[0, 0.21, 0.2, 0.3], [0, 0.21, 0.2, 0.3]]]],
tf.float32)
scores = tf.constant([[[.1, 0.9], [.75, 0.8],
[.6, 0.3], [0.95, 0.1]],
[[.1, 0.9], [.75, 0.8],
[.6, .3], [.95, .1]]])
score_thresh = 0.1
iou_thresh = .5
max_output_size = 3

exp_nms_corners = np.array([[[0, 0.11, 0.1, 0.2],
[0, 0, 0.1, 0.1],
[0, 0.6, 0.1, 0.7]],
[[0, 0.21, 0.2, 0.3],
[0, 0, 0.2, 0.2],
[0, 0.02, 0.2, 0.22]]])
exp_nms_scores = np.array([[.95, .9, 0.8],
[.95, .9, .75]])
exp_nms_classes = np.array([[0, 1, 1],
[0, 1, 0]])

(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
use_static_shapes=True,
use_combined_nms=True)

self.assertIsNone(nmsed_masks)
self.assertIsNone(nmsed_additional_fields)

with self.test_session() as sess:
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertListEqual(num_detections.tolist(), [3, 3])

if __name__ == '__main__':
tf.test.main()
67 changes: 64 additions & 3 deletions research/object_detection/core/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,8 @@ def batch_multiclass_non_max_suppression(boxes,
use_static_shapes=False,
parallel_iterations=32,
use_class_agnostic_nms=False,
max_classes_per_detection=1):
max_classes_per_detection=1,
use_combined_nms=False):
"""Multi-class version of non maximum suppression that operates on a batch.

This op is similar to `multiclass_non_max_suppression` but operates on a batch
Expand Down Expand Up @@ -866,14 +867,28 @@ def batch_multiclass_non_max_suppression(boxes,
False.
scope: tf scope name.
use_static_shapes: If true, the output nmsed boxes are padded to be of
length `max_size_per_class` and it doesn't clip boxes to max_total_size.
length `minimum(max_total_size, max_size_per_class*num_classes)`.
If false, they are padded to be of length `max_total_size`.
Defaults to false.
parallel_iterations: (optional) number of batch items to process in
parallel.
use_class_agnostic_nms: If true, this uses class-agnostic non max
suppression
max_classes_per_detection: Maximum number of retained classes per detection
box in class-agnostic NMS.
use_combined_nms: If true, it uses tf.image.combined_non_max_suppression (
multi-class version of NMS that operates on a batch).
It greedily selects a subset of detection bounding boxes, pruning away
boxes that have high IOU (intersection over union) overlap (> thresh) with
already selected boxes. It operates independently for each batch.
Within each batch, it operates independently for each class for which
scores are provided (via the scores field of the input box_list),
pruning boxes with score less than a provided threshold prior to applying
NMS. This operation is performed on *all* batches and *all* classes
in the batch, therefore any background classes should be removed prior to
calling this function.
Masks and additional fields are not supported.
See argument checks in the code below for unsupported arguments.

Returns:
'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor
Expand All @@ -899,11 +914,57 @@ def batch_multiclass_non_max_suppression(boxes,
ValueError: if `q` in boxes.shape is not 1 or not equal to number of
classes as inferred from scores.shape.
"""
if use_combined_nms:
if change_coordinate_frame:
raise ValueError(
'change_coordinate_frame (normalizing coordinates'
' relative to clip_window) is not supported by combined_nms.')
if num_valid_boxes is not None:
raise ValueError('num_valid_boxes is not supported by combined_nms.')
if masks is not None:
raise ValueError('masks is not supported by combined_nms.')
if soft_nms_sigma != 0.0:
raise ValueError('Soft NMS is not supported by combined_nms.')
if use_class_agnostic_nms:
raise ValueError('class-agnostic NMS is not supported by combined_nms.')
if clip_window is not None:
tf.compat.v1.logging.warning(
'clip_window is not supported by combined_nms unless it is'
' [0. 0. 1. 1.] for each image.')
if additional_fields is not None:
tf.compat.v1.logging.warning(
'additional_fields is not supported by combined_nms.')
if parallel_iterations != 32:
tf.compat.v1.logging.warning(
'Number of batch items to be processed in parallel is'
' not configurable by combined_nms.')
if max_classes_per_detection > 1:
tf.compat.v1.logging.warning(
'max_classes_per_detection is not configurable by combined_nms.')

with tf.name_scope(scope, 'CombinedNonMaxSuppression'):
(batch_nmsed_boxes, batch_nmsed_scores, batch_nmsed_classes,
batch_num_detections) = tf.image.combined_non_max_suppression(
boxes=boxes,
scores=scores,
max_output_size_per_class=max_size_per_class,
max_total_size=max_total_size,
iou_threshold=iou_thresh,
score_threshold=score_thresh,
pad_per_class=use_static_shapes)
# Not supported by combined_non_max_suppression.
batch_nmsed_masks = None
# Not supported by combined_non_max_suppression.
batch_nmsed_additional_fields = None
return (batch_nmsed_boxes, batch_nmsed_scores, batch_nmsed_classes,
batch_nmsed_masks, batch_nmsed_additional_fields,
batch_num_detections)

q = shape_utils.get_dim_as_int(boxes.shape[2])
num_classes = shape_utils.get_dim_as_int(scores.shape[2])
if q != 1 and q != num_classes:
raise ValueError('third dimension of boxes must be either 1 or equal '
'to the third dimension of scores')
'to the third dimension of scores.')
if change_coordinate_frame and clip_window is None:
raise ValueError('if change_coordinate_frame is True, then a clip_window'
'must be specified.')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1621,7 +1621,8 @@ def normalize_boxes(args):
normalize_boxes,
elems=[raw_proposal_boxes, image_shapes],
dtype=tf.float32)
proposal_multiclass_scores = nmsed_additional_fields['multiclass_scores']
proposal_multiclass_scores = nmsed_additional_fields.get(
'multiclass_scores') if nmsed_additional_fields else None,
return (normalized_proposal_boxes, proposal_scores,
proposal_multiclass_scores, num_proposals,
raw_normalized_proposal_boxes, rpn_objectness_softmax)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,8 @@ def postprocess(self, prediction_dict, true_image_shapes):
fields.DetectionResultFields.detection_classes:
nmsed_classes,
fields.DetectionResultFields.detection_multiclass_scores:
nmsed_additional_fields['multiclass_scores'],
nmsed_additional_fields.get(
'multiclass_scores') if nmsed_additional_fields else None,
fields.DetectionResultFields.num_detections:
tf.cast(num_detections, dtype=tf.float32),
fields.DetectionResultFields.raw_detection_boxes:
Expand Down
3 changes: 3 additions & 0 deletions research/object_detection/protos/faster_rcnn.proto
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ message FasterRcnn {
// If True, uses implementation of ops with static shape guarantees when
// running evaluation (specifically not is_training if False).
optional bool use_static_shapes_for_eval = 37 [default = false];

// Whether to use tf.image.combined_non_max_suppression.
optional bool use_combined_nms_in_first_stage = 38 [default=false];
}


Expand Down
3 changes: 3 additions & 0 deletions research/object_detection/protos/post_processing.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ message BatchNonMaxSuppression {

// Soft NMS sigma parameter; Bodla et al, https://arxiv.org/abs/1704.04503)
optional float soft_nms_sigma = 9 [default = 0.0];

// Whether to use tf.image.combined_non_max_suppression.
optional bool use_combined_nms = 10 [default = false];
}

// Configuration proto for post-processing predicted boxes and
Expand Down