Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions research/object_detection/box_coders/mean_stddev_box_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
class MeanStddevBoxCoder(box_coder.BoxCoder):
"""Mean stddev box coder."""

def __init__(self, stddev=0.01):
"""Constructor for MeanStddevBoxCoder.

Args:
stddev: The standard deviation used to encode and decode boxes.
"""
self._stddev = stddev

@property
def code_size(self):
return 4
Expand All @@ -34,37 +42,38 @@ def _encode(self, boxes, anchors):

Args:
boxes: BoxList holding N boxes to be encoded.
anchors: BoxList of N anchors. We assume that anchors has an associated
stddev field.
anchors: BoxList of N anchors.

Returns:
a tensor representing N anchor-encoded boxes

Raises:
ValueError: if the anchors BoxList does not have a stddev field
ValueError: if the anchors still have deprecated stddev field.
"""
if not anchors.has_field('stddev'):
raise ValueError('anchors must have a stddev field')
box_corners = boxes.get()
if anchors.has_field('stddev'):
raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and "
"should not be specified in the box list.")
means = anchors.get()
stddev = anchors.get_field('stddev')
return (box_corners - means) / stddev
return (box_corners - means) / self._stddev

def _decode(self, rel_codes, anchors):
"""Decode.

Args:
rel_codes: a tensor representing N anchor-encoded boxes.
anchors: BoxList of anchors. We assume that anchors has an associated
stddev field.
anchors: BoxList of anchors.

Returns:
boxes: BoxList holding N bounding boxes

Raises:
ValueError: if the anchors BoxList does not have a stddev field
ValueError: if the anchors still have deprecated stddev field and expects
the decode method to use stddev value from that field.
"""
if not anchors.has_field('stddev'):
raise ValueError('anchors must have a stddev field')
means = anchors.get()
stddevs = anchors.get_field('stddev')
box_corners = rel_codes * stddevs + means
if anchors.has_field('stddev'):
raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and "
"should not be specified in the box list.")
box_corners = rel_codes * self._stddev + means
return box_list.BoxList(box_corners)
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@ def testGetCorrectRelativeCodesAfterEncoding(self):
boxes = box_list.BoxList(tf.constant(box_corners))
expected_rel_codes = [[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]]
prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]])
prior_stddevs = tf.constant(2 * [4 * [.1]])
priors = box_list.BoxList(prior_means)
priors.add_field('stddev', prior_stddevs)

coder = mean_stddev_box_coder.MeanStddevBoxCoder()
coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
rel_codes = coder.encode(boxes, priors)
with self.test_session() as sess:
rel_codes_out = sess.run(rel_codes)
Expand All @@ -42,11 +40,9 @@ def testGetCorrectBoxesAfterDecoding(self):
rel_codes = tf.constant([[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]])
expected_box_corners = [[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5]]
prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]])
prior_stddevs = tf.constant(2 * [4 * [.1]])
priors = box_list.BoxList(prior_means)
priors.add_field('stddev', prior_stddevs)

coder = mean_stddev_box_coder.MeanStddevBoxCoder()
coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
decoded_boxes = coder.decode(rel_codes, priors)
decoded_box_corners = decoded_boxes.get()
with self.test_session() as sess:
Expand Down
3 changes: 2 additions & 1 deletion research/object_detection/builders/box_coder_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def build(box_coder_config):
])
if (box_coder_config.WhichOneof('box_coder_oneof') ==
'mean_stddev_box_coder'):
return mean_stddev_box_coder.MeanStddevBoxCoder()
return mean_stddev_box_coder.MeanStddevBoxCoder(
stddev=box_coder_config.mean_stddev_box_coder.stddev)
if box_coder_config.WhichOneof('box_coder_oneof') == 'square_box_coder':
return square_box_coder.SquareBoxCoder(scale_factors=[
box_coder_config.square_box_coder.y_scale,
Expand Down
42 changes: 42 additions & 0 deletions research/object_detection/builders/graph_rewriter_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for quantized training and evaluation."""

import tensorflow as tf


def build(graph_rewriter_config, is_training):
"""Returns a function that modifies default graph based on options.

Args:
graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto.
is_training: whether in training of eval mode.
"""
def graph_rewrite_fn():
"""Function to quantize weights and activation of the default graph."""
if (graph_rewriter_config.quantization.weight_bits != 8 or
graph_rewriter_config.quantization.activation_bits != 8):
raise ValueError('Only 8bit quantization is supported')

# Quantize the graph by inserting quantize ops for weights and activations
if is_training:
tf.contrib.quantize.create_training_graph(
input_graph=tf.get_default_graph(),
quant_delay=graph_rewriter_config.quantization.delay)
else:
tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph())

tf.contrib.layers.summarize_collection('quant_vars')
return graph_rewrite_fn
57 changes: 57 additions & 0 deletions research/object_detection/builders/graph_rewriter_builder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graph_rewriter_builder."""
import mock
import tensorflow as tf
from object_detection.builders import graph_rewriter_builder
from object_detection.protos import graph_rewriter_pb2


class QuantizationBuilderTest(tf.test.TestCase):

def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
with mock.patch.object(
tf.contrib.quantize, 'create_training_graph') as mock_quant_fn:
with mock.patch.object(tf.contrib.layers,
'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10
graph_rewriter_proto.quantization.weight_bits = 8
graph_rewriter_proto.quantization.activation_bits = 8
graph_rewrite_fn = graph_rewriter_builder.build(
graph_rewriter_proto, is_training=True)
graph_rewrite_fn()
_, kwargs = mock_quant_fn.call_args
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
self.assertEqual(kwargs['quant_delay'], 10)
mock_summarize_col.assert_called_with('quant_vars')

def testQuantizationBuilderSetsUpCorrectEvalArguments(self):
with mock.patch.object(tf.contrib.quantize,
'create_eval_graph') as mock_quant_fn:
with mock.patch.object(tf.contrib.layers,
'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10
graph_rewrite_fn = graph_rewriter_builder.build(
graph_rewriter_proto, is_training=False)
graph_rewrite_fn()
_, kwargs = mock_quant_fn.call_args
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
mock_summarize_col.assert_called_with('quant_vars')


if __name__ == '__main__':
tf.test.main()
17 changes: 14 additions & 3 deletions research/object_detection/builders/losses_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""A function to build localization and classification losses from config."""

from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import losses
from object_detection.protos import losses_pb2

Expand All @@ -34,9 +35,12 @@ def build(loss_config):
classification_weight: Classification loss weight.
localization_weight: Localization loss weight.
hard_example_miner: Hard example miner object.
random_example_sampler: BalancedPositiveNegativeSampler object.

Raises:
ValueError: If hard_example_miner is used with sigmoid_focal_loss.
ValueError: If random_example_sampler is getting non-positive value as
desired positive example fraction.
"""
classification_loss = _build_classification_loss(
loss_config.classification_loss)
Expand All @@ -54,9 +58,16 @@ def build(loss_config):
loss_config.hard_example_miner,
classification_weight,
localization_weight)
return (classification_loss, localization_loss,
classification_weight,
localization_weight, hard_example_miner)
random_example_sampler = None
if loss_config.HasField('random_example_sampler'):
if loss_config.random_example_sampler.positive_sample_fraction <= 0:
raise ValueError('RandomExampleSampler should not use non-positive'
'value as positive sample fraction.')
random_example_sampler = sampler.BalancedPositiveNegativeSampler(
positive_fraction=loss_config.random_example_sampler.
positive_sample_fraction)
return (classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, random_example_sampler)


def build_hard_example_miner(config,
Expand Down
Loading