-
Notifications
You must be signed in to change notification settings - Fork 45.5k
Add SavedModel export to Resnet #3759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
19d9644
58162fe
3781875
0ba85be
5eafbe3
665bca6
34af841
89fba53
7f7be94
a5f10dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
|
||
from official.resnet import resnet_model | ||
from official.utils.arg_parsers import parsers | ||
from official.utils.export import export | ||
from official.utils.logging import hooks_helper | ||
from official.utils.logging import logger | ||
|
||
|
@@ -219,7 +220,13 @@ def resnet_model_fn(features, labels, mode, model_class, | |
} | ||
|
||
if mode == tf.estimator.ModeKeys.PREDICT: | ||
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) | ||
# Return the predictions and the specification for serving a SavedModel | ||
return tf.estimator.EstimatorSpec( | ||
mode=mode, | ||
predictions=predictions, | ||
export_outputs={ | ||
'predict': tf.estimator.export.PredictOutput(predictions) | ||
}) | ||
|
||
# Calculate loss, which includes softmax cross entropy and L2 regularization. | ||
cross_entropy = tf.losses.softmax_cross_entropy( | ||
|
@@ -310,8 +317,20 @@ def validate_batch_size_for_multi_gpu(batch_size): | |
raise ValueError(err) | ||
|
||
|
||
def resnet_main(flags, model_function, input_function): | ||
"""Shared main loop for ResNet Models.""" | ||
def resnet_main(flags, model_function, input_function, shape=None): | ||
"""Shared main loop for ResNet Models. | ||
|
||
Args: | ||
flags: FLAGS object that contains the params for running. See | ||
ResnetArgParser for created flags. | ||
model_function: the function that instantiates the Model and builds the | ||
ops for train/eval. This will be passed directly into the estimator. | ||
input_function: the function that processes the dataset and returns a | ||
dataset that the estimator can train on. This will be wrapped with | ||
all the relevant flags for running and passed to estimator. | ||
shape: list of ints representing the shape of the images used for training. | ||
This is only used if flags.export_dir is passed. | ||
""" | ||
|
||
# Using the Winograd non-fused algorithms provides a small performance boost. | ||
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' | ||
|
@@ -389,16 +408,34 @@ def input_fn_eval(): | |
if benchmark_logger: | ||
benchmark_logger.log_estimator_evaluation_result(eval_results) | ||
|
||
if flags.export_dir is not None: | ||
warn_on_multi_gpu_export(flags.multi_gpu) | ||
|
||
# Exports a saved model for the given classifier. | ||
input_receiver_fn = export.build_tensor_serving_input_receiver_fn( | ||
shape, batch_size=flags.batch_size) | ||
classifier.export_savedmodel(flags.export_dir, input_receiver_fn) | ||
|
||
|
||
def warn_on_multi_gpu_export(multi_gpu=False): | ||
"""For the time being, multi-GPU mode does not play nicely with exporting.""" | ||
if multi_gpu: | ||
tf.logging.warning( | ||
'You are exporting a SavedModel while in multi-GPU mode. Note that ' | ||
'the resulting SavedModel will require the same GPUs be available.' | ||
'If you wish to serve the SavedModel from a different device, ' | ||
'try exporting the SavedModel with multi-GPU mode turned off.') | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI @isaprykin - For now, warning the user. Eventually, it would be nice if Estimator would know not to run the saved model through the replication part of the graph. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checkpoint loading seems flexible enough that you can load between single and multi gpu models. So in the case of multi gpu models it might makes sense to construct a new estimator, load in the trained weights, and then serialize that. It's perhaps not the most elegant solution, but training with multi_gpu and then serving on a per-GPU basis seems like the most common use case so supporting it is somewhat important. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now, I'm going to leave as is. DistributionStrategies is a moving target that aims to hide replicate_model_fn, so we can reevaluate in a ~month when that is firmed up. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sounds perfectly reasonable. |
||
|
||
class ResnetArgParser(argparse.ArgumentParser): | ||
"""Arguments for configuring and running a Resnet Model. | ||
""" | ||
"""Arguments for configuring and running a Resnet Model.""" | ||
|
||
def __init__(self, resnet_size_choices=None): | ||
super(ResnetArgParser, self).__init__(parents=[ | ||
parsers.BaseParser(), | ||
parsers.PerformanceParser(), | ||
parsers.ImageModelParser(), | ||
parsers.ExportParser(), | ||
parsers.BenchmarkParser(), | ||
]) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Convenience functions for exporting models as SavedModels or other types.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
|
||
|
||
def build_tensor_serving_input_receiver_fn(shape, dtype=tf.float32, | ||
batch_size=1): | ||
"""Returns a input_receiver_fn that can be used during serving. | ||
|
||
This expects examples to come through as float tensors, and simply | ||
wraps them as TensorServingInputReceivers. | ||
|
||
Arguably, this should live in tf.estimator.export. Testing here first. | ||
|
||
Args: | ||
shape: list representing target size of a single example. | ||
dtype: the expected datatype for the input example | ||
batch_size: number of input tensors that will be passed for prediction | ||
|
||
Returns: | ||
A function that itself returns a TensorServingInputReceiver. | ||
""" | ||
def serving_input_receiver_fn(): | ||
# Prep a placeholder where the input example will be fed in | ||
features = tf.placeholder( | ||
dtype=dtype, shape=[batch_size] + shape, name='input_tensor') | ||
|
||
return tf.estimator.export.TensorServingInputReceiver( | ||
features=features, receiver_tensors=features) | ||
|
||
return serving_input_receiver_fn |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for exporting utils.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be much trouble to put tests for the individual models? (Probably in their respective files rather than this one.) Basically generate a trivial model with synthetic data and then confirm that it can be loaded and serve. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hahahaha. I laugh because, as I learned during this process, actually using a saved model is incredibly difficult without a dedicated TF Serving instance. I am hoping we can change that next quarter, but, for now, it would require quite a few tf.contrib calls. Punting on that until we formalize the story on model exports, hopefully next quarter. |
||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf # pylint: disable=g-bad-import-order | ||
|
||
from official.utils.export import export | ||
|
||
|
||
class ExportUtilsTest(tf.test.TestCase): | ||
"""Tests for the ExportUtils.""" | ||
|
||
def test_build_tensor_serving_input_receiver_fn(self): | ||
receiver_fn = export.build_tensor_serving_input_receiver_fn(shape=[4, 5]) | ||
with tf.Graph().as_default(): | ||
receiver = receiver_fn() | ||
self.assertIsInstance( | ||
receiver, tf.estimator.export.TensorServingInputReceiver) | ||
|
||
self.assertIsInstance(receiver.features, tf.Tensor) | ||
self.assertEqual(receiver.features.shape, tf.TensorShape([1, 4, 5])) | ||
self.assertEqual(receiver.features.dtype, tf.float32) | ||
self.assertIsInstance(receiver.receiver_tensors, dict) | ||
# Note that Python 3 can no longer index .values() directly; cast to list. | ||
self.assertEqual(list(receiver.receiver_tensors.values())[0].shape, | ||
tf.TensorShape([1, 4, 5])) | ||
|
||
def test_build_tensor_serving_input_receiver_fn_batch_dtype(self): | ||
receiver_fn = export.build_tensor_serving_input_receiver_fn( | ||
shape=[4, 5], dtype=tf.int8, batch_size=10) | ||
|
||
with tf.Graph().as_default(): | ||
receiver = receiver_fn() | ||
self.assertIsInstance( | ||
receiver, tf.estimator.export.TensorServingInputReceiver) | ||
|
||
self.assertIsInstance(receiver.features, tf.Tensor) | ||
self.assertEqual(receiver.features.shape, tf.TensorShape([10, 4, 5])) | ||
self.assertEqual(receiver.features.dtype, tf.int8) | ||
self.assertIsInstance(receiver.receiver_tensors, dict) | ||
# Note that Python 3 can no longer index .values() directly; cast to list. | ||
self.assertEqual(list(receiver.receiver_tensors.values())[0].shape, | ||
tf.TensorShape([10, 4, 5])) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we should stay with double quote for consistence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file uses single quotes. Perhaps sometime we should just choose which quote style to use for this repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I would opt for double quotes just because docstrings are
"""
, but we can fix in a separate PR for all files if desired. For now, will stick with what's there to avoid confusion in this PR.