Skip to content

Commit

Permalink
Merge 536df12 into ff12468
Browse files Browse the repository at this point in the history
  • Loading branch information
vanvalen committed Apr 20, 2020
2 parents ff12468 + 536df12 commit b08e14d
Showing 1 changed file with 74 additions and 2 deletions.
76 changes: 74 additions & 2 deletions deepcell/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
from __future__ import division

import os

import numpy as np
import tensorflow as tf

from tensorflow.python.keras import backend as K
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import signature_constants
Expand All @@ -48,7 +49,8 @@ def export_model(keras_model, export_path, model_version=0, weights_path=None):
weights_path (str): path to a .h5 or .tf weights file
"""
# Start the tensorflow session
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8, allow_growth=False)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8,
allow_growth=False)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

K.set_session(sess)
Expand Down Expand Up @@ -98,3 +100,73 @@ def export_model(keras_model, export_path, model_version=0, weights_path=None):

# Save the graph
builder.save()


def export_model_to_tflite(model_file, export_path, calibration_images,
norm=True, location=True, file_name='model.tflite'):
"""Export a saved keras model to tensorflow-lite with int8 precision.
This export function has only been tested with PanopticNet models. For the
export to be successful, the PanopticNet model must have normalization set
to None, location set to False, and the upsampling layers must use bilinear
interpolation.
Args:
model_file (str): Path to saved keras model
export_path (str): Directory to save the exported tflite model
calibration_images (numpy array): Array of images used for calibration
during model quantization
norm (boolean): Whether to normalize calibration images.
Defaults to True.
location (boolean): Whether to append a location image to calibration
images. Defaults to True.
file_name (str): File name for the exported model. Defaults to
'model.tflite'
"""
# Define helper function - normalization
def norm_images(images):
mean = np.mean(images, axis=(1, 2), keepdims=True)
std = np.std(images, axis=(1, 2), keepdims=True)
norm = (images - mean) / std
return norm

# Define helper function - add location layer
def add_location(images):
x = np.arange(0, images.shape[1], dtype='float32')
y = np.arange(0, images.shape[2], dtype='float32')

x = x / max(x)
y = y / max(y)

loc_x, loc_y = np.meshgrid(x, y, indexing='ij')
loc = np.stack([loc_x, loc_y], axis=-1)
loc = np.expand_dims(loc, axis=0)
loc = np.tile(loc, (images.shape[0], 1, 1, 1))
images_with_loc = np.concatenate([images, loc], axis=-1)
return images_with_loc

# Create generator to calibrate model quantization
calibration_images = calibration_images.astype('float32')
if norm:
calibration_images = norm_images(calibration_images)
if location:
calibration_images = add_location(calibration_images)

def representative_data_gen():
for image in calibration_images:
data = [np.expand_dims(image, axis=0)]
yield data

converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(model_file)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
tflite_model = converter.convert()

# Save converted model
save_path = os.path.join(export_path, file_name)
open(save_path, "wb").write(tflite_quant_model)

return tflite_model

0 comments on commit b08e14d

Please sign in to comment.