Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a99b58f
Adding option for one_box_for_all_classes to the box_predictor
pkulzc Apr 13, 2018
2bbccfe
Extend to accept different ratios of conv channels.
pkulzc Apr 13, 2018
dfbc5c0
Remove inaccurate caveat from proto file.
pkulzc Apr 14, 2018
ef5d809
Add option to set dropout for classification net in weight shared box…
pkulzc Apr 15, 2018
a7ea572
fix flakiness in testSSDRandomCropWithMultiClassScores due to randomn…
pkulzc Apr 16, 2018
732cb5b
Post-process now works again in train mode.
pkulzc Apr 16, 2018
1541668
Adding support for reading in logits as groundtruth labels and applyi…
pkulzc Apr 16, 2018
7b6619b
Add a util function to visualize value histogram as a tf.summary.image.
pkulzc Apr 17, 2018
8f5b8a0
Do not add batch norm parameters to final conv2d ops that predict box…
pkulzc Apr 17, 2018
1f36070
Make sure the final layers are also resized proportional to conv_dept…
pkulzc Apr 17, 2018
5337ade
Remove deprecated batch_norm_trainable field from ssd mobilenet v2 co…
pkulzc Apr 17, 2018
b056905
Updating coco evaluation metrics to allow for a batch of image info, …
pkulzc Apr 18, 2018
0f0cdef
Update protobuf requirements to 3+ in installation docs.
pkulzc Apr 18, 2018
d8a852a
Add support for training keypoints.
pkulzc Apr 19, 2018
2fbea58
Fix data augmentation functions.
pkulzc Apr 20, 2018
0e3566a
Read the default batch size from config file.
pkulzc Apr 23, 2018
4ed70f6
Fixing a bug in the coco evaluator.
pkulzc Apr 23, 2018
0d13a83
num_gt_boxes_per_image and num_det_boxes_per_image value incorrect.
pkulzc Apr 24, 2018
d74cd76
Add option to evaluate any checkpoint (without requiring write access…
pkulzc Apr 25, 2018
faa7193
PiperOrigin-RevId: 190346687
sguada Mar 24, 2018
c4da8e0
- Expose slim arg_scope function to compute keys to enable tessting.
sguada Mar 29, 2018
3cdb031
Add an option to not set slim arg_scope for batch_norm is_training pa…
sguada Apr 4, 2018
5e43819
PiperOrigin-RevId: 191955231
sguada Apr 6, 2018
83f0e34
PiperOrigin-RevId: 193254125
sguada Apr 17, 2018
79dbee1
PiperOrigin-RevId: 193371562
sguada Apr 18, 2018
922a985
PiperOrigin-RevId: 194085628
sguada Apr 24, 2018
21654bf
Sync to latest.
pkulzc Apr 26, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions research/object_detection/builders/box_predictor_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,14 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
num_classes=num_classes,
conv_hyperparams_fn=conv_hyperparams_fn,
depth=conv_box_predictor.depth,
num_layers_before_predictor=(conv_box_predictor.
num_layers_before_predictor),
num_layers_before_predictor=(
conv_box_predictor.num_layers_before_predictor),
kernel_size=conv_box_predictor.kernel_size,
box_code_size=conv_box_predictor.box_code_size,
class_prediction_bias_init=conv_box_predictor.class_prediction_bias_init
)
class_prediction_bias_init=conv_box_predictor.
class_prediction_bias_init,
use_dropout=conv_box_predictor.use_dropout,
dropout_keep_prob=conv_box_predictor.dropout_keep_probability)
return box_predictor_object

if box_predictor_oneof == 'mask_rcnn_box_predictor':
Expand Down Expand Up @@ -113,7 +115,9 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
mask_rcnn_box_predictor.mask_prediction_conv_depth),
masks_are_class_agnostic=(
mask_rcnn_box_predictor.masks_are_class_agnostic),
predict_keypoints=mask_rcnn_box_predictor.predict_keypoints)
predict_keypoints=mask_rcnn_box_predictor.predict_keypoints,
share_box_across_classes=(
mask_rcnn_box_predictor.share_box_across_classes))
return box_predictor_object

if box_predictor_oneof == 'rfcn_box_predictor':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def test_non_default_mask_rcnn_box_predictor(self):
use_dropout: true
dropout_keep_probability: 0.8
box_code_size: 3
share_box_across_classes: true
}
"""
hyperparams_proto = hyperparams_pb2.Hyperparams()
Expand All @@ -338,6 +339,7 @@ def mock_fc_argscope_builder(fc_hyperparams_arg, is_training):
self.assertEqual(box_predictor.num_classes, 90)
self.assertTrue(box_predictor._is_training)
self.assertEqual(box_predictor._box_code_size, 3)
self.assertEqual(box_predictor._share_box_across_classes, True)

def test_build_default_mask_rcnn_box_predictor(self):
box_predictor_proto = box_predictor_pb2.BoxPredictor()
Expand Down
9 changes: 9 additions & 0 deletions research/object_detection/builders/losses_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def build_faster_rcnn_classification_loss(loss_config):
config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss(
logit_scale=config.logit_scale)
if loss_type == 'weighted_logits_softmax':
config = loss_config.weighted_logits_softmax
return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
logit_scale=config.logit_scale)

# By default, Faster RCNN second stage classifier uses Softmax loss
# with anchor-wise outputs.
Expand Down Expand Up @@ -193,6 +197,11 @@ def _build_classification_loss(loss_config):
return losses.WeightedSoftmaxClassificationLoss(
logit_scale=config.logit_scale)

if loss_type == 'weighted_logits_softmax':
config = loss_config.weighted_logits_softmax
return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
logit_scale=config.logit_scale)

if loss_type == 'bootstrapped_sigmoid':
config = loss_config.bootstrapped_sigmoid
return losses.BootstrappedSigmoidClassificationLoss(
Expand Down
31 changes: 31 additions & 0 deletions research/object_detection/builders/losses_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,24 @@ def test_build_weighted_softmax_classification_loss(self):
self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss))

def test_build_weighted_logits_softmax_classification_loss(self):
losses_text_proto = """
classification_loss {
weighted_logits_softmax {
}
}
localization_loss {
weighted_l2 {
}
}
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(
isinstance(classification_loss,
losses.WeightedSoftmaxClassificationAgainstLogitsLoss))

def test_build_weighted_softmax_classification_loss_with_logit_scale(self):
losses_text_proto = """
classification_loss {
Expand Down Expand Up @@ -442,6 +460,19 @@ def test_build_softmax_loss(self):
self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss))

def test_build_logits_softmax_loss(self):
losses_text_proto = """
weighted_logits_softmax {
}
"""
losses_proto = losses_pb2.ClassificationLoss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto)
self.assertTrue(
isinstance(classification_loss,
losses.WeightedSoftmaxClassificationAgainstLogitsLoss))

def test_build_softmax_loss_by_default(self):
losses_text_proto = """
"""
Expand Down
28 changes: 23 additions & 5 deletions research/object_detection/core/box_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def __init__(self,
mask_prediction_num_conv_layers=2,
mask_prediction_conv_depth=256,
masks_are_class_agnostic=False,
predict_keypoints=False):
predict_keypoints=False,
share_box_across_classes=False):
"""Constructor.

Args:
Expand Down Expand Up @@ -341,7 +342,8 @@ def __init__(self,
masks_are_class_agnostic: Boolean determining if the mask-head is
class-agnostic or not.
predict_keypoints: Whether to predict keypoints insde detection boxes.

share_box_across_classes: Whether to share boxes across classes rather
than use a different box for each class.

Raises:
ValueError: If predict_instance_masks is true but conv_hyperparams is not
Expand All @@ -362,6 +364,7 @@ def __init__(self,
self._mask_prediction_conv_depth = mask_prediction_conv_depth
self._masks_are_class_agnostic = masks_are_class_agnostic
self._predict_keypoints = predict_keypoints
self._share_box_across_classes = share_box_across_classes
if self._predict_keypoints:
raise ValueError('Keypoint prediction is unimplemented.')
if ((self._predict_instance_masks or self._predict_keypoints) and
Expand Down Expand Up @@ -403,10 +406,14 @@ def _predict_boxes_and_classes(self, image_features):
flattened_image_features = slim.dropout(flattened_image_features,
keep_prob=self._dropout_keep_prob,
is_training=self._is_training)
number_of_boxes = 1
if not self._share_box_across_classes:
number_of_boxes = self._num_classes

with slim.arg_scope(self._fc_hyperparams_fn()):
box_encodings = slim.fully_connected(
flattened_image_features,
self._num_classes * self._box_code_size,
number_of_boxes * self._box_code_size,
activation_fn=None,
scope='BoxEncodingPredictor')
class_predictions_with_background = slim.fully_connected(
Expand All @@ -415,7 +422,7 @@ def _predict_boxes_and_classes(self, image_features):
activation_fn=None,
scope='ClassPredictor')
box_encodings = tf.reshape(
box_encodings, [-1, 1, self._num_classes, self._box_code_size])
box_encodings, [-1, 1, number_of_boxes, self._box_code_size])
class_predictions_with_background = tf.reshape(
class_predictions_with_background, [-1, 1, self._num_classes + 1])
return box_encodings, class_predictions_with_background
Expand Down Expand Up @@ -778,7 +785,9 @@ def __init__(self,
num_layers_before_predictor,
box_code_size,
kernel_size=3,
class_prediction_bias_init=0.0):
class_prediction_bias_init=0.0,
use_dropout=False,
dropout_keep_prob=0.8):
"""Constructor.

Args:
Expand All @@ -796,6 +805,8 @@ def __init__(self,
kernel_size: Size of final convolution kernel.
class_prediction_bias_init: constant value to initialize bias of the last
conv2d layer before class prediction.
use_dropout: Whether to apply dropout to class prediction head.
dropout_keep_prob: Probability of keeping activiations.
"""
super(WeightSharedConvolutionalBoxPredictor, self).__init__(is_training,
num_classes)
Expand All @@ -805,6 +816,8 @@ def __init__(self,
self._box_code_size = box_code_size
self._kernel_size = kernel_size
self._class_prediction_bias_init = class_prediction_bias_init
self._use_dropout = use_dropout
self._dropout_keep_prob = dropout_keep_prob

def _predict(self, image_features, num_predictions_per_location_list):
"""Computes encoded object locations and corresponding confidences.
Expand Down Expand Up @@ -867,6 +880,7 @@ def _predict(self, image_features, num_predictions_per_location_list):
num_predictions_per_location * self._box_code_size,
[self._kernel_size, self._kernel_size],
activation_fn=None, stride=1, padding='SAME',
normalizer_fn=None,
scope='BoxEncodingPredictor')

for i in range(self._num_layers_before_predictor):
Expand All @@ -877,11 +891,15 @@ def _predict(self, image_features, num_predictions_per_location_list):
stride=1,
padding='SAME',
scope='ClassPredictionTower/conv2d_{}'.format(i))
if self._use_dropout:
class_predictions_net = slim.dropout(
class_predictions_net, keep_prob=self._dropout_keep_prob)
class_predictions_with_background = slim.conv2d(
class_predictions_net,
num_predictions_per_location * num_class_slots,
[self._kernel_size, self._kernel_size],
activation_fn=None, stride=1, padding='SAME',
normalizer_fn=None,
biases_initializer=tf.constant_initializer(
self._class_prediction_bias_init),
scope='ClassPredictor')
Expand Down
63 changes: 58 additions & 5 deletions research/object_detection/core/box_predictor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,33 @@ def test_get_boxes_with_five_classes(self):
self.assertAllEqual(box_encodings_shape, [2, 1, 5, 4])
self.assertAllEqual(class_predictions_with_background_shape, [2, 1, 6])

def test_get_boxes_with_five_classes_share_box_across_classes(self):
image_features = tf.random_uniform([2, 7, 7, 3], dtype=tf.float32)
mask_box_predictor = box_predictor.MaskRCNNBoxPredictor(
is_training=False,
num_classes=5,
fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
use_dropout=False,
dropout_keep_prob=0.5,
box_code_size=4,
share_box_across_classes=True
)
box_predictions = mask_box_predictor.predict(
[image_features], num_predictions_per_location=[1],
scope='BoxPredictor')
box_encodings = box_predictions[box_predictor.BOX_ENCODINGS]
class_predictions_with_background = box_predictions[
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND]
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init_op)
(box_encodings_shape,
class_predictions_with_background_shape) = sess.run(
[tf.shape(box_encodings),
tf.shape(class_predictions_with_background)])
self.assertAllEqual(box_encodings_shape, [2, 1, 1, 4])
self.assertAllEqual(class_predictions_with_background_shape, [2, 1, 6])

def test_value_error_on_predict_instance_masks_with_no_conv_hyperparms(self):
with self.assertRaises(ValueError):
box_predictor.MaskRCNNBoxPredictor(
Expand Down Expand Up @@ -403,9 +430,14 @@ def _build_arg_scope_with_conv_hyperparams(self):
}
}
initializer {
truncated_normal_initializer {
random_normal_initializer {
stddev: 0.01
mean: 0.0
}
}
batch_norm {
train: true,
}
"""
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams)
return hyperparams_builder.build(conv_hyperparams, is_training=True)
Expand Down Expand Up @@ -434,6 +466,27 @@ def graph_fn(image_features):
self.assertAllEqual(box_encodings.shape, [4, 320, 1, 4])
self.assertAllEqual(objectness_predictions.shape, [4, 320, 1])

def test_bias_predictions_to_background_with_sigmoid_score_conversion(self):

def graph_fn(image_features):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=True,
num_classes=2,
conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32,
num_layers_before_predictor=1,
class_prediction_bias_init=-4.6,
box_code_size=4)
box_predictions = conv_box_predictor.predict(
[image_features], num_predictions_per_location=[5],
scope='BoxPredictor')
class_predictions = tf.concat(box_predictions[
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND], axis=1)
return (tf.nn.sigmoid(class_predictions),)
image_features = np.random.rand(4, 8, 8, 64).astype(np.float32)
class_predictions = self.execute(graph_fn, [image_features])
self.assertAlmostEqual(np.mean(class_predictions), 0.01, places=3)

def test_get_multi_class_predictions_for_five_aspect_ratios_per_location(
self):

Expand Down Expand Up @@ -524,19 +577,19 @@ def graph_fn(image_features1, image_features2):
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxEncodingPredictionTower/conv2d_0/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxEncodingPredictionTower/conv2d_0/biases'),
'BoxEncodingPredictionTower/conv2d_0/BatchNorm/beta'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxEncodingPredictionTower/conv2d_1/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxEncodingPredictionTower/conv2d_1/biases'),
'BoxEncodingPredictionTower/conv2d_1/BatchNorm/beta'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_0/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_0/biases'),
'ClassPredictionTower/conv2d_0/BatchNorm/beta'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_1/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'ClassPredictionTower/conv2d_1/biases'),
'ClassPredictionTower/conv2d_1/BatchNorm/beta'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
'BoxEncodingPredictor/weights'),
('BoxPredictor/WeightSharedConvolutionalBoxPredictor/'
Expand Down
49 changes: 49 additions & 0 deletions research/object_detection/core/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Classification losses:
* WeightedSigmoidClassificationLoss
* WeightedSoftmaxClassificationLoss
* WeightedSoftmaxClassificationAgainstLogitsLoss
* BootstrappedSigmoidClassificationLoss
"""
from abc import ABCMeta
Expand Down Expand Up @@ -317,6 +318,54 @@ def _compute_loss(self, prediction_tensor, target_tensor, weights):
return tf.reshape(per_row_cross_ent, tf.shape(weights)) * weights


class WeightedSoftmaxClassificationAgainstLogitsLoss(Loss):
"""Softmax loss function against logits.

Targets are expected to be provided in logits space instead of "one hot" or
"probability distribution" space.
"""

def __init__(self, logit_scale=1.0):
"""Constructor.

Args:
logit_scale: When this value is high, the target is "diffused" and
when this value is low, the target is made peakier.
(default 1.0)

"""
self._logit_scale = logit_scale

def _scale_and_softmax_logits(self, logits):
"""Scale logits then apply softmax."""
scaled_logits = tf.divide(logits, self._logit_scale, name='scale_logits')
return tf.nn.softmax(scaled_logits, name='convert_scores')

def _compute_loss(self, prediction_tensor, target_tensor, weights):
"""Compute loss function.

Args:
prediction_tensor: A float tensor of shape [batch_size, num_anchors,
num_classes] representing the predicted logits for each class
target_tensor: A float tensor of shape [batch_size, num_anchors,
num_classes] representing logit classification targets
weights: a float tensor of shape [batch_size, num_anchors]

Returns:
loss: a float tensor of shape [batch_size, num_anchors]
representing the value of the loss function.
"""
num_classes = prediction_tensor.get_shape().as_list()[-1]
target_tensor = self._scale_and_softmax_logits(target_tensor)
prediction_tensor = tf.divide(prediction_tensor, self._logit_scale,
name='scale_logits')

per_row_cross_ent = (tf.nn.softmax_cross_entropy_with_logits(
labels=tf.reshape(target_tensor, [-1, num_classes]),
logits=tf.reshape(prediction_tensor, [-1, num_classes])))
return tf.reshape(per_row_cross_ent, tf.shape(weights)) * weights


class BootstrappedSigmoidClassificationLoss(Loss):
"""Bootstrapped sigmoid cross entropy classification loss function.

Expand Down
Loading