# Jax Flat to ONNX
For envs trained in colab, which uses an earlier version of brax

## Importing Libraries

In [1]:
import os
os.environ["MUJOCO_GL"] = "egl"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import functools
from etils import epath
import numpy as np
import subprocess

import jax
import jax.numpy as jp
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.acme import running_statistics
from mujoco_playground.config import locomotion_params
from mujoco_playground import locomotion

import tf2onnx
import tensorflow as tf
from tensorflow.keras import layers

import onnxruntime as rt
from orbax import checkpoint as ocp


## Env Config

In [2]:
from hector_pg import constants as consts
env_name = "HectorWBCFlatTerrain"

ROOT_PATH = consts.ROOT_PATH
cpt_path = consts.ROOT_PATH.parent / 'logs' / 'wbc' / 's2_0820_3'

output_path = f"hector_wbc_s_test.onnx"

In [3]:
ppo_params = locomotion_params.brax_ppo_config(env_name)

def identity_observation_preprocessor(observation, preprocessor_params):
  del preprocessor_params
  return observation

network_factory=functools.partial(
  ppo_networks.make_ppo_networks,
  **ppo_params.network_factory,
  # We need to explicitly call the normalization function here since only the brax
  # PPO train.py script creates it if normalize_observations is True.
  preprocess_observations_fn=running_statistics.normalize,
)

## Load Env, Net dfn, Checkpoints

In [4]:
env_cfg = locomotion.get_default_config(env_name)
env = locomotion.load(env_name, config=env_cfg)

obs_size = env.observation_size
act_size = env.action_size
print(f"Observation Size: {obs_size}")
print(f"Action Size: {act_size}")

ppo_network = network_factory(obs_size, act_size)

l_hip_yaw: -0.79 to 0.79
l_hip_roll: -0.79 to 0.79
l_hip_pitch: -3.14 to 3.14
l_knee: -3.14 to 3.14
l_ankle: -1.40 to 1.40
r_hip_yaw: -0.79 to 0.79
r_hip_roll: -0.79 to 0.79
r_hip_pitch: -3.14 to 3.14
r_knee: -3.14 to 3.14
r_ankle: -1.40 to 1.40
l_shoulder_yaw: -1.40 to 1.40
l_shoulder_pitch: -3.14 to 1.57
l_shoulder_roll: -1.40 to 1.40
l_elbow: -3.00 to 3.00
r_shoulder_yaw: -1.40 to 1.40
r_shoulder_pitch: -3.14 to 1.57
r_shoulder_roll: -1.40 to 1.40
r_elbow: -3.00 to 3.00
Observation Size: {'privileged_state': (408,), 'state': (328,)}
Action Size: 18


In [5]:
checkpoint_root = epath.Path(cpt_path).expanduser().resolve()

checkpoint_dirs = list(checkpoint_root.glob('*'))
checkpoint_dirs = [ckpt for ckpt in checkpoint_dirs if ckpt.is_dir()]
checkpoint_dirs.sort(key=lambda x: int(x.name))

latest_checkpoint = checkpoint_dirs[-1]
print(f'Loading checkpoint: {latest_checkpoint}')
checkpointer = ocp.PyTreeCheckpointer()
params = checkpointer.restore(latest_checkpoint)

normalizer_params = params[1]
policy_params = params[2]

make_inference_fn = ppo_networks.make_inference_fn(ppo_network)
inference_fn = make_inference_fn(params, deterministic=True)

Loading checkpoint: /home/ps/Documents/hector_dev/src/logs/wbc/s2_0820_3/52920320




## Build TF Network

In [6]:
class MLP(tf.keras.Model):
    def __init__(
        self,
        layer_sizes,
        activation=tf.nn.relu,
        kernel_init="lecun_uniform",
        activate_final=False,
        bias=True,
        layer_norm=False,
        mean_std=None,
    ):
        super().__init__()

        self.layer_sizes = layer_sizes
        self.activation = activation
        self.kernel_init = kernel_init
        self.activate_final = activate_final
        self.bias = bias
        self.layer_norm = layer_norm

        if mean_std is not None:
            self.mean = tf.Variable(mean_std[0], trainable=False, dtype=tf.float32)
            self.std = tf.Variable(mean_std[1], trainable=False, dtype=tf.float32)
        else:
            self.mean = None
            self.std = None

        self.mlp_block = tf.keras.Sequential(name="MLP_0")
        for i, size in enumerate(self.layer_sizes):
            dense_layer = layers.Dense(
                size,
                activation=self.activation,
                kernel_initializer=self.kernel_init,
                name=f"hidden_{i}",
                use_bias=self.bias,
            )
            self.mlp_block.add(dense_layer)
            if self.layer_norm:
                self.mlp_block.add(layers.LayerNormalization(name=f"layer_norm_{i}"))
        if not self.activate_final and self.mlp_block.layers:
            if hasattr(self.mlp_block.layers[-1], 'activation') and self.mlp_block.layers[-1].activation is not None:
                self.mlp_block.layers[-1].activation = None

        self.submodules = [self.mlp_block]

    def call(self, inputs):
        if isinstance(inputs, list):
            inputs = inputs[0]
        if self.mean is not None and self.std is not None:
            print(self.mean.shape, self.std.shape)
            inputs = (inputs - self.mean) / self.std
        logits = self.mlp_block(inputs)
        loc, _ = tf.split(logits, 2, axis=-1)
        return tf.tanh(loc)

def make_policy_network(
    param_size,
    mean_std,
    hidden_layer_sizes=[256, 256],
    activation=tf.nn.relu,
    kernel_init="lecun_uniform",
    layer_norm=False,
):
    policy_network = MLP(
        layer_sizes=list(hidden_layer_sizes) + [param_size],
        activation=activation,
        kernel_init=kernel_init,
        layer_norm=layer_norm,
        mean_std=mean_std,
    )
    return policy_network

mean = params[0]['mean']['state']
std = params[0]['std']['state']

# Convert mean/std jax arrays to tf tensors.
mean_std = (tf.convert_to_tensor(mean), tf.convert_to_tensor(std))

tf_policy_network = make_policy_network(
    param_size=act_size * 2,
    mean_std=mean_std,
    hidden_layer_sizes=ppo_params.network_factory.policy_hidden_layer_sizes,
    activation=tf.nn.swish,
)

example_input = tf.zeros((1, obs_size["state"][0]))
example_output = tf_policy_network(example_input)
print(f"Example output shape: {example_output.shape}")

(328,) (328,)
(328,) (328,)
Example output shape: (1, 18)


## Transfering Weights

In [7]:
def transfer_weights(jax_params, tf_model):
    """
    Transfer weights from a JAX parameter dictionary to the TensorFlow model.
    """
    for layer_name, layer_params in jax_params.items():
        try:
            tf_layer = tf_model.get_layer("MLP_0").get_layer(name=layer_name)
        except ValueError:
            print(f"Layer {layer_name} not found in TensorFlow model.")
            continue
        if isinstance(tf_layer, tf.keras.layers.Dense):
            kernel = np.array(layer_params['kernel'])
            bias = np.array(layer_params['bias'])
            print(f"Transferring Dense layer {layer_name}, kernel shape {kernel.shape}, bias shape {bias.shape}")
            tf_layer.set_weights([kernel, bias])
        else:
            print(f"Unhandled layer type in {layer_name}: {type(tf_layer)}")

    print("Weights transferred successfully.")

transfer_weights(params[1]['params'], tf_policy_network)

Transferring Dense layer hidden_0, kernel shape (328, 1024), bias shape (1024,)
Transferring Dense layer hidden_1, kernel shape (1024, 512), bias shape (512,)
Transferring Dense layer hidden_2, kernel shape (512, 256), bias shape (256,)
Transferring Dense layer hidden_3, kernel shape (256, 36), bias shape (36,)
Weights transferred successfully.


## ONNX Conversion and Inference

In [None]:
# Example inputs for the model
test_input = [np.ones((1, obs_size["state"][0]), dtype=np.float32)]
# Define the TensorFlow input signature
spec = [tf.TensorSpec(shape=(1, obs_size["state"][0]), dtype=tf.float32, name="obs")]
tensorflow_pred = tf_policy_network(test_input)[0]
# Build the model by calling it with example data
print(f"Tensorflow prediction: {tensorflow_pred}")
tf_policy_network.output_names = ['continuous_actions']

# Monkey patch version attribute if missing, quick fix
if not hasattr(tf2onnx, 'version'):
    class DummyVersion:
        __version__ = "1.16.1"
        git_version = "dummy_git"

    tf2onnx.version = DummyVersion()

# opset 11 matches isaac lab.
model_proto, _ = tf2onnx.convert.from_keras(tf_policy_network, input_signature=spec, opset=11, output_path=output_path)

# Run inference with ONNX Runtime
output_names = ['continuous_actions']
providers = ['CPUExecutionProvider']
m = rt.InferenceSession(output_path, providers=providers)

# Dummy input
onnx_input = {
  'obs': np.ones((1, obs_size["state"][0]), dtype=np.float32)
}
# Prepare inputs for ONNX Runtime
onnx_pred = m.run(output_names, onnx_input)[0][0]

print("ONNX prediction:", onnx_pred)

(328,) (328,)
Tensorflow prediction: [ 1. -1.  1. -1.  1.  1. -1.  1.  1.  1.  1.  1.  1.  1. -1. -1. -1. -1.]
(328,) (328,)


I0000 00:00:1756103546.822929   21557 devices.cc:76] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 (Note: TensorFlow was not compiled with CUDA or ROCm support)
I0000 00:00:1756103546.823112   21557 single_machine.cc:376] Starting new session
I0000 00:00:1756103546.919445   21557 devices.cc:76] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0 (Note: TensorFlow was not compiled with CUDA or ROCm support)
I0000 00:00:1756103546.919648   21557 single_machine.cc:376] Starting new session


ONNX prediction: [ 1. -1.  1. -1.  1.  1. -1.  1.  1.  1.  1.  1.  1.  1. -1. -1. -1. -1.]


## Rollout and Visualization

In [9]:
env_cfg = locomotion.get_default_config(env_name)

env = locomotion.load(env_name, config=env_cfg)
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# Reset action scale to match policy
env._config.action_scale = 0.5

l_hip_yaw: -0.79 to 0.79
l_hip_roll: -0.79 to 0.79
l_hip_pitch: -3.14 to 3.14
l_knee: -3.14 to 3.14
l_ankle: -1.40 to 1.40
r_hip_yaw: -0.79 to 0.79
r_hip_roll: -0.79 to 0.79
r_hip_pitch: -3.14 to 3.14
r_knee: -3.14 to 3.14
r_ankle: -1.40 to 1.40
l_shoulder_yaw: -1.40 to 1.40
l_shoulder_pitch: -3.14 to 1.57
l_shoulder_roll: -1.40 to 1.40
l_elbow: -3.00 to 3.00
r_shoulder_yaw: -1.40 to 1.40
r_shoulder_pitch: -3.14 to 1.57
r_shoulder_roll: -1.40 to 1.40
r_elbow: -3.00 to 3.00


In [10]:
# # Rollout with onnx in mujoco playground
x_vel = 1.0  #@param {type: "number"}
y_vel = 0.0  #@param {type: "number"}
yaw_vel = 0.0  #@param {type: "number"}
body_height = 0.55
body_euler = jp.array([0,0,0])
qarm_left = jp.array([0.0, 0.0, 0.0, 0.0])
qarm_right = jp.array([0.0, 0.0, 0.0, 0.0])
command = jp.array([x_vel, y_vel, yaw_vel, body_height, *body_euler, *qarm_left, *qarm_right])

gait_freq = 1.0 # Hz
phase_dt = 2 * jp.pi * env.dt * gait_freq
phase = jp.array([0, jp.pi])

actions = []

states = [state := jit_reset(jax.random.PRNGKey(1927))] 
state.info["phase_dt"] = phase_dt
state.info["phase"] = phase
  
for _ in range(env_cfg.episode_length):
  onnx_input = {'obs': np.array(state.obs["state"].reshape(1, -1))}
  action = m.run(output_names, onnx_input)[0][0]
  state = jit_step(state, jp.array(action))
  print(action)
  state.info["command"] = command

  states.append(state)
  # actions.append(state.info["motor_targets"])
  # actions.append(action)
  if state.done:
    print("Unexpected termination.")
    break

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


[ 0.12111455 -0.05962836 -0.09541052 -0.08864472 -0.15967077 -0.03396695
  0.01995199  0.27641046  0.65379095 -0.28575718 -0.232775   -0.07531313
  0.01445885  0.24145658  0.2555966   0.0347183   0.12793052 -0.09207653]
[ 0.14394827 -0.20260744 -0.25009844 -0.3896737  -0.13970166 -0.11122089
  0.17649746  0.4136985   0.8632422  -0.40284634 -0.24989828 -0.06888857
  0.08171035  0.32565725  0.30329126  0.11391615  0.13051142  0.00784444]
[ 0.17024879 -0.10653716 -0.2669206  -0.30221817 -0.09002856 -0.07754818
 -0.05596821  0.27062523  0.88075924 -0.4549299  -0.16568723 -0.0575278
  0.04646344  0.14982438  0.18107574  0.00634623  0.13019428 -0.07230149]
[ 0.26098943 -0.09562615  0.02941147 -0.15688403 -0.04138898 -0.0712753
 -0.08174719  0.15790108  0.84287256 -0.4659128  -0.18970276 -0.05520608
 -0.02756009  0.09752022  0.03328868 -0.02783323  0.1080807  -0.13192827]
[ 0.3453794   0.02885916  0.11229324  0.7835272  -0.10545123  0.03090804
 -0.05191378  0.38571385  0.50673413 -0.4036032  

In [11]:
# Local ffmpeg
import os
import subprocess

# Add colon separator between paths!
bin_dir = (ROOT_PATH.parent / "visualize" / "ffmpeg" / "ffmpeg-7.0.2-amd64-static")
os.environ["PATH"] = os.pathsep.join([str(bin_dir), os.environ.get("PATH", "")])

print(subprocess.check_output(["ffmpeg", "-version"]).decode())

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-l

In [12]:
import mediapy as media

fps = 1.0 / env.dt

frames = env.render(
    states,
    camera="track",
    width=640,
    height=480,
)

media.show_video(frames, fps=fps, loop=False)

100%|██████████| 1001/1001 [00:05<00:00, 170.74it/s]


0
This browser does not support the video tag.
