##### Copyright 2024 The TensorFlow Authors.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Converting Keras to TFLite (via the JAX backend)
==========

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/lite/examples/keras/keras_jax_backend_to_tfl"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/keras/keras_jax_backend_to_tfl.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax"

## Setup

In [None]:
import keras
import tensorflow as tf
import numpy as np

## Get the test image data

In [None]:
from PIL import Image
import requests

url = "https://storage.googleapis.com/download.tensorflow.org/example_images/astrid_l_shaped.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image = image.resize((224, 224))
input_image = np.array(image)
input_image = np.expand_dims(input_image, axis=0)

## Instatiate a Resnet50 model from the Keras models library

In [None]:
jax_model = keras.applications.resnet.ResNet50(include_top=True, weights="imagenet")

## Run the keras JAX model with the test input

In [None]:
input_data = keras.applications.resnet50.preprocess_input(input_image)
jax_model_output = jax_model(input_data)

decoded_preds = keras.applications.resnet.decode_predictions(jax_model_output, top=1)[
    0
][0]
print("Predicted class:", decoded_preds[1])

## Save the Keras JAX model

In [None]:
saved_model_dir = "resnet50_saved_model"
jax_model.export(saved_model_dir)

## Convert to a TFLite model file

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()

## Run using TFLite Runtime

In [None]:
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()[0]
interpreter.set_tensor(input_details["index"], input_data)
interpreter.invoke()

output_details = interpreter.get_output_details()
output_data = interpreter.get_tensor(output_details[0]["index"])

tfl_predicted_class_idx = keras.applications.resnet.decode_predictions(
    output_data, top=1
)[0][0]
print("Predicted class:", tfl_predicted_class_idx[1])