# `track-mjx` - Checkpoint Playground

This notebook is a playground for the saved checkpoints of `track-mjx`'s kinematic replay tasks. It allows you to load the intention network and perform different operations on it, such as inspecting the decoder with random intentions, generating a trajectory from a given intention, evaluating model performance with various metrics, visualizing the trajectories and intention distributions, and implementing additional analysis tools to enhance the exploration of model behavior and capabilities.

## !! JAX's Sharp Tips !!

If you want to modify the input of a `jax.jit` compiled function, please do make sure that the input arguments have consistent types across the notebook. For consistency, this notebook uses `jnp.array` as the default type for all inputs. If you want to use `np.array`, please make sure that the input arguments are all `jnp.array`, or consistent with previous compilation time.

If you trigger excessive re-compilation, please review the input types and ensure they are properly aligned with previous compiled functions.

In [None]:
# imports

%load_ext autoreload
%autoreload 2

import os
import logging

# Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"

# This is important when you have multiple notebook that is using JAX on the same machine
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = os.environ.get(
    "XLA_PYTHON_CLIENT_MEM_FRACTION", "0.4"
)

from track_mjx.agent import checkpointing
from track_mjx.agent.intention_network import Decoder
from track_mjx.analysis.rollout import create_environment
from track_mjx.environment.wrappers import EvalClipResetWrapper
from track_mjx.analysis.render import render_from_saved_rollout, display_video
from track_mjx.analysis.utils import save_to_h5py, load_from_h5py

import jax
from jax import numpy as jnp
from brax.training import distribution
from brax.training.acme import running_statistics
from pathlib import Path

# rendering related
import mujoco
from dm_control import mjcf as mjcf_dm
from tqdm import tqdm
import numpy as np
from dm_control.locomotion.walkers import rescale

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Part 0: Network Initialization and Checkpoint Loading

First, we need to load the checkpoint. If you want to load your own custom checkpoint, you can replace the `ckpt_path` variable to your checkpoint path.

In [None]:
# replace with your checkpoint path
ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156"
# Load config from checkpoint
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/root/vast/scott-yang/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

# NOTE: To use accelerated JAX.JIT, only run the following code only once. 
# If you re-run, you will triggers recompilations

env = create_environment(cfg)
env = EvalClipResetWrapper(env)

# hardcoded for rodent
rodent_XML_path = "/root/vast/scott-yang/track-mjx/track_mjx/environment/walker/assets/rodent/rodent.xml"
root = mjcf_dm.from_path(rodent_XML_path)

rescale.rescale_subtree(
    root,
    cfg.walker_config.rescale_factor / 0.8,
    cfg.walker_config.rescale_factor / 0.8,
)

mj_model = mjcf_dm.Physics.from_mjcf_model(root).model.ptr
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
mj_data = mujoco.MjData(mj_model)

# Calulate realtime rendering fps
render_fps = (
    1.0 / mj_model.opt.timestep
) / cfg.env_config.env_args.physics_steps_per_control_step

rng = jax.random.PRNGKey(0)


def render_rollout(qposes_rollout: jnp.ndarray):
    """
    Render the rollout with the given qposes of a specific rollout.

    Args:
        qposes_rollout (jnp.ndarray): The qposes of the rollout to render.

    Returns:
        List of rendered frames.
    """
    # render the first 200 steps
    mujoco.mj_kinematics(mj_model, mj_data)
    renderer = mujoco.Renderer(mj_model, height=480, width=640)
    frames = []
    print("MuJoCo Rendering...")
    for qpos1 in tqdm(qposes_rollout):
        mj_data.qpos = qpos1
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera=cfg.env_config.render_camera_name)
        pixels = renderer.render()
        frames.append(pixels)
    return frames


# prevent recompilation
jit_env_reset = jax.jit(env.reset, static_argnames=("clip_idx",))
jit_env_step = jax.jit(env.step)
jit_apply = jax.jit(decoder.apply)

Loading checkpoint from /root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156 at step 88


The following code initializes the network and loads the checkpoint.

In [4]:
# initialize the abstract decoder
network_config = cfg["network_config"]

# initialize the decoder with last layer represent the mean and variance of the action distribution
decoder = Decoder(
    network_config["decoder_layer_sizes"] + [network_config["action_size"] * 2]
)

normalizer = running_statistics.normalize
# load the normalizer parameters
normalizer_param = ckpt["policy"][0]

# load the decoder parameters
decoder_raw = ckpt["policy"][1]["params"]["decoder"]
decoder_param = {"params": decoder_raw}
# initialize the action distribution
action_distribution = distribution.NormalTanhDistribution(
    event_size=network_config["action_size"]
)

Test run: feed decoder with zero intentions

In [22]:
jit_apply = jax.jit(decoder.apply)

actions = action_distribution.mode(jit_apply(decoder_param, jnp.zeros(60 + 147)))
print("Number of actions:", len(actions))
actions

Number of actions: 38


Array([ 0.02897133, -0.03831713, -0.02078261,  0.58814603,  0.17070429,
       -0.03250994, -0.07583926,  0.0453225 ,  0.4106154 , -0.1252416 ,
        0.17990206, -0.18754841,  0.02986068, -0.0783918 ,  0.15125185,
       -0.24585633,  0.4074648 , -0.25561202,  0.22987182,  0.05525849,
        0.53349316,  0.0109764 , -0.03424448, -0.24943057, -0.01396631,
        0.40130252,  0.1403449 ,  0.27226597, -0.4767266 , -0.8498172 ,
        0.08346232, -0.11328582, -0.08953246,  0.34286287,  0.09904323,
        0.13242985, -0.5625049 , -0.80803865], dtype=float32)

### JAX Tips:

When performing training, it’s recommended to fully leverage JAX’s functionality, such as using `jax.lax.scan` to replace Python-native for loops. This brings the entire loop into the computation graph, allowing JAX to optimize it more effectively.

However, during inference and analysis, it’s often better to avoid compiling the entire function — including the `for` loop — all at once. Instead, you can compile individual sub-functions and call those smaller, compiled functions within a regular Python for loop. This approach is preferable because inference and analysis workflows often involve frequent modifications to the rollout functions. By keeping the loop outside the compiled graph, you can avoid unnecessary recompilation of the entire loop when only a small part of the algorithm changes.

!! To avoid costly recompilation, make sure your input parameter to the jitted function is the same type.

In [None]:
# NOTE: To use accelerated JAX.JIT, only run the following cell once.
# If you run it multiple times, it will triggers recompilations.

env = create_environment(cfg)
env = EvalClipResetWrapper(env)

# hardcoded for rodent
rodent_XML_path = "/root/vast/scott-yang/track-mjx/track_mjx/environment/walker/assets/rodent/rodent.xml"
root = mjcf_dm.from_path(rodent_XML_path)

rescale.rescale_subtree(
    root,
    cfg.walker_config.rescale_factor / 0.8,
    cfg.walker_config.rescale_factor / 0.8,
)

mj_model = mjcf_dm.Physics.from_mjcf_model(root).model.ptr
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
mj_data = mujoco.MjData(mj_model)

# Calulate realtime rendering fps
render_fps = (
    1.0 / mj_model.opt.timestep
) / cfg.env_config.env_args.physics_steps_per_control_step

rng = jax.random.PRNGKey(0)


def render_rollout(qposes_rollout: jnp.ndarray):
    """
    Render the rollout with the given qposes of a specific rollout.

    Args:
        qposes_rollout (jnp.ndarray): The qposes of the rollout to render.

    Returns:
        List of rendered frames.
    """
    # render the first 200 steps
    mujoco.mj_kinematics(mj_model, mj_data)
    renderer = mujoco.Renderer(mj_model, height=480, width=640)
    frames = []
    print("MuJoCo Rendering...")
    for qpos1 in tqdm(qposes_rollout):
        mj_data.qpos = qpos1
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data, camera=cfg.env_config.render_camera_name)
        pixels = renderer.render()
        frames.append(pixels)
    return frames


# prevent recompilation
jit_env_reset = jax.jit(env.reset, static_argnames=("clip_idx",))
jit_env_step = jax.jit(env.step)
jit_apply = jax.jit(decoder.apply)

env._steps_for_cur_frame: 2.0


## Part 1: Decoder with Random Intentions

In this section, we will explore how the decoder behaves with random intentions, compared with the randomly generated actions.

In [21]:
# Initialize state and random key.
rng, env_rng = jax.random.split(rng)
rng, init_key = jax.random.split(rng)

state = jit_env_reset(rng=env_rng, clip_idx=41)

qposes_rollout = []

for _ in range(500):
    # Generate intentions using JAX’s random for consistency
    # normalize the observation
    obs = normalizer(state.obs, normalizer_param)
    rng, subkey = jax.random.split(rng)
    intentions = jax.random.normal(subkey, shape=(60,))
    # Concatenate intentions with a slice of the observation.
    concatenated = jnp.concatenate(
        [intentions, obs[..., cfg["network_config"]["reference_obs_size"] :]],
        axis=-1,
    )
    # Get the decoder output and sample an action.
    logits = jit_apply(decoder_param, concatenated)
    rng, subkey = jax.random.split(rng)
    # NOTE: This action returns as np array, from brax implementations
    # and this triggers the recompilation of jax.jit
    action = jnp.array(action_distribution.mode(logits))
    qposes_rollout.append(state.pipeline_state.qpos)
    # Step the environment.
    state = jit_env_step(state, action)


# render the first 200 steps
frames = render_rollout(jnp.array(qposes_rollout))
print("Rollout videos with random intentions")
display_video(frames, framerate=100)

MuJoCo Rendering...


100%|██████████| 500/500 [00:01<00:00, 489.01it/s]


Rollout videos with random intentions


first run will take ~7 minutes, second run takes about 20s seconds on an A40 GPU

In [19]:
# Initialize state and random key.
rng, env_rng = jax.random.split(rng)
rng, init_key = jax.random.split(rng)

state = jit_env_reset(rng=env_rng, clip_idx=41)

qposes_rollout = []

for _ in range(500):
    # Generate intentions using JAX’s random for consistency
    rng, subkey = jax.random.split(rng)
    action = jax.random.normal(subkey, shape=(38,))
    qposes_rollout.append(state.pipeline_state.qpos)
    # Step the environment.
    state = jit_env_step(state, action)


# render the first 200 steps
frames = render_rollout(jnp.array(qposes_rollout))
# take a look at the video
print("Rollout with random actions")
display_video(frames, framerate=100)

MuJoCo Rendering...


100%|██████████| 500/500 [00:01<00:00, 490.99it/s]


Rollout with random actions


As you can see, the rendering of the decoder actions with random intentions is not very good, but still better than pure random actions that it remains upright and did not fall over.

## Part 2: Recreate the Recorded Rollout

The following cell recreates the recorded rollout based on the recorded statistics taken during the saved rollout. We will explore different ways of recreating the recorded rollout, including:

- 2.0: Load the saved rollout
- 2.1: Step with recorded **actions**
- 2.2: Step with recorded **decoder inputs**
- 2.3: Step with recorded **intentions**

### Part 2.0: Load the saved rollout


In [24]:
# load the rollout data, generated in rollout_from_checkpoint.ipynb
example_rollout_path = (
    "/root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156/rollout_41.h5"
)

rollout_data = load_from_h5py(example_rollout_path)

# render the saved rollout
frames, realtime_framerate = render_from_saved_rollout(cfg, rollout_data)
display_video(frames, framerate=realtime_framerate)

MuJoCo Rendering...


100%|██████████| 500/500 [00:02<00:00, 210.84it/s]


### Part 2.1: Step with Recorded Actions

The following cell step through the MuJoCo environment based on the recorded action taken during the saved rollout. In other words, the next for loop re-step through the physic simulator based on the recorded actions

In [None]:
# Initialize state and random key.
rng, env_rng = jax.random.split(rng)
rng, init_key = jax.random.split(rng)

state = jit_env_reset(rng=env_rng, clip_idx=41)

qposes = []

# make sure that the array is the same input type that gets feed into the jit_env_step
ctrl = jnp.array(rollout_data["ctrl"])

for action in rollout_data["ctrl"]:
    state = jit_env_step(state, action)
    qposes.append(state.pipeline_state.qpos)

frames = render_rollout(jnp.array(qposes))
print("Rollout with the recorded actions")
display_video(frames, framerate=100)

MuJoCo Rendering...


100%|██████████| 499/499 [00:01<00:00, 471.50it/s]


Rollout with the recorded actions


From the video we can see that it generates a somewhat similar trajectory as the recorded one, but the rodent eventually fall over. This is because we are not using the same random seed for the mujoco simulator, and the physics engine is not deterministic. In order to make it deterministic, we need to set the random seed for the mujoco simulator.

### Part 2.2: Step with Recorded Decoder Inputs

The following cell steps through the MuJoCo environment based on the recorded decoder inputs taken during the saved rollout. In other words, the next for loop re-steps through the physics simulator based on the recorded decoder inputs.

In [28]:
inps = rollout_data["activations"]["decoder_inputs"]
inps = jnp.array(inps)


state = jit_env_reset(rng=env_rng, clip_idx=41)
qposes = []

for inp in inps:
    logits = jit_apply(decoder_param, inp)
    action = jnp.array(action_distribution.mode(logits))
    state = jit_env_step(state, action)
    qposes.append(state.pipeline_state.qpos)

frames = render_rollout(jnp.array(qposes))
print("Rollout videos with decoder inputs")
display_video(frames, framerate=100)

MuJoCo Rendering...


100%|██████████| 499/499 [00:01<00:00, 478.42it/s]


Rollout videos with decoder inputs


### Part 2.3: Step with Recorded Intentions

The following cell step through the MuJoCo environment based on the recorded intentions taken during the saved rollout. In other words, the next for loop re-step through the physic simulator based on the recorded intentions

In [42]:
qposes = []
state = jit_env_reset(rng=env_rng, clip_idx=41)
intentions = jnp.array(rollout_data["activations"]["intention"])

for intention in intentions:
    obs = normalizer(state.obs, normalizer_param)
    concatenated = jnp.concatenate(
        [intention, obs[..., cfg["network_config"]["reference_obs_size"] :]],
        axis=-1,
    )
    logits = jit_apply(decoder_param, concatenated)
    action = jnp.array(action_distribution.mode(logits))
    state = jit_env_step(state, action)
    qposes.append(state.pipeline_state.qpos)

frames = render_rollout(jnp.array(qposes))
print("Rollout videos with recorded intention")
display_video(frames, framerate=100)

MuJoCo Rendering...


100%|██████████| 499/499 [00:01<00:00, 487.02it/s]


Rollout videos with recorded intention


This is remarkable! The rodent is able to perform the locomotion task more stably than the recorded actions and decoder inputs. This is because, with the sensory feedback loop provided by egocentric observations fed back into the decoder, the low-level controller can continuously adjust locomotion based on real-time sensory feedback. This highlights the importance of the sensorimotor loop.

### Bonus: Step with concatenated intentions

In [43]:
frames = render_rollout(jnp.array(qposes)[130:230])
print("Rollout videos with recorded intention from 130-230")
display_video(frames, framerate=100)

MuJoCo Rendering...


100%|██████████| 100/100 [00:00<00:00, 486.20it/s]


Rollout videos with recorded intention from 130-230


In [45]:
qposes = []
state = jit_env_reset(rng=env_rng, clip_idx=41)


intentions = jnp.array(rollout_data["activations"]["intention"])[130:230]

for _ in range(20):
    for intention in intentions:
        obs = normalizer(state.obs, normalizer_param)
        concatenated = jnp.concatenate(
            [intention, obs[..., cfg["network_config"]["reference_obs_size"] :]],
            axis=-1,
        )
        logits = jit_apply(decoder_param, concatenated)
        action = jnp.array(action_distribution.mode(logits))
        state = jit_env_step(state, action)
        qposes.append(state.pipeline_state.qpos)

frames = render_rollout(jnp.array(qposes))
print("Rollout videos with 20 repeats of recorded intention from 130 to 230")
display_video(frames, framerate=100)

MuJoCo Rendering...


100%|██████████| 2000/2000 [00:04<00:00, 465.30it/s]


Rollout videos with 20 repeats of recorded intention from 130 to 230


## Misc. Useful Utils -- Param Shape Comparison

The following cell compares the parameter structure (the parameter here refers to the `flax.linen` model parameter, which is usually a `pytree` object) between two checkpoints. When reconstructing the model and aligning with the submodule we created, the following utilities are essential to ensure that the parameters in the checkpoint are identical and applicable to the created model.

In [15]:
import jax
import jax.numpy as jnp
from track_mjx.agent.intention_network import Decoder


def compare_param_structure(params1, params2, path=""):
    """
    Recursively compare the structure of two parameter dictionaries.
    Returns a list of strings describing any key or shape differences.
    """
    differences = []
    # Compare keys in params1 against params2
    for key in params1:
        key_path = f"{path}/{key}" if path else key
        if key not in params2:
            differences.append(
                f"Key '{key_path}' exists in init params but not in checkpoint params."
            )
        else:
            val1 = params1[key]
            val2 = params2[key]
            # If both values are dicts, compare recursively.
            if isinstance(val1, dict) and isinstance(val2, dict):
                differences.extend(compare_param_structure(val1, val2, key_path))
            else:
                # Compare shape if the object has a shape attribute (e.g., a JAX array).
                shape1 = getattr(val1, "shape", None)
                shape2 = getattr(val2, "shape", None)
                if shape1 != shape2:
                    differences.append(
                        f"Key '{key_path}' has shape {shape1} in init params vs {shape2} in checkpoint params."
                    )
    # Check for keys in params2 that are not in params1.
    for key in params2:
        key_path = f"{path}/{key}" if path else key
        if key not in params1:
            differences.append(
                f"Key '{key_path}' exists in checkpoint params but not in init params."
            )
    return differences


# Assume your configuration dictionary is available as cfg.
network_config = cfg["network_config"]

# Create the Decoder module.
decoder = Decoder(
    network_config["decoder_layer_sizes"] + [network_config["action_size"] * 2]
)

# IMPORTANT: For Dense layers, ensure to include a batch dimension.
input_shape = (1, 60 + 147)
dummy_input = jnp.zeros(input_shape)

# Initialize the decoder's parameters.
init_variables = decoder.init(jax.random.PRNGKey(0), dummy_input)
params_init = init_variables["params"]

# Assume decoder_param holds your checkpoint parameters.
# If checkpoint parameters are wrapped in a 'params' key, extract them.
params_checkpoint = decoder_raw
if "params" in params_checkpoint:
    params_checkpoint = params_checkpoint["params"]

# Compare the structure of the two parameter dictionaries.
structure_diffs = compare_param_structure(params_init, params_checkpoint)
if structure_diffs:
    print("Structure differences between init params and checkpoint params:")
    for diff in structure_diffs:
        print(" -", diff)
else:
    print("No structural differences found between the parameter dictionaries.")

No structural differences found between the parameter dictionaries.
