##### Copyright 2018 The TensorFlow Authors.


In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tensorflow Lite Gesture Classification Example Conversion Script


This guide shows how you can go about converting the model trained with TensorFlowJS to TensorFlow Lite FlatBuffers.

Run all steps in-order. At the end, `model.tflite` file will be downloaded.


<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/examples/blob/master/lite/examples/gesture_classification/ml/tensorflowjs_to_tflite_colab_notebook.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />
    Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/examples/blob/master/lite/examples/gesture_classification/ml/tensorflowjs_to_tflite_colab_notebook.ipynb">
    <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />
    View source on GitHub</a>
  </td>
</table>

**Install Dependencies**

In [None]:
!pip install -q tensorflowjs

In [None]:
import traceback
import logging
import tensorflow as tf 
import os

from google.colab import files
from tensorflow import keras
from tensorflowjs.converters import load_keras_model

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

***Cleanup any existing models if necessary***

In [None]:
!rm -rf *.h5 *.tflite *.json *.bin

**Upload your Tensorflow.js Artifacts Here**

i.e., The weights manifest **model.json** and the binary weights file **model-weights.bin**

In [None]:
files.upload()

**Export Configuration**

In [None]:
#@title Export Configuration

# TensorFlow.js arguments

config_json = "model.json" #@param {type:"string"}
weights_path_prefix = None #@param {type:"raw"}
model_tflite = "model.tflite" #@param {type:"string"}


**Model Converter**

The following class converts a TensorFlow.js model to a TFLite FlatBuffer

In [None]:
class ModelConverter:
    """
    Creates a ModelConverter class from a TensorFlow.js model file.

    Args:
      :param config_json_path: Full filepath of weights manifest file containing the model architecture.
      :param weights_path_prefix: Full filepath to the directory in which the weights binaries exist.
      :param tflite_model_file: Name of the TFLite FlatBuffer file to be exported.

    :return:
      ModelConverter class.
    """

    def __init__(self,
                 config_json_path,
                 weights_path_prefix,
                 tflite_model_file
                 ):
        self.config_json_path = config_json_path
        self.weights_path_prefix = weights_path_prefix
        self.tflite_model_file = tflite_model_file
        self.keras_model_file = 'merged.h5'

        # MobileNet Options
        self.input_node_name = 'the_input'
        self.image_size = 224
        self.alpha = 0.25
        self.depth_multiplier = 1
        self._input_shape = (1, self.image_size, self.image_size, 3)
        self.depthwise_conv_layer = 'conv_pw_13_relu'

    def convert(self):
        self.save_keras_model()
        self._deserialize_tflite_from_keras()
        logger.info('The TFLite model has been generated')

    def save_keras_model(self):
        top_model = load_keras_model(self.config_json_path, self.weights_path_prefix,
                                     weights_data_buffers=None,
                                     load_weights=True,
                                     )
        base_model = self.get_base_model()
        self._merged_model = self.merge(base_model, top_model)

        logger.info("The merged Keras has been generated.")

    def merge(self, base_model, top_model):
        """
        Merges base model with the classification block
        :return:  Returns the merged Keras model
        """
        logger.info("Initializing model...")

        layer = base_model.get_layer(self.depthwise_conv_layer)
        model = keras.Model(inputs=base_model.input, outputs=top_model(layer.output))
        logger.info("Model created.")

        return model

    def get_base_model(self):
        """
        Builds MobileNet with the default parameters
        :return:  Returns the base MobileNet model
        """
        input_tensor = keras.Input(shape=self._input_shape[1:], name=self.input_node_name)
        base_model = keras.applications.MobileNet(input_shape=self._input_shape[1:],
                               alpha=self.alpha,
                               depth_multiplier=self.depth_multiplier,
                               input_tensor=input_tensor,
                               include_top=False)
        return base_model

    def _deserialize_tflite_from_keras(self):
        converter = tf.lite.TFLiteConverter.from_keras_model(self._merged_model)
        tflite_model = converter.convert()

        with open(self.tflite_model_file, "wb") as file:
            file.write(tflite_model)

In [None]:
try:
    converter = ModelConverter(config_json,
                               weights_path_prefix,
                               model_tflite)

    converter.convert()

except ValueError as e:
    print(traceback.format_exc())
    print("Error occurred while converting.")

In [None]:
files.download(model_tflite)