Copyright 2024 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

# [OSS] JAX to TFLite with StableHLO Quantization Demonstration for ODML.

This example shows a JAX Keras reference model converted into a StableHLO module and via `jax2tf`, then quantized in the ODML Converter via the StableHLO Quantizer.

Note: This API is experimental and will likely have breakages with other models. Please reach out to [scalable-opt-team@google.com](mailto:scalable-opt-team@google.com) and we will support your use case.

## StableHLO Quantizer



StableHLO Quantizer is a quantization API to enable ML framework optionality and hardware retargetability.

In [None]:
!pip uninstall tensorflow --yes

In [None]:
!pip3 install tf-nightly
!pip3 install keras-core

In [None]:
import tensorflow as tf
print("TensorFlow version:", tf.__version__)

In [None]:
import os
os.environ['KERAS_BACKEND'] = 'jax'
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
from keras_core.applications import ResNet50
from jax.experimental import jax2tf

In [None]:
input_shape = (1, 224, 224, 3)

jax_callable = jax2tf.convert(
    ResNet50(
      input_shape=input_shape[1:],
      pooling='avg',
  ).call,
    with_gradient=False,
    native_serialization=True,
    native_serialization_platforms=('cpu',))

tf_module = tf.Module()
tf_module.f = tf.function(
    jax_callable,
    autograph=False,
    input_signature=[
        tf.TensorSpec(input_shape, jnp.float32, 'lhs_operand')
    ],
)

saved_model_dir = '/tmp/saved_model'
tf.saved_model.save(tf_module, saved_model_dir)

def calibration_dataset():
  rng = np.random.default_rng(seed=1235)
  for _ in range(2):
    yield {
        'lhs_operand': rng.uniform(low=-1.0, high=1.0, size=input_shape).astype(
            np.float32
        )
    }
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.SELECT_TF_OPS,  # enable TensorFlow ops.
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TFL ops.
]
converter.representative_dataset = calibration_dataset
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Below flag controls whether to use StableHLO Quantizer or TFLite quantizer.
converter.experimental_use_stablehlo_quantizer = True

quantized_model = converter.convert()

with open('/tmp/resnet50_quantized.tflite', 'wb') as f:
  f.write(quantized_model)

In [None]:
print(str(os.path.getsize('/tmp/resnet50_quantized.tflite') >> 20) + 'MB')