In [1]:
# This example successfully demonstrates the conversion of JAX functions to TensorFlow and subsequently to ONNX, 
# using TensorFlow's tf2onnx tool for ONNX export and ONNX Runtime for inference. 

from jax.experimental import jax2tf
from jax import numpy as jnp

import numpy as np
import tensorflow as tf

import onnx
import tf2onnx
import onnxruntime as ort


In [2]:
# Define JAX functions
def f_jax1(x):
    return jnp.sin(x)

def f_jax2(x):
    return jnp.cos(x)



In [3]:
# Convert JAX functions to TensorFlow functions
# enable_xla is set to False to ensure that XLA compilation is disabled for compatibility with tf2onnx
f_tf1 = jax2tf.convert(f_jax1, enable_xla=False)
f_tf2 = jax2tf.convert(f_jax2, enable_xla=False)

# Define a higher-level function that combines the individual functions
def combined_function(x):
    y1 = f_tf1(x)
    y2 = f_tf2(x)
    return y1 + y2

# Convert the combined function to a TensorFlow graph
combined_function_graph = tf.function(combined_function, autograph=False)

#print(combined_function_graph.get_concrete_function(tf.constant(1)).graph.as_graph_def())


In [4]:
# Export the module to ONNX
inference_onnx, _ = tf2onnx.convert.from_function(combined_function_graph, input_signature=[tf.TensorSpec([1, 3])])

onnx_model_path = "combined_model.onnx"
with open(onnx_model_path, "wb") as f:
    f.write(inference_onnx.SerializeToString())

# Load the ONNX model with ONNX Runtime
ort_session = ort.InferenceSession(onnx_model_path)

# Prepare the input for ONNX Runtime
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name

# Test the combined function with a NumPy array
sample_input = np.random.random((1, 3)).astype(np.float32)  # Example input

# Perform inference with ONNX Runtime
onnx_input = sample_input
onnx_output = ort_session.run([output_name], {input_name: onnx_input})[0]

# Display the results
print("Input:")
print(sample_input)
print("Output from TensorFlow:")
print(combined_function_graph(sample_input))
print("Output from ONNX Runtime:")
print(onnx_output)

Input:
[[0.8283861  0.26919276 0.6401985 ]]
Output from TensorFlow:
tf.Tensor([[1.4129071 1.2299392 1.3993318]], shape=(1, 3), dtype=float32)
Output from ONNX Runtime:
[[1.4129071 1.2299392 1.3993318]]


In [5]:
import netron
netron.start(onnx_model_path)

Serving 'combined_model.onnx' at http://localhost:8081


('localhost', 8081)

In [5]:
# In this case, a TF.Module is employed to organize the variables
class CombinedModule(tf.Module):
    def __init__(self):
        super().__init__()

        # Convert JAX functions to TensorFlow functions
        self.f_tf1 = jax2tf.convert(f_jax1, enable_xla=False)
        self.f_tf2 = jax2tf.convert(f_jax2, enable_xla=False)

    @tf.function(input_signature=[tf.TensorSpec([1, 3], dtype=tf.float32)])
    def combined_function(self, x):
        y1 = self.f_tf1(x)
        y2 = self.f_tf2(x)
        return y1 + y2


# Example usage:
# Create an instance of CombinedModule
combined_module = CombinedModule()

# Test the combined function with a NumPy array
result_tf = combined_module.combined_function(sample_input)

# Export the module to ONNX
inference_onnx, _ = tf2onnx.convert.from_function(
    combined_module.combined_function,
    input_signature=[tf.TensorSpec([1, 3], dtype=tf.float32)],
    opset=13
)

onnx_model_path = "combined_model2.onnx"
with open(onnx_model_path, "wb") as f:
    f.write(inference_onnx.SerializeToString())

# Load the ONNX model with ONNX Runtime
ort_session = ort.InferenceSession(onnx_model_path)

# Perform inference with ONNX Runtime
onnx_output = ort_session.run(None, {ort_session.get_inputs()[0].name: sample_input})[0]

# Display the results
print("Input:")
print(sample_input)
print("Output from TensorFlow:")
print(result_tf)
print("Output from ONNX Runtime:")
print(onnx_output)


Input:
[[0.5753632  0.6161297  0.90586555]]
Output from TensorFlow:
tf.Tensor([[1.3831344 1.394002  1.403964 ]], shape=(1, 3), dtype=float32)
Output from ONNX Runtime:
[[1.3831344 1.394002  1.403964 ]]


In [60]:
def f_jax3(a, lst):
    b = []
    print(lst)
    for l in lst:
        print(l)
        b.append(a[l])
    return b 

# Convert JAX functions to TensorFlow functions
# enable_xla is set to False to ensure that XLA compilation is disabled for compatibility with tf2onnx
f_tf3 = jax2tf.convert(f_jax3, enable_xla=False)

# Convert the combined function to a TensorFlow graph
function_graph = tf.function(f_tf3, autograph=False)

# Export the model to ONNX
input_signature = [
        [tf.TensorSpec(shape=(3), dtype=tf.float32), tf.TensorSpec(shape=(3), dtype=tf.float32)], 
        [], 
        #tf.TensorSpec([3], dtype=tf.int32),  
    ]

# Export the module to ONNX
inference_onnx, _ = tf2onnx.convert.from_function(function_graph, input_signature=input_signature)

onnx_model_path = "trial.onnx"
with open(onnx_model_path, "wb") as f:
    f.write(inference_onnx.SerializeToString())


from onnxsim import simplify

# Load your ONNX model
model = onnx.load(onnx_model_path)

# Simplify the model
model_simp, check = simplify(model)

# Ensure the simplified model is valid
assert check, "Simplified ONNX model could not be validated"

# Save the simplified model
simplified_model_path = "path_to_simplified_model.onnx"
onnx.save(model_simp, simplified_model_path)

print(f"Simplified model saved at: {simplified_model_path}")

import netron
netron.start(simplified_model_path)

[Traced<ShapedArray(int32[])>with<TensorFlowTrace(level=0/1)> with
  val = <tf.Tensor 'jax2tf_arg_2:0' shape=() dtype=int32>
  _aval = ShapedArray(int32[]), Traced<ShapedArray(int32[])>with<TensorFlowTrace(level=0/1)> with
  val = <tf.Tensor 'jax2tf_arg_3:0' shape=() dtype=int32>
  _aval = ShapedArray(int32[])]
Traced<ShapedArray(int32[])>with<TensorFlowTrace(level=0/1)> with
  val = <tf.Tensor 'jax2tf_arg_2:0' shape=() dtype=int32>
  _aval = ShapedArray(int32[])


TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [53]:

# Define the JAX function
def f_jax3(a, lst):
    return a[lst]

# Convert the JAX function to TensorFlow
f_tf3 = jax2tf.convert(f_jax3, enable_xla=False)

# Create a TensorFlow function graph
function_graph = tf.function(f_tf3, autograph=False)

# Define the input signature for TensorFlow
input_signature = [
    tf.TensorSpec(shape=(3,), dtype=tf.float32),  # Shape of 'a'
    tf.TensorSpec(shape=(3,), dtype=tf.int32)     # Shape of 'lst'
]

# Create the concrete function
inference_onnx, _ = tf2onnx.convert.from_function(function_graph, input_signature=input_signature)

# Save the ONNX model to a file
with open("f_jax3.onnx", "wb") as f:
    f.write(inference_onnx.SerializeToString())