-
Notifications
You must be signed in to change notification settings - Fork 45.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merged commit includes the following changes:
196161788 by Zhichao Lu: Add eval_on_train_steps parameter. Since the number of samples in train dataset is usually different to the number of samples in the eval dataset. -- 196151742 by Zhichao Lu: Add an optional random sampling process for SSD meta arch and update mean stddev coder to use default std dev when corresponding tensor is not added to boxlist field. -- 196148940 by Zhichao Lu: Release ssdlite mobilenet v2 coco trained model. -- 196058528 by Zhichao Lu: Apply FPN feature map generation before we add additional layers on top of resnet feature extractor. -- 195818367 by Zhichao Lu: Add support for exporting detection keypoints. -- 195745420 by Zhichao Lu: Introduce include_metrics_per_category option to Object Detection eval_config. -- 195734733 by Zhichao Lu: Rename SSDLite config to be more explicit. -- 195717383 by Zhichao Lu: Add quantized training to object_detection. -- 195683542 by Zhichao Lu: Fix documentation for the interaction of fine_tune_checkpoint_type and load_all_detection_checkpoint_vars interaction. -- 195668233 by Zhichao Lu: Using batch size from params dictionary if present. -- 195570173 by Zhichao Lu: A few fixes to get new estimator API eval to match legacy detection eval binary by (1) plumbing `is_crowd` annotations through to COCO evaluator, (2) setting the `sloppy` flag in tf.contrib.data.parallel_interleave based on whether shuffling is enabled, and (3) saving the original image instead of the resized original image, which allows for small/medium/large mAP metrics to be properly computed. -- 195316756 by Zhichao Lu: Internal change -- PiperOrigin-RevId: 196161788
- Loading branch information
Showing
42 changed files
with
1,088 additions
and
299 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
42 changes: 42 additions & 0 deletions
42
research/object_detection/builders/graph_rewriter_builder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Functions for quantized training and evaluation.""" | ||
|
||
import tensorflow as tf | ||
|
||
|
||
def build(graph_rewriter_config, is_training): | ||
"""Returns a function that modifies default graph based on options. | ||
Args: | ||
graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto. | ||
is_training: whether in training of eval mode. | ||
""" | ||
def graph_rewrite_fn(): | ||
"""Function to quantize weights and activation of the default graph.""" | ||
if (graph_rewriter_config.quantization.weight_bits != 8 or | ||
graph_rewriter_config.quantization.activation_bits != 8): | ||
raise ValueError('Only 8bit quantization is supported') | ||
|
||
# Quantize the graph by inserting quantize ops for weights and activations | ||
if is_training: | ||
tf.contrib.quantize.create_training_graph( | ||
input_graph=tf.get_default_graph(), | ||
quant_delay=graph_rewriter_config.quantization.delay) | ||
else: | ||
tf.contrib.quantize.create_eval_graph(input_graph=tf.get_default_graph()) | ||
|
||
tf.contrib.layers.summarize_collection('quant_vars') | ||
return graph_rewrite_fn |
57 changes: 57 additions & 0 deletions
57
research/object_detection/builders/graph_rewriter_builder_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for graph_rewriter_builder.""" | ||
import mock | ||
import tensorflow as tf | ||
from object_detection.builders import graph_rewriter_builder | ||
from object_detection.protos import graph_rewriter_pb2 | ||
|
||
|
||
class QuantizationBuilderTest(tf.test.TestCase): | ||
|
||
def testQuantizationBuilderSetsUpCorrectTrainArguments(self): | ||
with mock.patch.object( | ||
tf.contrib.quantize, 'create_training_graph') as mock_quant_fn: | ||
with mock.patch.object(tf.contrib.layers, | ||
'summarize_collection') as mock_summarize_col: | ||
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter() | ||
graph_rewriter_proto.quantization.delay = 10 | ||
graph_rewriter_proto.quantization.weight_bits = 8 | ||
graph_rewriter_proto.quantization.activation_bits = 8 | ||
graph_rewrite_fn = graph_rewriter_builder.build( | ||
graph_rewriter_proto, is_training=True) | ||
graph_rewrite_fn() | ||
_, kwargs = mock_quant_fn.call_args | ||
self.assertEqual(kwargs['input_graph'], tf.get_default_graph()) | ||
self.assertEqual(kwargs['quant_delay'], 10) | ||
mock_summarize_col.assert_called_with('quant_vars') | ||
|
||
def testQuantizationBuilderSetsUpCorrectEvalArguments(self): | ||
with mock.patch.object(tf.contrib.quantize, | ||
'create_eval_graph') as mock_quant_fn: | ||
with mock.patch.object(tf.contrib.layers, | ||
'summarize_collection') as mock_summarize_col: | ||
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter() | ||
graph_rewriter_proto.quantization.delay = 10 | ||
graph_rewrite_fn = graph_rewriter_builder.build( | ||
graph_rewriter_proto, is_training=False) | ||
graph_rewrite_fn() | ||
_, kwargs = mock_quant_fn.call_args | ||
self.assertEqual(kwargs['input_graph'], tf.get_default_graph()) | ||
mock_summarize_col.assert_called_with('quant_vars') | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.