Skip to content

Commit

Permalink
Merged commit includes the following changes:
Browse files Browse the repository at this point in the history
196161788  by Zhichao Lu:

    Add eval_on_train_steps parameter.

    Since the number of samples in train dataset is usually different to the number of samples in the eval dataset.

--
196151742  by Zhichao Lu:

    Add an optional random sampling process for SSD meta arch and update mean stddev coder to use default std dev when corresponding tensor is not added to boxlist field.

--
196148940  by Zhichao Lu:

    Release ssdlite mobilenet v2 coco trained model.

--
196058528  by Zhichao Lu:

    Apply FPN feature map generation before we add additional layers on top of resnet feature extractor.

--
195818367  by Zhichao Lu:

    Add support for exporting detection keypoints.

--
195745420  by Zhichao Lu:

    Introduce include_metrics_per_category option to Object Detection eval_config.

--
195734733  by Zhichao Lu:

    Rename SSDLite config to be more explicit.

--
195717383  by Zhichao Lu:

    Add quantized training to object_detection.

--
195683542  by Zhichao Lu:

    Fix documentation for the interaction of fine_tune_checkpoint_type and load_all_detection_checkpoint_vars interaction.

--
195668233  by Zhichao Lu:

    Using batch size from params dictionary if present.

--
195570173  by Zhichao Lu:

    A few fixes to get new estimator API eval to match legacy detection eval binary by (1) plumbing `is_crowd` annotations through to COCO evaluator, (2) setting the `sloppy` flag in tf.contrib.data.parallel_interleave based on whether shuffling is enabled, and (3) saving the original image instead of the resized original image, which allows for small/medium/large mAP metrics to be properly computed.

--
195316756  by Zhichao Lu:

    Internal change

--

PiperOrigin-RevId: 196161788
  • Loading branch information
pkulzc committed May 11, 2018
1 parent 6305421 commit 324d6dc
Show file tree
Hide file tree
Showing 42 changed files with 1,088 additions and 299 deletions.
37 changes: 23 additions & 14 deletions research/object_detection/box_coders/mean_stddev_box_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
class MeanStddevBoxCoder(box_coder.BoxCoder):
"""Mean stddev box coder."""

def __init__(self, stddev=0.01):
"""Constructor for MeanStddevBoxCoder.
Args:
stddev: The standard deviation used to encode and decode boxes.
"""
self._stddev = stddev

@property
def code_size(self):
return 4
Expand All @@ -34,37 +42,38 @@ def _encode(self, boxes, anchors):
Args:
boxes: BoxList holding N boxes to be encoded.
anchors: BoxList of N anchors. We assume that anchors has an associated
stddev field.
anchors: BoxList of N anchors.
Returns:
a tensor representing N anchor-encoded boxes
Raises:
ValueError: if the anchors BoxList does not have a stddev field
ValueError: if the anchors still have deprecated stddev field.
"""
if not anchors.has_field('stddev'):
raise ValueError('anchors must have a stddev field')
box_corners = boxes.get()
if anchors.has_field('stddev'):
raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and "
"should not be specified in the box list.")
means = anchors.get()
stddev = anchors.get_field('stddev')
return (box_corners - means) / stddev
return (box_corners - means) / self._stddev

def _decode(self, rel_codes, anchors):
"""Decode.
Args:
rel_codes: a tensor representing N anchor-encoded boxes.
anchors: BoxList of anchors. We assume that anchors has an associated
stddev field.
anchors: BoxList of anchors.
Returns:
boxes: BoxList holding N bounding boxes
Raises:
ValueError: if the anchors BoxList does not have a stddev field
ValueError: if the anchors still have deprecated stddev field and expects
the decode method to use stddev value from that field.
"""
if not anchors.has_field('stddev'):
raise ValueError('anchors must have a stddev field')
means = anchors.get()
stddevs = anchors.get_field('stddev')
box_corners = rel_codes * stddevs + means
if anchors.has_field('stddev'):
raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and "
"should not be specified in the box list.")
box_corners = rel_codes * self._stddev + means
return box_list.BoxList(box_corners)
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@ def testGetCorrectRelativeCodesAfterEncoding(self):
boxes = box_list.BoxList(tf.constant(box_corners))
expected_rel_codes = [[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]]
prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]])
prior_stddevs = tf.constant(2 * [4 * [.1]])
priors = box_list.BoxList(prior_means)
priors.add_field('stddev', prior_stddevs)

coder = mean_stddev_box_coder.MeanStddevBoxCoder()
coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
rel_codes = coder.encode(boxes, priors)
with self.test_session() as sess:
rel_codes_out = sess.run(rel_codes)
Expand All @@ -42,11 +40,9 @@ def testGetCorrectBoxesAfterDecoding(self):
rel_codes = tf.constant([[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]])
expected_box_corners = [[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5]]
prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]])
prior_stddevs = tf.constant(2 * [4 * [.1]])
priors = box_list.BoxList(prior_means)
priors.add_field('stddev', prior_stddevs)

coder = mean_stddev_box_coder.MeanStddevBoxCoder()
coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
decoded_boxes = coder.decode(rel_codes, priors)
decoded_box_corners = decoded_boxes.get()
with self.test_session() as sess:
Expand Down
3 changes: 2 additions & 1 deletion research/object_detection/builders/box_coder_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def build(box_coder_config):
])
if (box_coder_config.WhichOneof('box_coder_oneof') ==
'mean_stddev_box_coder'):
return mean_stddev_box_coder.MeanStddevBoxCoder()
return mean_stddev_box_coder.MeanStddevBoxCoder(
stddev=box_coder_config.mean_stddev_box_coder.stddev)
if box_coder_config.WhichOneof('box_coder_oneof') == 'square_box_coder':
return square_box_coder.SquareBoxCoder(scale_factors=[
box_coder_config.square_box_coder.y_scale,
Expand Down
42 changes: 42 additions & 0 deletions research/object_detection/builders/graph_rewriter_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for quantized training and evaluation."""

import tensorflow as tf


def build(graph_rewriter_config, is_training):
"""Returns a function that modifies default graph based on options.
Args:
graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto.
is_training: whether in training of eval mode.
"""
def graph_rewrite_fn():
"""Function to quantize weights and activation of the default graph."""
if (graph_rewriter_config.quantization.weight_bits != 8 or
graph_rewriter_config.quantization.activation_bits != 8):
raise ValueError('Only 8bit quantization is supported')

# Quantize the graph by inserting quantize ops for weights and activations
if is_training:
tf.contrib.quantize.create_training_graph(
input_graph=tf.get_default_graph(),
quant_delay=graph_rewriter_config.quantization.delay)
else:
tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph())

tf.contrib.layers.summarize_collection('quant_vars')
return graph_rewrite_fn
57 changes: 57 additions & 0 deletions research/object_detection/builders/graph_rewriter_builder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graph_rewriter_builder."""
import mock
import tensorflow as tf
from object_detection.builders import graph_rewriter_builder
from object_detection.protos import graph_rewriter_pb2


class QuantizationBuilderTest(tf.test.TestCase):

def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
with mock.patch.object(
tf.contrib.quantize, 'create_training_graph') as mock_quant_fn:
with mock.patch.object(tf.contrib.layers,
'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10
graph_rewriter_proto.quantization.weight_bits = 8
graph_rewriter_proto.quantization.activation_bits = 8
graph_rewrite_fn = graph_rewriter_builder.build(
graph_rewriter_proto, is_training=True)
graph_rewrite_fn()
_, kwargs = mock_quant_fn.call_args
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
self.assertEqual(kwargs['quant_delay'], 10)
mock_summarize_col.assert_called_with('quant_vars')

def testQuantizationBuilderSetsUpCorrectEvalArguments(self):
with mock.patch.object(tf.contrib.quantize,
'create_eval_graph') as mock_quant_fn:
with mock.patch.object(tf.contrib.layers,
'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10
graph_rewrite_fn = graph_rewriter_builder.build(
graph_rewriter_proto, is_training=False)
graph_rewrite_fn()
_, kwargs = mock_quant_fn.call_args
self.assertEqual(kwargs['input_graph'], tf.get_default_graph())
mock_summarize_col.assert_called_with('quant_vars')


if __name__ == '__main__':
tf.test.main()
17 changes: 14 additions & 3 deletions research/object_detection/builders/losses_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""A function to build localization and classification losses from config."""

from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import losses
from object_detection.protos import losses_pb2

Expand All @@ -34,9 +35,12 @@ def build(loss_config):
classification_weight: Classification loss weight.
localization_weight: Localization loss weight.
hard_example_miner: Hard example miner object.
random_example_sampler: BalancedPositiveNegativeSampler object.
Raises:
ValueError: If hard_example_miner is used with sigmoid_focal_loss.
ValueError: If random_example_sampler is getting non-positive value as
desired positive example fraction.
"""
classification_loss = _build_classification_loss(
loss_config.classification_loss)
Expand All @@ -54,9 +58,16 @@ def build(loss_config):
loss_config.hard_example_miner,
classification_weight,
localization_weight)
return (classification_loss, localization_loss,
classification_weight,
localization_weight, hard_example_miner)
random_example_sampler = None
if loss_config.HasField('random_example_sampler'):
if loss_config.random_example_sampler.positive_sample_fraction <= 0:
raise ValueError('RandomExampleSampler should not use non-positive'
'value as positive sample fraction.')
random_example_sampler = sampler.BalancedPositiveNegativeSampler(
positive_fraction=loss_config.random_example_sampler.
positive_sample_fraction)
return (classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, random_example_sampler)


def build_hard_example_miner(config,
Expand Down
Loading

0 comments on commit 324d6dc

Please sign in to comment.