Skip to content

Commit

Permalink
changed exporter to save variables for Tensorflow Serving support, as…
Browse files Browse the repository at this point in the history
… per github issue tensorflow#1988, comment tensorflow#1988 (comment)
  • Loading branch information
pplior committed Sep 18, 2019
1 parent e2293a9 commit 7a561d5
Showing 1 changed file with 31 additions and 46 deletions.
77 changes: 31 additions & 46 deletions research/object_detection/exporter.py
Expand Up @@ -263,54 +263,39 @@ def add_output_tensor_nodes(postprocessed_tensors,


def write_saved_model(saved_model_path,
frozen_graph_def,
trained_checkpoint_prefix,
inputs,
outputs):
"""Writes SavedModel to disk.
If checkpoint_path is not None bakes the weights into the graph thereby
eliminating the need of checkpoint files during inference. If the model
was trained with moving averages, setting use_moving_averages to true
restores the moving averages, otherwise the original set of variables
is restored.
Args:
saved_model_path: Path to write SavedModel.
frozen_graph_def: tf.GraphDef holding frozen graph.
inputs: The input placeholder tensor.
outputs: A tensor dictionary containing the outputs of a DetectionModel.
"""
with tf.Graph().as_default():
with tf.Session() as sess:

tf.import_graph_def(frozen_graph_def, name='')

builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)

tensor_info_inputs = {
'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
tensor_info_outputs = {}
for k, v in outputs.items():
tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)

detection_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs=tensor_info_inputs,
outputs=tensor_info_outputs,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
))

builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature,
},
)
builder.save()

saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, trained_checkpoint_prefix)

builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)

tensor_info_inputs = {
'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
tensor_info_outputs = {}
for k, v in outputs.items():
tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)

detection_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs=tensor_info_inputs,
outputs=tensor_info_outputs,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
))

builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
detection_signature,
},
)
builder.save()

def write_graph_and_checkpoint(inference_graph_def,
model_path,
Expand Down Expand Up @@ -441,7 +426,7 @@ def _export_inference_graph(input_type,
clear_devices=True,
initializer_nodes='')

write_saved_model(saved_model_path, frozen_graph_def,
write_saved_model(saved_model_path, trained_checkpoint_prefix,
placeholder_tensor, outputs)


Expand Down

0 comments on commit 7a561d5

Please sign in to comment.