|
| 1 | +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | + |
| 16 | +r"""Tool to export an object detection model for inference. |
| 17 | +
|
| 18 | +Prepares an object detection tensorflow graph for inference using model |
| 19 | +configuration and a trained checkpoint. Outputs inference |
| 20 | +graph, associated checkpoint files, a frozen inference graph and a |
| 21 | +SavedModel (https://tensorflow.github.io/serving/serving_basic.html). |
| 22 | +
|
| 23 | +The inference graph contains one of three input nodes depending on the user |
| 24 | +specified option. |
| 25 | + * `image_tensor`: Accepts a uint8 4-D tensor of shape [None, None, None, 3] |
| 26 | + * `encoded_image_string_tensor`: Accepts a 1-D string tensor of shape [None] |
| 27 | + containing encoded PNG or JPEG images. Image resolutions are expected to be |
| 28 | + the same if more than 1 image is provided. |
| 29 | + * `tf_example`: Accepts a 1-D string tensor of shape [None] containing |
| 30 | + serialized TFExample protos. Image resolutions are expected to be the same |
| 31 | + if more than 1 image is provided. |
| 32 | +
|
| 33 | +and the following output nodes returned by the model.postprocess(..): |
| 34 | + * `num_detections`: Outputs float32 tensors of the form [batch] |
| 35 | + that specifies the number of valid boxes per image in the batch. |
| 36 | + * `detection_boxes`: Outputs float32 tensors of the form |
| 37 | + [batch, num_boxes, 4] containing detected boxes. |
| 38 | + * `detection_scores`: Outputs float32 tensors of the form |
| 39 | + [batch, num_boxes] containing class scores for the detections. |
| 40 | + * `detection_classes`: Outputs float32 tensors of the form |
| 41 | + [batch, num_boxes] containing classes for the detections. |
| 42 | + * `detection_masks`: Outputs float32 tensors of the form |
| 43 | + [batch, num_boxes, mask_height, mask_width] containing predicted instance |
| 44 | + masks for each box if its present in the dictionary of postprocessed |
| 45 | + tensors returned by the model. |
| 46 | +
|
| 47 | +Notes: |
| 48 | + * This tool uses `use_moving_averages` from eval_config to decide which |
| 49 | + weights to freeze. |
| 50 | +
|
| 51 | +Example Usage: |
| 52 | +-------------- |
| 53 | +python export_inference_graph \ |
| 54 | + --input_type image_tensor \ |
| 55 | + --pipeline_config_path path/to/ssd_inception_v2.config \ |
| 56 | + --trained_checkpoint_prefix path/to/model.ckpt \ |
| 57 | + --output_directory path/to/exported_model_directory |
| 58 | +
|
| 59 | +The expected output would be in the directory |
| 60 | +path/to/exported_model_directory (which is created if it does not exist) |
| 61 | +with contents: |
| 62 | + - inference_graph.pbtxt |
| 63 | + - model.ckpt.data-00000-of-00001 |
| 64 | + - model.ckpt.info |
| 65 | + - model.ckpt.meta |
| 66 | + - frozen_inference_graph.pb |
| 67 | + + saved_model (a directory) |
| 68 | +
|
| 69 | +Config overrides (see the `config_override` flag) are text protobufs |
| 70 | +(also of type pipeline_pb2.TrainEvalPipelineConfig) which are used to override |
| 71 | +certain fields in the provided pipeline_config_path. These are useful for |
| 72 | +making small changes to the inference graph that differ from the training or |
| 73 | +eval config. |
| 74 | +
|
| 75 | +Example Usage (in which we change the second stage post-processing score |
| 76 | +threshold to be 0.5): |
| 77 | +
|
| 78 | +python export_inference_graph \ |
| 79 | + --input_type image_tensor \ |
| 80 | + --pipeline_config_path path/to/ssd_inception_v2.config \ |
| 81 | + --trained_checkpoint_prefix path/to/model.ckpt \ |
| 82 | + --output_directory path/to/exported_model_directory \ |
| 83 | + --config_override " \ |
| 84 | + model{ \ |
| 85 | + faster_rcnn { \ |
| 86 | + second_stage_post_processing { \ |
| 87 | + batch_non_max_suppression { \ |
| 88 | + score_threshold: 0.5 \ |
| 89 | + } \ |
| 90 | + } \ |
| 91 | + } \ |
| 92 | + }" |
| 93 | +""" |
| 94 | +import tensorflow as tf |
| 95 | +from google.protobuf import text_format |
| 96 | +from object_detection import exporter |
| 97 | +from object_detection.protos import pipeline_pb2 |
| 98 | + |
| 99 | +slim = tf.contrib.slim |
| 100 | +flags = tf.app.flags |
| 101 | + |
| 102 | +flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be ' |
| 103 | + 'one of [`image_tensor`, `encoded_image_string_tensor`, ' |
| 104 | + '`tf_example`]') |
| 105 | +flags.DEFINE_string('input_shape', None, |
| 106 | + 'If input_type is `image_tensor`, this can explicitly set ' |
| 107 | + 'the shape of this input tensor to a fixed size. The ' |
| 108 | + 'dimensions are to be provided as a comma-separated list ' |
| 109 | + 'of integers. A value of -1 can be used for unknown ' |
| 110 | + 'dimensions. If not specified, for an `image_tensor, the ' |
| 111 | + 'default shape will be partially specified as ' |
| 112 | + '`[None, None, None, 3]`.') |
| 113 | +flags.DEFINE_string('pipeline_config_path', None, |
| 114 | + 'Path to a pipeline_pb2.TrainEvalPipelineConfig config ' |
| 115 | + 'file.') |
| 116 | +flags.DEFINE_string('trained_checkpoint_prefix', None, |
| 117 | + 'Path to trained checkpoint, typically of the form ' |
| 118 | + 'path/to/model.ckpt') |
| 119 | +flags.DEFINE_string('output_directory', None, 'Path to write outputs.') |
| 120 | +flags.DEFINE_string('config_override', '', |
| 121 | + 'pipeline_pb2.TrainEvalPipelineConfig ' |
| 122 | + 'text proto to override pipeline_config_path.') |
| 123 | +flags.DEFINE_boolean('write_inference_graph', False, |
| 124 | + 'If true, writes inference graph to disk.') |
| 125 | +tf.app.flags.mark_flag_as_required('pipeline_config_path') |
| 126 | +tf.app.flags.mark_flag_as_required('trained_checkpoint_prefix') |
| 127 | +tf.app.flags.mark_flag_as_required('output_directory') |
| 128 | +FLAGS = flags.FLAGS |
| 129 | + |
| 130 | + |
| 131 | +def main(_): |
| 132 | + pipeline_config = pipeline_pb2.TrainEvalPipelineConfig() |
| 133 | + with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: |
| 134 | + text_format.Merge(f.read(), pipeline_config) |
| 135 | + text_format.Merge(FLAGS.config_override, pipeline_config) |
| 136 | + if FLAGS.input_shape: |
| 137 | + input_shape = [ |
| 138 | + int(dim) if dim != '-1' else None |
| 139 | + for dim in FLAGS.input_shape.split(',') |
| 140 | + ] |
| 141 | + else: |
| 142 | + input_shape = None |
| 143 | + exporter.export_inference_graph( |
| 144 | + FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_prefix, |
| 145 | + FLAGS.output_directory, input_shape=input_shape, |
| 146 | + write_inference_graph=FLAGS.write_inference_graph) |
| 147 | + |
| 148 | + |
| 149 | +if __name__ == '__main__': |
| 150 | + tf.app.run() |
0 commit comments