Skip to content

Commit

Permalink
Script to convert TF2 SSD models to TFLite, with documentation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 329938888
  • Loading branch information
srjoglekar246 authored and TF Object Detection Team committed Sep 3, 2020
1 parent f2c76e4 commit 575684c
Show file tree
Hide file tree
Showing 7 changed files with 761 additions and 0 deletions.
7 changes: 7 additions & 0 deletions research/object_detection/README.md
Expand Up @@ -73,6 +73,13 @@ documentation of the Object Detection API:

## Whats New

### Mobile Inference for TF2 models

TF2 OD API models can now be converted to TensorFlow Lite! Only SSD models
currently supported. See <a href='running_on_mobile_tf2.md'>documentation</a>.

**Thanks to contributors**: Sachin Joglekar

### TensorFlow 2 Support

We are happy to announce that the TF OD API officially supports TF2! Our release
Expand Down
254 changes: 254 additions & 0 deletions research/object_detection/export_tflite_graph_lib_tf2.py
@@ -0,0 +1,254 @@
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library to export TFLite-compatible SavedModel from TF2 detection models."""
import os
import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf

from object_detection.builders import model_builder
from object_detection.builders import post_processing_builder
from object_detection.core import box_list

_DEFAULT_NUM_CHANNELS = 3
_DEFAULT_NUM_COORD_BOX = 4
_MAX_CLASSES_PER_DETECTION = 1
_DETECTION_POSTPROCESS_FUNC = 'TFLite_Detection_PostProcess'


def get_const_center_size_encoded_anchors(anchors):
"""Exports center-size encoded anchors as a constant tensor.
Args:
anchors: a float32 tensor of shape [num_anchors, 4] containing the anchor
boxes
Returns:
encoded_anchors: a float32 constant tensor of shape [num_anchors, 4]
containing the anchor boxes.
"""
anchor_boxlist = box_list.BoxList(anchors)
y, x, h, w = anchor_boxlist.get_center_coordinates_and_sizes()
num_anchors = y.get_shape().as_list()

with tf1.Session() as sess:
y_out, x_out, h_out, w_out = sess.run([y, x, h, w])
encoded_anchors = tf1.constant(
np.transpose(np.stack((y_out, x_out, h_out, w_out))),
dtype=tf1.float32,
shape=[num_anchors[0], _DEFAULT_NUM_COORD_BOX],
name='anchors')
return num_anchors[0], encoded_anchors


class SSDModule(tf.Module):
"""Inference Module for TFLite-friendly SSD models."""

def __init__(self, pipeline_config, detection_model, max_detections,
use_regular_nms):
"""Initialization.
Args:
pipeline_config: The original pipeline_pb2.TrainEvalPipelineConfig
detection_model: The detection model to use for inference.
max_detections: Max detections desired from the TFLite model.
use_regular_nms: If True, TFLite model uses the (slower) multi-class NMS.
"""
self._process_config(pipeline_config)
self._pipeline_config = pipeline_config
self._model = detection_model
self._max_detections = max_detections
self._use_regular_nms = use_regular_nms

def _process_config(self, pipeline_config):
self._num_classes = pipeline_config.model.ssd.num_classes
self._nms_score_threshold = pipeline_config.model.ssd.post_processing.batch_non_max_suppression.score_threshold
self._nms_iou_threshold = pipeline_config.model.ssd.post_processing.batch_non_max_suppression.iou_threshold
self._scale_values = {}
self._scale_values[
'y_scale'] = pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.y_scale
self._scale_values[
'x_scale'] = pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.x_scale
self._scale_values[
'h_scale'] = pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.height_scale
self._scale_values[
'w_scale'] = pipeline_config.model.ssd.box_coder.faster_rcnn_box_coder.width_scale

image_resizer_config = pipeline_config.model.ssd.image_resizer
image_resizer = image_resizer_config.WhichOneof('image_resizer_oneof')
self._num_channels = _DEFAULT_NUM_CHANNELS

if image_resizer == 'fixed_shape_resizer':
self._height = image_resizer_config.fixed_shape_resizer.height
self._width = image_resizer_config.fixed_shape_resizer.width
if image_resizer_config.fixed_shape_resizer.convert_to_grayscale:
self._num_channels = 1
else:
raise ValueError(
'Only fixed_shape_resizer'
'is supported with tflite. Found {}'.format(
image_resizer_config.WhichOneof('image_resizer_oneof')))

def input_shape(self):
"""Returns shape of TFLite model input."""
return [1, self._height, self._width, self._num_channels]

def postprocess_implements_signature(self):
"""Returns tf.implements signature for MLIR legalization of TFLite NMS."""
implements_signature = [
'name: "%s"' % _DETECTION_POSTPROCESS_FUNC,
'attr { key: "max_detections" value { i: %d } }' % self._max_detections,
'attr { key: "max_classes_per_detection" value { i: %d } }' %
_MAX_CLASSES_PER_DETECTION,
'attr { key: "use_regular_nms" value { b: %s } }' %
str(self._use_regular_nms).lower(),
'attr { key: "nms_score_threshold" value { f: %f } }' %
self._nms_score_threshold,
'attr { key: "nms_iou_threshold" value { f: %f } }' %
self._nms_iou_threshold,
'attr { key: "y_scale" value { f: %f } }' %
self._scale_values['y_scale'],
'attr { key: "x_scale" value { f: %f } }' %
self._scale_values['x_scale'],
'attr { key: "h_scale" value { f: %f } }' %
self._scale_values['h_scale'],
'attr { key: "w_scale" value { f: %f } }' %
self._scale_values['w_scale'],
'attr { key: "num_classes" value { i: %d } }' % self._num_classes
]
implements_signature = ' '.join(implements_signature)
return implements_signature

def _get_postprocess_fn(self, num_anchors, num_classes):
# There is no TF equivalent for TFLite's custom post-processing op.
# So we add an 'empty' composite function here, that is legalized to the
# custom op with MLIR.
@tf.function(
experimental_implements=self.postprocess_implements_signature())
# pylint: disable=g-unused-argument,unused-argument
def dummy_post_processing(box_encodings, class_predictions, anchors):
boxes = tf.constant(0.0, dtype=tf.float32, name='boxes')
scores = tf.constant(0.0, dtype=tf.float32, name='scores')
classes = tf.constant(0.0, dtype=tf.float32, name='classes')
num_detections = tf.constant(0.0, dtype=tf.float32, name='num_detections')
return boxes, scores, classes, num_detections

return dummy_post_processing

@tf.function
def inference_fn(self, image):
"""Encapsulates SSD inference for TFLite conversion.
NOTE: The Args & Returns sections below indicate the TFLite model signature,
and not what the TF graph does (since the latter does not include the custom
NMS op used by TFLite)
Args:
image: a float32 tensor of shape [num_anchors, 4] containing the anchor
boxes
Returns:
num_detections: a float32 scalar denoting number of total detections.
classes: a float32 tensor denoting class ID for each detection.
scores: a float32 tensor denoting score for each detection.
boxes: a float32 tensor denoting coordinates of each detected box.
"""
predicted_tensors = self._model.predict(image, true_image_shapes=None)
# The score conversion occurs before the post-processing custom op
_, score_conversion_fn = post_processing_builder.build(
self._pipeline_config.model.ssd.post_processing)
class_predictions = score_conversion_fn(
predicted_tensors['class_predictions_with_background'])

with tf.name_scope('raw_outputs'):
# 'raw_outputs/box_encodings': a float32 tensor of shape
# [1, num_anchors, 4] containing the encoded box predictions. Note that
# these are raw predictions and no Non-Max suppression is applied on
# them and no decode center size boxes is applied to them.
box_encodings = tf.identity(
predicted_tensors['box_encodings'], name='box_encodings')
# 'raw_outputs/class_predictions': a float32 tensor of shape
# [1, num_anchors, num_classes] containing the class scores for each
# anchor after applying score conversion.
class_predictions = tf.identity(
class_predictions, name='class_predictions')
# 'anchors': a float32 tensor of shape
# [4, num_anchors] containing the anchors as a constant node.
num_anchors, anchors = get_const_center_size_encoded_anchors(
predicted_tensors['anchors'])
anchors = tf.identity(anchors, name='anchors')

# tf.function@ seems to reverse order of inputs, so reverse them here.
return self._get_postprocess_fn(num_anchors,
self._num_classes)(box_encodings,
class_predictions,
anchors)[::-1]


def export_tflite_model(pipeline_config, trained_checkpoint_dir,
output_directory, max_detections, use_regular_nms):
"""Exports inference SavedModel for TFLite conversion.
NOTE: Only supports SSD meta-architectures for now, and the output model will
have static-shaped, single-batch input.
This function creates `output_directory` if it does not already exist,
which will hold the intermediate SavedModel that can be used with the TFLite
converter.
Args:
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
trained_checkpoint_dir: Path to the trained checkpoint file.
output_directory: Path to write outputs.
max_detections: Max detections desired from the TFLite model.
use_regular_nms: If True, TFLite model uses the (slower) multi-class NMS.
Raises:
ValueError: if pipeline is invalid.
"""
output_saved_model_directory = os.path.join(output_directory, 'saved_model')

# Build the underlying model using pipeline config.
# TODO(b/162842801): Add support for other architectures.
if pipeline_config.model.WhichOneof('model') != 'ssd':
raise ValueError('Only ssd models are supported in tflite. '
'Found {} in config'.format(
pipeline_config.model.WhichOneof('model')))
detection_model = model_builder.build(
pipeline_config.model, is_training=False)

ckpt = tf.train.Checkpoint(model=detection_model)
manager = tf.train.CheckpointManager(
ckpt, trained_checkpoint_dir, max_to_keep=1)
status = ckpt.restore(manager.latest_checkpoint).expect_partial()

# The module helps build a TF SavedModel appropriate for TFLite conversion.
detection_module = SSDModule(pipeline_config, detection_model, max_detections,
use_regular_nms)

# Getting the concrete function traces the graph and forces variables to
# be constructed; only after this can we save the saved model.
status.assert_existing_objects_matched()
concrete_function = detection_module.inference_fn.get_concrete_function(
tf.TensorSpec(
shape=detection_module.input_shape(), dtype=tf.float32, name='input'))
status.assert_existing_objects_matched()

# Export SavedModel.
tf.saved_model.save(
detection_module,
output_saved_model_directory,
signatures=concrete_function)

0 comments on commit 575684c

Please sign in to comment.