# JaX Flax --> ONNX

This notebook converts brax MLP networks to an ONNX checkpoint.

In [None]:
import os

os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax

# Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [None]:
from brax.training.agents.ppo import networks as ppo_networks
from mujoco_playground.config import locomotion_params, manipulation_params
from mujoco_playground import locomotion, manipulation
import functools
import pickle
import jax.numpy as jp
import jax
import tf2onnx
import tensorflow as tf
from tensorflow.keras import layers
import onnxruntime as rt
from brax.training.acme import running_statistics

In [None]:
from track_mjx.environment.walker import rodent
from track_mjx.agent import checkpointing
from track_mjx.agent.intention_network import Decoder
from track_mjx.analysis.rollout import create_environment
from track_mjx.environment.wrappers import EvalClipResetWrapper
from track_mjx.analysis.render import render_from_saved_rollout, display_video
from track_mjx.analysis.utils import save_to_h5py, load_from_h5py

import jax
from jax import numpy as jnp
from brax.training import distribution
from brax.training.acme import running_statistics
from pathlib import Path

# replace with your checkpoint path
ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250306_194809"
# Load config from checkpoint
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/root/vast/scott-yang/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

# NOTE: To use accelerated JAX.JIT, only run the following code only once.
# If you re-run, you will triggers recompilations

env = create_environment(cfg)

# env_name = "BerkeleyHumanoidJoystickFlatTerrain"
# # ppo_params = locomotion_params.brax_ppo_config(env_name)
# ppo_params = locomotion_params.brax_ppo_config(env_name)z

In [None]:
obs_size = env.observation_size
act_size = env.action_size
print(obs_size, act_size)

In [None]:
cfg["network_config"]["reference_obs_size"]

In [None]:
# initialize the abstract decoder
network_config = cfg["network_config"]

# initialize the decoder with last layer represent the mean and variance of the action distribution
decoder = Decoder(
    network_config["decoder_layer_sizes"] + [network_config["action_size"] * 2]
)

normalizer = running_statistics.normalize
# load the normalizer parameters
normalizer_param = ckpt["policy"][0]

# load the decoder parameters
decoder_raw = ckpt["policy"][1]["params"]["decoder"]
decoder_param = {"params": decoder_raw}
# initialize the action distribution
action_distribution = distribution.NormalTanhDistribution(
    event_size=network_config["action_size"]
)

# prevent recompilation
jit_env_reset = jax.jit(env.reset, static_argnames=("clip_idx",))
jit_env_step = jax.jit(env.step)
jit_apply = jax.jit(decoder.apply)

In [None]:
normalizer_param

In [None]:
from tensorflow.keras import layers


class DecoderTF(tf.keras.Model):
    def __init__(
        self,
        layer_sizes,
        activation=tf.nn.relu,
        kernel_init="lecun_uniform",
        activate_final=False,
        bias=True,
        mean_std=None,
    ):
        """
        Initializes the Decoder, including the observation normalizer

        Args:
            layer_sizes (Sequence[int]): List of layer sizes for each Dense layer.
            activation (callable): Activation function to apply.
            kernel_init (str or callable): Kernel initializer for Dense layers.
            activate_final (bool): Whether to apply activation (and layer norm) on the final layer.
            bias (bool): Whether the Dense layers should use a bias term.
            mean_std (tuple): Mean and standard deviation for obs normalization.
        """
        super().__init__()
        self.layer_sizes = layer_sizes
        self.activation = activation
        self.kernel_init = kernel_init
        self.activate_final = activate_final
        self.bias = bias
        if mean_std is not None:
            self.mean = tf.Variable(mean_std[0], trainable=False, dtype=tf.float32)
            self.std = tf.Variable(mean_std[1], trainable=False, dtype=tf.float32)
        else:
            self.mean = None
            self.std = None

        # Build lists to store Dense layers and corresponding LayerNorm layers.
        self.dense_layers = []
        self.layer_norms = []

        for i, size in enumerate(self.layer_sizes):
            dense_layer = layers.Dense(
                size,
                kernel_initializer=self.kernel_init,
                use_bias=self.bias,
                name=f"hidden_{i}",
            )
            self.dense_layers.append(dense_layer)
            # Apply activation and layer norm if it's not the final layer or if activate_final is True.
            if i != len(self.layer_sizes) - 1 or self.activate_final:
                self.layer_norms.append(
                    layers.LayerNormalization(name=f"LayerNorm_{i}")
                )
            else:
                self.layer_norms.append(None)

    def call(self, inputs, get_activation=False):
        """
        Forward pass through the network.

        Args:
            inputs (tf.Tensor): Input tensor.
            get_activation (bool): If True, also return a dict with activations per layer.

        Returns:
            tf.Tensor or Tuple[tf.Tensor, dict]: The output tensor, and optionally a dictionary
            mapping layer names to their activations.
        """
        if isinstance(inputs, list):
            inputs = inputs[0]
        activations = {}
        if self.mean is not None and self.std is not None:
            # Normalize the part of the input starting at column 60 (intention size TODO: config this).
            normalized_part = (inputs[:, 60:] - self.mean) / self.std
            # Concatenate the unchanged part with the normalized part along axis 1.
            inputs = tf.concat([inputs[:, :60], normalized_part], axis=1)
        x = inputs
        for i, dense_layer in enumerate(self.dense_layers):
            x = dense_layer(x)
            # Apply activation (and layer norm) for all layers except the final one
            # unless activate_final is True.
            if i != len(self.layer_sizes) - 1 or self.activate_final:
                x = self.activation(x)
                if self.layer_norms[i] is not None:
                    x = self.layer_norms[i](x)
                if get_activation:
                    activations[f"layer_{i}"] = x
        if get_activation:
            return x, activations
        return x


def make_policy_network(
    param_size,
    hidden_layer_sizes=[512, 512, 512],
    mean_std=None,
):
    policy_network = DecoderTF(
        layer_sizes=list(hidden_layer_sizes) + [param_size],
        mean_std=mean_std,
    )
    return policy_network

In [None]:
ref_obs_size = cfg["network_config"]["reference_obs_size"]
mean_std = (normalizer_param.mean[ref_obs_size:], normalizer_param.std[ref_obs_size:])
tf_policy_network = make_policy_network(
    param_size=act_size * 2,
    mean_std=mean_std,
)

In [None]:
example_input = tf.zeros((1, 60 + 147))
example_output = tf_policy_network(example_input)
print(example_output.shape)

In [None]:
tf_policy_network.layers

In [None]:
import numpy as np
import tensorflow as tf


def transfer_weights(jax_params, tf_model):
    """
    Transfer weights from a JAX parameter dictionary to the TensorFlow model.

    Parameters:
    - jax_params: dict
      Nested dictionary with structure {block_name: {layer_name: {params}}}.
      For example:
      {
        'CNN_0': {
          'Conv_0': {'kernel': np.ndarray},
          'Conv_1': {'kernel': np.ndarray},
          'Conv_2': {'kernel': np.ndarray},
        },
        'MLP_0': {
          'hidden_0': {'kernel': np.ndarray, 'bias': np.ndarray},
          'hidden_1': {'kernel': np.ndarray, 'bias': np.ndarray},
          'hidden_2': {'kernel': np.ndarray, 'bias': np.ndarray},
        }
      }

    - tf_model: tf.keras.Model
      An instance of the adapted VisionMLP model containing named submodules and layers.
    """
    for layer_name, layer_params in jax_params.items():
        try:
            tf_layer = tf_model.get_layer(name=layer_name)
        except ValueError:
            print(f"Layer {layer_name} not found in TensorFlow model.")
            continue
        if isinstance(tf_layer, tf.keras.layers.Dense):
            kernel = np.array(layer_params["kernel"])
            bias = np.array(layer_params["bias"])
            print(
                f"Transferring Dense layer {layer_name}, kernel shape {kernel.shape}, bias shape {bias.shape}"
            )
            tf_layer.set_weights([kernel, bias])
        elif isinstance(tf_layer, tf.keras.layers.LayerNormalization):
            gamma = np.array(layer_params["scale"])
            beta = np.array(layer_params["bias"])
            print(
                f"Transferring LayerNorm layer {layer_name}, gamma shape {gamma.shape}, beta shape {beta.shape}"
            )
            tf_layer.set_weights([gamma, beta])
        else:
            print(f"Unhandled layer type in {layer_name}: {type(tf_layer)}")

    print("Weights transferred successfully.")

In [None]:
transfer_weights(decoder_raw, tf_policy_network)

In [None]:
output_path = "decoder.onnx"

# Example inputs for the model
test_input = [np.ones((1, 60 + 147), dtype=np.float32)]

# Define the TensorFlow input signature
spec = [tf.TensorSpec(shape=(1, 60 + 147), dtype=tf.float32, name="obs")]

tensorflow_pred = tf_policy_network(test_input)[0]
# Build the model by calling it with example data
print(f"Tensorflow prediction: {tensorflow_pred}")

tf_policy_network.output_names = ["continuous_actions"]

# opset 11 matches isaac lab.
model_proto, _ = tf2onnx.convert.from_keras(
    tf_policy_network, input_signature=spec, opset=11, output_path=output_path
)

# Run inference with ONNX Runtime
output_names = ["continuous_actions"]
providers = ["CPUExecutionProvider"]
m = rt.InferenceSession(output_path, providers=providers)

In [None]:
onnx_input = {"obs": np.ones((1, 147 + 60), dtype=np.float32)}
# Prepare inputs for ONNX Runtime
onnx_pred = m.run(output_names, onnx_input)[0][0]


print("ONNX prediction shape:", onnx_pred.shape)
print("ONNX prediction:", onnx_pred)

In [None]:
test_input = {
    "state": jp.ones(obs_size["state"]),
    "privileged_state": jp.zeros(obs_size["privileged_state"]),
}
jax_pred, _ = inference_fn(test_input, jax.random.PRNGKey(0))
print(jax_pred)

In [None]:
import matplotlib.pyplot as plt

print(onnx_pred.shape)
print(tensorflow_pred.shape)
print(jax_pred.shape)
plt.plot(onnx_pred, label="onnx")
plt.plot(tensorflow_pred, label="tensorflow")
plt.plot(jax_pred, label="jax")
plt.legend()
plt.show()

In [None]:
# env_cfg = locomotion.get_default_config(env_name)
# env = locomotion.load(env_name, config=env_cfg)
# jit_reset = jax.jit(env.reset)
# jit_step = jax.jit(env.step)

In [None]:
# # Test the policy.

# # env_cfg = locomotion.get_default_config(env_name)
# # env_cfg.init_from_crouch = 0.0
# # env = locomotion.load(env_name, config=env_cfg)
# # env_cfg = manipulation.get_default_config(env_name)
# # env = manipulation.load(env_name, config=env_cfg)
# # jit_reset = jax.jit(env.reset)
# # jit_step = jax.jit(env.step)

# x = 0.8
# y = 0.0
# yaw = 0.3
# command = jp.array([x, y, yaw])
# # actions = []

# states = [state := jit_reset(jax.random.PRNGKey(555))]
# state.info["command"] = command
# for _ in range(env_cfg.episode_length):
#   onnx_input = {'obs': np.array(state.obs["state"].reshape(1, -1))}
#   action = m.run(output_names, onnx_input)[0][0]
#   state = jit_step(state, jp.array(action))
#   state.info["command"] = command
#   states.append(state)
#   # actions.append(state.info["motor_targets"])
#   # actions.append(action)
#   if state.done:
#     print("Unexpected termination.")
#     break

In [None]:
# import mediapy as media
# fps = 1.0 / env.dt

# frames = env.render(
#     states,
#     camera="track",
#     width=640*2,
#     height=480*2,
# )
# media.show_video(frames, fps=fps, loop=False)