Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion research/object_detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Extras:
Run the evaluation for the Open Images Challenge 2018</a><br>
* <a href='g3doc/tpu_compatibility.md'>
TPU compatible detection pipelines</a><br>
* <a href='g3doc/running_on_mobile_tensorflowlite.md'>
* <a href='g3doc/running_on_mobile_tensorflowlite.md'>
Running object detection on mobile devices with TensorFlow Lite</a><br>

## Getting Help
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,10 @@ def _generate(self, feature_map_shape_list, im_height=1, im_width=1):
correspond to an 8x8 layer followed by a 7x7 layer.
im_height: the height of the image to generate the grid for. If both
im_height and im_width are 1, the generated anchors default to
normalized coordinates, otherwise absolute coordinates are used for the
grid.
absolute coordinates, otherwise normalized coordinates are produced.
im_width: the width of the image to generate the grid for. If both
im_height and im_width are 1, the generated anchors default to
normalized coordinates, otherwise absolute coordinates are used for the
grid.
absolute coordinates, otherwise normalized coordinates are produced.

Returns:
boxes_list: a list of BoxLists each holding anchor boxes corresponding to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,12 @@ def __init__(self, min_level, max_level, anchor_scale, aspect_ratios,
self._scales_per_octave = scales_per_octave
self._normalize_coordinates = normalize_coordinates

scales = [2**(float(scale) / scales_per_octave)
for scale in xrange(scales_per_octave)]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

py3 compatibility. Perhaps six? or are these small enough to just use a list?

aspects = list(aspect_ratios)

for level in range(min_level, max_level + 1):
anchor_stride = [2**level, 2**level]
scales = []
aspects = []
for scale in range(scales_per_octave):
scales.append(2**(float(scale) / scales_per_octave))
for aspect_ratio in aspect_ratios:
aspects.append(aspect_ratio)
base_anchor_size = [2**level * anchor_scale, 2**level * anchor_scale]
self._anchor_grid_info.append({
'level': level,
Expand All @@ -84,7 +82,7 @@ def num_anchors_per_location(self):
return len(self._anchor_grid_info) * [
len(self._aspect_ratios) * self._scales_per_octave]

def _generate(self, feature_map_shape_list, im_height, im_width):
def _generate(self, feature_map_shape_list, im_height=1, im_width=1):
"""Generates a collection of bounding boxes to be used as anchors.

Currently we require the input image shape to be statically defined. That
Expand All @@ -95,14 +93,20 @@ def _generate(self, feature_map_shape_list, im_height, im_width):
format [(height_0, width_0), (height_1, width_1), ...]. For example,
setting feature_map_shape_list=[(8, 8), (7, 7)] asks for anchors that
correspond to an 8x8 layer followed by a 7x7 layer.
im_height: the height of the image to generate the grid for.
im_width: the width of the image to generate the grid for.
im_height: the height of the image to generate the grid for. If both
im_height and im_width are 1, anchors can only be generated in
absolute coordinates.
im_width: the width of the image to generate the grid for. If both
im_height and im_width are 1, anchors can only be generated in
absolute coordinates.

Returns:
boxes_list: a list of BoxLists each holding anchor boxes corresponding to
the input feature map shapes.
Raises:
ValueError: if im_height and im_width are not integers.
ValueError: if im_height and im_width are 1, but normalized coordinates
were requested.
"""
if not isinstance(im_height, int) or not isinstance(im_width, int):
raise ValueError('MultiscaleGridAnchorGenerator currently requires '
Expand All @@ -118,9 +122,9 @@ def _generate(self, feature_map_shape_list, im_height, im_width):
feat_h = feat_shape[0]
feat_w = feat_shape[1]
anchor_offset = [0, 0]
if im_height % 2.0**level == 0:
if im_height % 2.0**level == 0 or im_height == 1:
anchor_offset[0] = stride / 2.0
if im_width % 2.0**level == 0:
if im_width % 2.0**level == 0 or im_width == 1:
anchor_offset[1] = stride / 2.0
ag = grid_anchor_generator.GridAnchorGenerator(
scales,
Expand All @@ -131,6 +135,11 @@ def _generate(self, feature_map_shape_list, im_height, im_width):
(anchor_grid,) = ag.generate(feature_map_shape_list=[(feat_h, feat_w)])

if self._normalize_coordinates:
if im_height == 1 or im_width == 1:
raise ValueError(
'Normalized coordinates were requested upon construction of the '
'MultiscaleGridAnchorGenerator, but a subsequent call to '
'generate did not supply dimension information.')
anchor_grid = box_list_ops.to_normalized_coordinates(
anchor_grid, im_height, im_width, check_range=False)
anchor_grid_list.append(anchor_grid)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,40 @@ def test_construct_single_anchor(self):
anchor_corners_out = anchor_corners.eval()
self.assertAllClose(anchor_corners_out, exp_anchor_corners)

def test_construct_single_anchor_unit_dimensions(self):
min_level = 5
max_level = 5
anchor_scale = 1.0
aspect_ratios = [1.0]
scales_per_octave = 1
im_height = 1
im_width = 1
feature_map_shape_list = [(2, 2)]
# Positive offsets are produced.
exp_anchor_corners = [[0, 0, 32, 32],
[0, 32, 32, 64],
[32, 0, 64, 32],
[32, 32, 64, 64]]

anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave,
normalize_coordinates=False)
anchors_list = anchor_generator.generate(
feature_map_shape_list, im_height=im_height, im_width=im_width)
anchor_corners = anchors_list[0].get()

with self.test_session():
anchor_corners_out = anchor_corners.eval()
self.assertAllClose(anchor_corners_out, exp_anchor_corners)

def test_construct_normalized_anchors_fails_with_unit_dimensions(self):
anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level=5, max_level=5, anchor_scale=1.0, aspect_ratios=[1.0],
scales_per_octave=1, normalize_coordinates=True)
with self.assertRaisesRegexp(ValueError, 'Normalized coordinates'):
anchor_generator.generate(
feature_map_shape_list=[(2, 2)], im_height=1, im_width=1)

def test_construct_single_anchor_in_normalized_coordinates(self):
min_level = 5
max_level = 5
Expand Down Expand Up @@ -94,7 +128,7 @@ def test_construct_single_anchor_fails_with_tensor_image_size(self):
anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave,
normalize_coordinates=False)
with self.assertRaises(ValueError):
with self.assertRaisesRegexp(ValueError, 'statically defined'):
anchor_generator.generate(
feature_map_shape_list, im_height=im_height, im_width=im_width)

Expand Down
157 changes: 91 additions & 66 deletions research/object_detection/builders/box_predictor_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@

"""Function to build box predictor from configuration."""

from object_detection.core import box_predictor
from object_detection.predictors import convolutional_box_predictor
from object_detection.predictors import mask_rcnn_box_predictor
from object_detection.predictors import rfcn_box_predictor
from object_detection.predictors.mask_rcnn_heads import box_head
from object_detection.predictors.mask_rcnn_heads import class_head
from object_detection.predictors.mask_rcnn_heads import mask_head
from object_detection.protos import box_predictor_pb2


Expand Down Expand Up @@ -48,92 +53,112 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
box_predictor_oneof = box_predictor_config.WhichOneof('box_predictor_oneof')

if box_predictor_oneof == 'convolutional_box_predictor':
conv_box_predictor = box_predictor_config.convolutional_box_predictor
conv_hyperparams_fn = argscope_fn(conv_box_predictor.conv_hyperparams,
config_box_predictor = box_predictor_config.convolutional_box_predictor
conv_hyperparams_fn = argscope_fn(config_box_predictor.conv_hyperparams,
is_training)
box_predictor_object = box_predictor.ConvolutionalBoxPredictor(
is_training=is_training,
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
min_depth=conv_box_predictor.min_depth,
max_depth=conv_box_predictor.max_depth,
num_layers_before_predictor=(conv_box_predictor.
num_layers_before_predictor),
use_dropout=conv_box_predictor.use_dropout,
dropout_keep_prob=conv_box_predictor.dropout_keep_probability,
kernel_size=conv_box_predictor.kernel_size,
box_code_size=conv_box_predictor.box_code_size,
apply_sigmoid_to_scores=conv_box_predictor.apply_sigmoid_to_scores,
class_prediction_bias_init=(conv_box_predictor.
class_prediction_bias_init),
use_depthwise=conv_box_predictor.use_depthwise
)
box_predictor_object = (
convolutional_box_predictor.ConvolutionalBoxPredictor(
is_training=is_training,
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
min_depth=config_box_predictor.min_depth,
max_depth=config_box_predictor.max_depth,
num_layers_before_predictor=(
config_box_predictor.num_layers_before_predictor),
use_dropout=config_box_predictor.use_dropout,
dropout_keep_prob=config_box_predictor.dropout_keep_probability,
kernel_size=config_box_predictor.kernel_size,
box_code_size=config_box_predictor.box_code_size,
apply_sigmoid_to_scores=config_box_predictor.
apply_sigmoid_to_scores,
class_prediction_bias_init=(
config_box_predictor.class_prediction_bias_init),
use_depthwise=config_box_predictor.use_depthwise))
return box_predictor_object

if box_predictor_oneof == 'weight_shared_convolutional_box_predictor':
conv_box_predictor = (box_predictor_config.
weight_shared_convolutional_box_predictor)
conv_hyperparams_fn = argscope_fn(conv_box_predictor.conv_hyperparams,
config_box_predictor = (
box_predictor_config.weight_shared_convolutional_box_predictor)
conv_hyperparams_fn = argscope_fn(config_box_predictor.conv_hyperparams,
is_training)
box_predictor_object = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=is_training,
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
depth=conv_box_predictor.depth,
num_layers_before_predictor=(
conv_box_predictor.num_layers_before_predictor),
kernel_size=conv_box_predictor.kernel_size,
box_code_size=conv_box_predictor.box_code_size,
class_prediction_bias_init=conv_box_predictor.
class_prediction_bias_init,
use_dropout=conv_box_predictor.use_dropout,
dropout_keep_prob=conv_box_predictor.dropout_keep_probability,
share_prediction_tower=conv_box_predictor.share_prediction_tower)
apply_batch_norm = config_box_predictor.conv_hyperparams.HasField(
'batch_norm')
box_predictor_object = (
convolutional_box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=is_training,
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
depth=config_box_predictor.depth,
num_layers_before_predictor=(
config_box_predictor.num_layers_before_predictor),
kernel_size=config_box_predictor.kernel_size,
box_code_size=config_box_predictor.box_code_size,
class_prediction_bias_init=config_box_predictor.
class_prediction_bias_init,
use_dropout=config_box_predictor.use_dropout,
dropout_keep_prob=config_box_predictor.dropout_keep_probability,
share_prediction_tower=config_box_predictor.share_prediction_tower,
apply_batch_norm=apply_batch_norm))
return box_predictor_object

if box_predictor_oneof == 'mask_rcnn_box_predictor':
mask_rcnn_box_predictor = box_predictor_config.mask_rcnn_box_predictor
fc_hyperparams_fn = argscope_fn(mask_rcnn_box_predictor.fc_hyperparams,
config_box_predictor = box_predictor_config.mask_rcnn_box_predictor
fc_hyperparams_fn = argscope_fn(config_box_predictor.fc_hyperparams,
is_training)
conv_hyperparams_fn = None
if mask_rcnn_box_predictor.HasField('conv_hyperparams'):
if config_box_predictor.HasField('conv_hyperparams'):
conv_hyperparams_fn = argscope_fn(
mask_rcnn_box_predictor.conv_hyperparams, is_training)
box_predictor_object = box_predictor.MaskRCNNBoxPredictor(
config_box_predictor.conv_hyperparams, is_training)
box_prediction_head = box_head.BoxHead(
is_training=is_training,
num_classes=num_classes,
fc_hyperparams_fn=fc_hyperparams_fn,
use_dropout=mask_rcnn_box_predictor.use_dropout,
dropout_keep_prob=mask_rcnn_box_predictor.dropout_keep_probability,
box_code_size=mask_rcnn_box_predictor.box_code_size,
conv_hyperparams_fn=conv_hyperparams_fn,
predict_instance_masks=mask_rcnn_box_predictor.predict_instance_masks,
mask_height=mask_rcnn_box_predictor.mask_height,
mask_width=mask_rcnn_box_predictor.mask_width,
mask_prediction_num_conv_layers=(
mask_rcnn_box_predictor.mask_prediction_num_conv_layers),
mask_prediction_conv_depth=(
mask_rcnn_box_predictor.mask_prediction_conv_depth),
masks_are_class_agnostic=(
mask_rcnn_box_predictor.masks_are_class_agnostic),
predict_keypoints=mask_rcnn_box_predictor.predict_keypoints,
use_dropout=config_box_predictor.use_dropout,
dropout_keep_prob=config_box_predictor.dropout_keep_probability,
box_code_size=config_box_predictor.box_code_size,
share_box_across_classes=(
mask_rcnn_box_predictor.share_box_across_classes))
config_box_predictor.share_box_across_classes))
class_prediction_head = class_head.ClassHead(
is_training=is_training,
num_classes=num_classes,
fc_hyperparams_fn=fc_hyperparams_fn,
use_dropout=config_box_predictor.use_dropout,
dropout_keep_prob=config_box_predictor.dropout_keep_probability)
third_stage_heads = {}
if config_box_predictor.predict_instance_masks:
third_stage_heads[
mask_rcnn_box_predictor.MASK_PREDICTIONS] = mask_head.MaskHead(
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
mask_height=config_box_predictor.mask_height,
mask_width=config_box_predictor.mask_width,
mask_prediction_num_conv_layers=(
config_box_predictor.mask_prediction_num_conv_layers),
mask_prediction_conv_depth=(
config_box_predictor.mask_prediction_conv_depth),
masks_are_class_agnostic=(
config_box_predictor.masks_are_class_agnostic))
box_predictor_object = mask_rcnn_box_predictor.MaskRCNNBoxPredictor(
is_training=is_training,
num_classes=num_classes,
box_prediction_head=box_prediction_head,
class_prediction_head=class_prediction_head,
third_stage_heads=third_stage_heads)
return box_predictor_object

if box_predictor_oneof == 'rfcn_box_predictor':
rfcn_box_predictor = box_predictor_config.rfcn_box_predictor
conv_hyperparams_fn = argscope_fn(rfcn_box_predictor.conv_hyperparams,
config_box_predictor = box_predictor_config.rfcn_box_predictor
conv_hyperparams_fn = argscope_fn(config_box_predictor.conv_hyperparams,
is_training)
box_predictor_object = box_predictor.RfcnBoxPredictor(
box_predictor_object = rfcn_box_predictor.RfcnBoxPredictor(
is_training=is_training,
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
crop_size=[rfcn_box_predictor.crop_height,
rfcn_box_predictor.crop_width],
num_spatial_bins=[rfcn_box_predictor.num_spatial_bins_height,
rfcn_box_predictor.num_spatial_bins_width],
depth=rfcn_box_predictor.depth,
box_code_size=rfcn_box_predictor.box_code_size)
crop_size=[config_box_predictor.crop_height,
config_box_predictor.crop_width],
num_spatial_bins=[config_box_predictor.num_spatial_bins_height,
config_box_predictor.num_spatial_bins_width],
depth=config_box_predictor.depth,
box_code_size=config_box_predictor.box_code_size)
return box_predictor_object
raise ValueError('Unknown box predictor: {}'.format(box_predictor_oneof))
Loading