# JaX Flax --> ONNX

This notebook converts brax MLP networks to an ONNX checkpoint.

In [2]:
import os

os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax

# Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [3]:
from brax.training.agents.ppo import networks as ppo_networks
from mujoco_playground.config import locomotion_params, manipulation_params
from mujoco_playground import locomotion, manipulation
import functools
import pickle
import jax.numpy as jp
import jax
import tf2onnx
import tensorflow as tf
from tensorflow.keras import layers
import onnxruntime as rt
from brax.training.acme import running_statistics

2025-03-12 06:05:57.684347: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1741759557.700928   36020 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1741759557.706371   36020 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1741759557.719727   36020 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1741759557.719740   36020 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1741759557.719742   36020 computation_placer.cc:177] computation placer alr

In [16]:
from track_mjx.environment.walker import rodent
from track_mjx.agent import checkpointing
from track_mjx.agent.intention_network import Decoder
from track_mjx.analysis.rollout import create_environment
from track_mjx.environment.wrappers import EvalClipResetWrapper
from track_mjx.analysis.render import render_from_saved_rollout, display_video
from track_mjx.analysis.utils import save_to_h5py, load_from_h5py

import jax
from jax import numpy as jnp
from brax.training import distribution
from brax.training.acme import running_statistics
from pathlib import Path

# replace with your checkpoint path
ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250306_194809"
# Load config from checkpoint
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/root/vast/scott-yang/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

# NOTE: To use accelerated JAX.JIT, only run the following code only once.
# If you re-run, you will triggers recompilations

env = create_environment(cfg)

# env_name = "BerkeleyHumanoidJoystickFlatTerrain"
# # ppo_params = locomotion_params.brax_ppo_config(env_name)
# ppo_params = locomotion_params.brax_ppo_config(env_name)z

Loading checkpoint from /root/vast/scott-yang/track-mjx/model_checkpoints/250306_194809 at step 144
env._steps_for_cur_frame: 2.0


In [5]:
obs_size = env.observation_size
act_size = env.action_size
print(obs_size, act_size)

617 38


In [18]:
# initialize the abstract decoder
network_config = cfg["network_config"]

# initialize the decoder with last layer represent the mean and variance of the action distribution
decoder = Decoder(
    network_config["decoder_layer_sizes"] + [network_config["action_size"] * 2]
)

normalizer = running_statistics.normalize
# load the normalizer parameters
normalizer_param = ckpt["policy"][0]

# load the decoder parameters
decoder_raw = ckpt["policy"][1]["params"]["decoder"]
decoder_param = {"params": decoder_raw}
# initialize the action distribution
action_distribution = distribution.NormalTanhDistribution(
    event_size=network_config["action_size"]
)

# prevent recompilation
jit_env_reset = jax.jit(env.reset, static_argnames=("clip_idx",))
jit_env_step = jax.jit(env.step)
jit_apply = jax.jit(decoder.apply)

In [33]:
from tensorflow.keras import layers


class DecoderTF(tf.keras.Model):
    def __init__(
        self,
        layer_sizes,
        activation=tf.nn.relu,
        kernel_init="lecun_uniform",
        activate_final=False,
        bias=True,
    ):
        """
        Initializes the Decoder.

        Args:
            layer_sizes (Sequence[int]): List of layer sizes for each Dense layer.
            activation (callable): Activation function to apply.
            kernel_init (str or callable): Kernel initializer for Dense layers.
            activate_final (bool): Whether to apply activation (and layer norm) on the final layer.
            bias (bool): Whether the Dense layers should use a bias term.
        """
        super().__init__()
        self.layer_sizes = layer_sizes
        self.activation = activation
        self.kernel_init = kernel_init
        self.activate_final = activate_final
        self.bias = bias

        # Build lists to store Dense layers and corresponding LayerNorm layers.
        self.dense_layers = []
        self.layer_norms = []

        for i, size in enumerate(self.layer_sizes):
            dense_layer = layers.Dense(
                size,
                kernel_initializer=self.kernel_init,
                use_bias=self.bias,
                name=f"hidden_{i}",
            )
            self.dense_layers.append(dense_layer)
            # Apply activation and layer norm if it's not the final layer or if activate_final is True.
            if i != len(self.layer_sizes) - 1 or self.activate_final:
                self.layer_norms.append(
                    layers.LayerNormalization(name=f"LayerNorm_{i}")
                )
            else:
                self.layer_norms.append(None)

    def call(self, inputs, get_activation=False):
        """
        Forward pass through the network.

        Args:
            inputs (tf.Tensor): Input tensor.
            get_activation (bool): If True, also return a dict with activations per layer.

        Returns:
            tf.Tensor or Tuple[tf.Tensor, dict]: The output tensor, and optionally a dictionary
            mapping layer names to their activations.
        """
        activations = {}
        x = inputs
        for i, dense_layer in enumerate(self.dense_layers):
            x = dense_layer(x)
            # Apply activation (and layer norm) for all layers except the final one
            # unless activate_final is True.
            if i != len(self.layer_sizes) - 1 or self.activate_final:
                x = self.activation(x)
                if self.layer_norms[i] is not None:
                    x = self.layer_norms[i](x)
                if get_activation:
                    activations[f"layer_{i}"] = x
        if get_activation:
            return x, activations
        return x


def make_policy_network(
    param_size,
    hidden_layer_sizes=[512, 512, 512],
):
    policy_network = DecoderTF(
        layer_sizes=list(hidden_layer_sizes) + [param_size],
    )
    return policy_network

In [34]:
tf_policy_network = make_policy_network(param_size=act_size * 2)

In [46]:
decoder_raw["LayerNorm_0"].keys()

dict_keys(['bias', 'scale'])

In [36]:
example_input = tf.zeros((1, 60 + 147))
example_output = tf_policy_network(example_input)
print(example_output.shape)

(1, 76)


In [37]:
tf_policy_network.layers

[<Dense name=hidden_0, built=True>,
 <LayerNormalization name=LayerNorm_0, built=True>,
 <Dense name=hidden_1, built=True>,
 <LayerNormalization name=LayerNorm_1, built=True>,
 <Dense name=hidden_2, built=True>,
 <LayerNormalization name=LayerNorm_2, built=True>,
 <Dense name=hidden_3, built=True>]

In [47]:
import numpy as np
import tensorflow as tf


def transfer_weights(jax_params, tf_model):
    """
    Transfer weights from a JAX parameter dictionary to the TensorFlow model.

    Parameters:
    - jax_params: dict
      Nested dictionary with structure {block_name: {layer_name: {params}}}.
      For example:
      {
        'CNN_0': {
          'Conv_0': {'kernel': np.ndarray},
          'Conv_1': {'kernel': np.ndarray},
          'Conv_2': {'kernel': np.ndarray},
        },
        'MLP_0': {
          'hidden_0': {'kernel': np.ndarray, 'bias': np.ndarray},
          'hidden_1': {'kernel': np.ndarray, 'bias': np.ndarray},
          'hidden_2': {'kernel': np.ndarray, 'bias': np.ndarray},
        }
      }

    - tf_model: tf.keras.Model
      An instance of the adapted VisionMLP model containing named submodules and layers.
    """
    for layer_name, layer_params in jax_params.items():
        try:
            tf_layer = tf_model.get_layer(name=layer_name)
        except ValueError:
            print(f"Layer {layer_name} not found in TensorFlow model.")
            continue
        if isinstance(tf_layer, tf.keras.layers.Dense):
            kernel = np.array(layer_params["kernel"])
            bias = np.array(layer_params["bias"])
            print(
                f"Transferring Dense layer {layer_name}, kernel shape {kernel.shape}, bias shape {bias.shape}"
            )
            tf_layer.set_weights([kernel, bias])
        elif isinstance(tf_layer, tf.keras.layers.LayerNormalization):
            gamma = np.array(layer_params["scale"])
            beta = np.array(layer_params["bias"])
            print(
                f"Transferring LayerNorm layer {layer_name}, gamma shape {gamma.shape}, beta shape {beta.shape}"
            )
            tf_layer.set_weights([gamma, beta])
        else:
            print(f"Unhandled layer type in {layer_name}: {type(tf_layer)}")

    print("Weights transferred successfully.")

In [48]:
transfer_weights(decoder_raw, tf_policy_network)

Transferring LayerNorm layer LayerNorm_0, gamma shape (512,), beta shape (512,)
Transferring LayerNorm layer LayerNorm_1, gamma shape (512,), beta shape (512,)
Transferring LayerNorm layer LayerNorm_2, gamma shape (512,), beta shape (512,)
Transferring Dense layer hidden_0, kernel shape (207, 512), bias shape (512,)
Transferring Dense layer hidden_1, kernel shape (512, 512), bias shape (512,)
Transferring Dense layer hidden_2, kernel shape (512, 512), bias shape (512,)
Transferring Dense layer hidden_3, kernel shape (512, 76), bias shape (76,)
Weights transferred successfully.


In [51]:
output_path = "decoder.onnx"

# Example inputs for the model
test_input = [np.ones((1, 60 + 147), dtype=np.float32)]

# Define the TensorFlow input signature
spec = [tf.TensorSpec(shape=(1, 60 + 147), dtype=tf.float32, name="obs")]

tensorflow_pred = tf_policy_network(test_input)[0]
# Build the model by calling it with example data
print(f"Tensorflow prediction: {tensorflow_pred}")

tf_policy_network.output_names = ["continuous_actions"]

# opset 11 matches isaac lab.
model_proto, _ = tf2onnx.convert.from_keras(
    tf_policy_network, input_signature=spec, opset=11, output_path=output_path
)

# Run inference with ONNX Runtime
output_names = ["continuous_actions"]
providers = ["CPUExecutionProvider"]
m = rt.InferenceSession(output_path, providers=providers)

Tensorflow prediction: [[ 1.8439487   1.8195764   0.96545035 -0.439171   -0.40265936  0.2564969
   0.1916398   0.7053365   0.10381285 -0.8021101   1.0065378   0.4596085
   0.63145524 -0.3105858   0.07498297 -0.46999425  0.20024359 -0.73598313
  -1.059566   -0.02754471  0.0551969  -0.17984359 -0.02428782  0.24375072
  -0.72107244 -0.9132743   0.18826191 -0.68193454 -0.06793734  0.08407827
  -0.16330367  0.8781642  -0.5258842  -1.2544575  -0.27115586 -0.5708299
  -0.6491313   0.3832001   2.5505161   2.4471612   3.5239034   1.0350533
   1.0109748   0.5961331   0.20818818  1.012076    0.9032251   0.99604803
   1.6650752   1.0874788   2.1373665   0.8951385   1.0131183   1.1382874
   2.09338     2.1891723   2.6804817   0.3315918   0.00923654  0.72923476
   0.29644024  1.1556864   0.81286275  1.0483586   0.6993547   0.7837227
   0.21743718  0.48821545  0.28367886  0.42269164  0.6529687   0.412511
   0.7705666   0.99292445  0.6567806   0.11907423]]


I0000 00:00:1741761195.481004   36020 devices.cc:67] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1
I0000 00:00:1741761195.481153   36020 single_machine.cc:374] Starting new session
I0000 00:00:1741761195.487490   36020 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 42374 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:17:00.0, compute capability: 8.6
I0000 00:00:1741761195.882569   36020 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 42374 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:17:00.0, compute capability: 8.6
I0000 00:00:1741761196.066661   36020 devices.cc:67] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 1
I0000 00:00:1741761196.066847   36020 single_machine.cc:374] Starting new session
I0000 00:00:1741761196.073196   36020 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 42374 MB memory: 

In [53]:
onnx_input = {"obs": np.ones((1, 147 + 60), dtype=np.float32)}
# Prepare inputs for ONNX Runtime
onnx_pred = m.run(output_names, onnx_input)[0][0]


print("ONNX prediction shape:", onnx_pred.shape)
print("ONNX prediction:", onnx_pred)

ONNX prediction shape: (76,)
ONNX prediction: [ 1.843951    1.8195753   0.965451   -0.43917096 -0.40265906  0.2564969
  0.19163963  0.70533675  0.10381252 -0.8021095   1.0065386   0.45960838
  0.6314552  -0.31058574  0.07498334 -0.4699943   0.20024568 -0.7359829
 -1.0595651  -0.02754396  0.05519673 -0.1798436  -0.02428795  0.2437514
 -0.7210727  -0.91327393  0.18826176 -0.6819357  -0.06793791  0.08407864
 -0.16330372  0.8781644  -0.5258841  -1.254457   -0.27115592 -0.57083106
 -0.64913124  0.38320005  2.5505152   2.4471622   3.5239048   1.0350527
  1.0109746   0.5961333   0.20818841  1.0120753   0.9032248   0.9960484
  1.6650743   1.0874784   2.137367    0.8951385   1.013118    1.1382874
  2.0933805   2.1891718   2.6804807   0.33159184  0.00923641  0.7292344
  0.29644024  1.1556861   0.8128632   1.0483584   0.6993551   0.78372234
  0.21743718  0.4882155   0.28367874  0.42269155  0.652969    0.41251102
  0.7705667   0.9929241   0.65678036  0.11907405]


In [None]:
test_input = {
    "state": jp.ones(obs_size["state"]),
    "privileged_state": jp.zeros(obs_size["privileged_state"]),
}
jax_pred, _ = inference_fn(test_input, jax.random.PRNGKey(0))
print(jax_pred)

In [None]:
import matplotlib.pyplot as plt

print(onnx_pred.shape)
print(tensorflow_pred.shape)
print(jax_pred.shape)
plt.plot(onnx_pred, label="onnx")
plt.plot(tensorflow_pred, label="tensorflow")
plt.plot(jax_pred, label="jax")
plt.legend()
plt.show()

In [20]:
# env_cfg = locomotion.get_default_config(env_name)
# env = locomotion.load(env_name, config=env_cfg)
# jit_reset = jax.jit(env.reset)
# jit_step = jax.jit(env.step)

In [21]:
# # Test the policy.

# # env_cfg = locomotion.get_default_config(env_name)
# # env_cfg.init_from_crouch = 0.0
# # env = locomotion.load(env_name, config=env_cfg)
# # env_cfg = manipulation.get_default_config(env_name)
# # env = manipulation.load(env_name, config=env_cfg)
# # jit_reset = jax.jit(env.reset)
# # jit_step = jax.jit(env.step)

# x = 0.8
# y = 0.0
# yaw = 0.3
# command = jp.array([x, y, yaw])
# # actions = []

# states = [state := jit_reset(jax.random.PRNGKey(555))]
# state.info["command"] = command
# for _ in range(env_cfg.episode_length):
#   onnx_input = {'obs': np.array(state.obs["state"].reshape(1, -1))}
#   action = m.run(output_names, onnx_input)[0][0]
#   state = jit_step(state, jp.array(action))
#   state.info["command"] = command
#   states.append(state)
#   # actions.append(state.info["motor_targets"])
#   # actions.append(action)
#   if state.done:
#     print("Unexpected termination.")
#     break

In [22]:
# import mediapy as media
# fps = 1.0 / env.dt

# frames = env.render(
#     states,
#     camera="track",
#     width=640*2,
#     height=480*2,
# )
# media.show_video(frames, fps=fps, loop=False)