# Checkpoint Playground

This notebook is a playground for the saved checkpoints of `track-mjx`'s kinematic replay tasks. It allows you to load the intention network and perform different operations on it, such as inspecting the decoder with random intentions, generating a trajectory from a given intention, evaluating model performance with various metrics, visualizing the trajectories and intention distributions, and implementing additional analysis tools to enhance the exploration of model behavior and capabilities.

In [1]:
# imports

%load_ext autoreload
%autoreload 2

import os
import logging

# Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"

# This is important when you have multiple notebook that is using JAX on the same machine
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = os.environ.get(
    "XLA_PYTHON_CLIENT_MEM_FRACTION", "0.4"
)

from track_mjx.agent import checkpointing
from track_mjx.agent.intention_network import Decoder
from track_mjx.analysis.rollout import (
    create_rollout_generator,
    create_environment,
)
from track_mjx.environment.wrappers import EvalClipResetWrapper
from track_mjx.analysis.render import render_from_saved_rollout, display_video
from track_mjx.analysis.utils import save_to_h5py, load_from_h5py

import jax
from jax import numpy as jnp
from brax.training import distribution
from pathlib import Path

# rendering related
import mujoco
from dm_control import mjcf as mjcf_dm
from tqdm import tqdm
import numpy as np
from dm_control.locomotion.walkers import rescale

## Network Initialization and Checkpoint Loading

First, we need to load the checkpoint. If you want to load your own custom checkpoint, you can replace the `ckpt_path` variable to your checkpoint path.

In [2]:
# replace with your checkpoint path
ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156"
# Load config from checkpoint
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/root/vast/scott-yang/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

Loading checkpoint from /root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156 at step 88


The following code initializes the network and loads the checkpoint.

In [3]:
# initialize the abstract decoder
network_config = cfg["network_config"]

# initialize the decoder with last layer represent the mean and variance of the action distribution
decoder = Decoder(
    network_config["decoder_layer_sizes"] + [network_config["action_size"] * 2]
)

# load the decoder parameters
decoder_raw = ckpt["policy"][1]["params"]["decoder"]
decoder_param = {"params": decoder_raw}
# create the action distributions
action_distribution = distribution.NormalTanhDistribution(
    event_size=network_config["action_size"]
)

Test run: feed decoder with zero intentions

In [4]:
jit_apply = jax.jit(decoder.apply)

actions = action_distribution.mode(jit_apply(decoder_param, jnp.zeros(60 + 147)))
print("Number of actions:", len(actions))
actions

Number of actions: 38


Array([ 0.02897133, -0.03831713, -0.02078261,  0.58814603,  0.17070429,
       -0.03250994, -0.07583926,  0.0453225 ,  0.4106154 , -0.1252416 ,
        0.17990206, -0.18754841,  0.02986068, -0.0783918 ,  0.15125185,
       -0.24585633,  0.4074648 , -0.25561202,  0.22987182,  0.05525849,
        0.53349316,  0.0109764 , -0.03424448, -0.24943057, -0.01396631,
        0.40130252,  0.1403449 ,  0.27226597, -0.4767266 , -0.8498172 ,
        0.08346232, -0.11328582, -0.08953246,  0.34286287,  0.09904323,
        0.13242985, -0.5625049 , -0.80803865], dtype=float32)

## Examine the Decoder with Random Intentions

The first thing we can do is to examine the decoder with random intentions. We can generate random intentions and see how the decoder generates the trajectory from them.

In [5]:
# NOTE: To use accelerated JAX.JIT, only run the following cell once.
# If you run it multiple times, it will triggers recompilations.

env = create_environment(cfg)
env = EvalClipResetWrapper(env)

rodent_XML_path = "/root/vast/scott-yang/track-mjx/track_mjx/environment/walker/assets/rodent/rodent.xml"
root = mjcf_dm.from_path(rodent_XML_path)

rescale.rescale_subtree(
    root,
    cfg.walker_config.rescale_factor / 0.8,
    cfg.walker_config.rescale_factor / 0.8,
)

mj_model = mjcf_dm.Physics.from_mjcf_model(root).model.ptr
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
mj_data = mujoco.MjData(mj_model)

# Calulate realtime rendering fps
render_fps = (
    1.0 / mj_model.opt.timestep
) / cfg.env_config.env_args.physics_steps_per_control_step

rng = jax.random.PRNGKey(0)

# prevent recompilation
jit_env_reset = jax.jit(env.reset, static_argnames=("clip_idx",))
jit_env_step = jax.jit(env.step)
jit_apply = jax.jit(decoder.apply)


@jax.jit
def rollout(carry, _):
    state, key = carry
    # Generate intentions using JAX’s random for consistency
    key, subkey = jax.random.split(key)
    intentions = jax.random.normal(subkey, shape=(60,))

    # Concatenate intentions with a slice of the observation.
    concatenated = jnp.concatenate(
        [intentions, state.obs[..., cfg["network_config"]["reference_obs_size"] :]],
        axis=-1,
    )

    # Get the decoder output and sample an action.
    logits = jit_apply(decoder_param, concatenated)
    key, subkey = jax.random.split(key)
    action = action_distribution.mode(logits)

    # Step the environment.
    new_state = jit_env_step(state, action)

    # Return the new state and the qpos for later rendering.
    return (new_state, key), state.pipeline_state.qpos

env._steps_for_cur_frame: 2.0


### JAX Tips:

When performing training, it’s recommended to fully leverage JAX’s functionality, such as using `jax.lax.scan` to replace Python-native for loops. This brings the entire loop into the computation graph, allowing JAX to optimize it more effectively.

However, during inference and analysis, it’s often better to avoid compiling the entire function — including the `for` loop — all at once. Instead, you can compile individual sub-functions and call those smaller, compiled functions within a regular Python for loop. This approach is preferable because inference and analysis workflows often involve frequent modifications to the rollout functions. By keeping the loop outside the compiled graph, you can avoid unnecessary recompilation of the entire loop when only a small part of the algorithm changes.

To avoid costly recompilation, make sure your input parameter to the jitted function is the same type.

In [None]:
# Initialize state and random key.
rng, env_rng = jax.random.split(rng)
rng, init_key = jax.random.split(rng)

state = jit_env_reset(rng=env_rng, clip_idx=41)

qposes_rollout = []

for _ in range(200):
    # Generate intentions using JAX’s random for consistency
    rng, subkey = jax.random.split(rng)
    intentions = jax.random.normal(subkey, shape=(60,))

    # Concatenate intentions with a slice of the observation.
    concatenated = jnp.concatenate(
        [intentions, state.obs[..., cfg["network_config"]["reference_obs_size"] :]],
        axis=-1,
    )

    # Get the decoder output and sample an action.
    logits = jit_apply(decoder_param, concatenated)
    rng, subkey = jax.random.split(rng)
    action = action_distribution.mode(logits)
    qposes_rollout.append(state.pipeline_state.qpos)
    # Step the environment.
    state = jit_env_step(state, action)


# render the first 200 steps
mujoco.mj_kinematics(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model, height=480, width=640)
frames = []
print("MuJoCo Rendering...")
for qpos1 in tqdm(qposes_rollout):
    mj_data.qpos = qpos1
    mujoco.mj_forward(mj_model, mj_data)
    renderer.update_scene(mj_data, camera=cfg.env_config.render_camera_name)
    pixels = renderer.render()
    frames.append(pixels)

# take a look at the video
display_video(frames, framerate=50)

MuJoCo Rendering...


100%|██████████| 200/200 [00:00<00:00, 443.97it/s]


In [12]:
type(concatenated)

jaxlib.xla_extension.ArrayImpl

the following rollout is optimized in jax manner for a optimized computation graph, for a fixed number of steps output

first run will take ~4.5 minutes, second run takes about 6.3 seconds on an A40 GPU

In [9]:
# # Initialize state and random key.
# rng, env_rng = jax.random.split(rng)
# rng, init_key = jax.random.split(rng)

# init_state = jit_env_reset(rng=env_rng)
# carry = (init_state, init_key)

# # Run the rollout for 200 steps in a single compiled graph, storing results.
# (final_carry, qposes_rollout) = jax.lax.scan(rollout, carry, None, length=200)


# # save rendering and log to wandb
# mujoco.mj_kinematics(mj_model, mj_data)
# renderer = mujoco.Renderer(mj_model, height=480, width=640)
# frames = []
# print("MuJoCo Rendering...")
# for qpos1 in tqdm(qposes_rollout):
#     mj_data.qpos = qpos1
#     mujoco.mj_forward(mj_model, mj_data)
#     renderer.update_scene(mj_data, camera=cfg.env_config.render_camera_name)
#     pixels = renderer.render()
#     frames.append(pixels)

# # take a look at the video
# display_video(frames, framerate=50)

As you can see, the rendering of the decoder actions with random intentions is not very good, but still better than pure random actions that it remains upright and did not fall over.

## Recreate the Recorded Rollout

The following cell recreate the recorded rollout based on the recorded actions taken during the saved rollout. In other words, the next for loop re-step through the physic simulator based on the recorded actions

In [10]:
# load the rollout data, generated in rollout_from_checkpoint.ipynb
example_rollout_path = (
    "/root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156/rollout_41.h5"
)

rollout_data = load_from_h5py(example_rollout_path)

# render the saved rollout
frames, realtime_framerate = render_from_saved_rollout(cfg, rollout_data)
display_video(frames, framerate=realtime_framerate)

MuJoCo Rendering...


100%|██████████| 500/500 [00:02<00:00, 227.49it/s]


## Step Through Environment Based on Recorded Action

The following cell step through the MuJoCo environment based on the recorded action taken during the saved rollout. In other words, the next for loop re-step through the physic simulator based on the recorded actions

In [14]:
# Initialize state and random key.
rng, env_rng = jax.random.split(rng)
rng, init_key = jax.random.split(rng)

state = jit_env_reset(rng=env_rng, clip_idx=41)

qposes = []

# make sure that the array is the same input type that gets feed into the jit_env_step
ctrl = jnp.array(rollout_data["ctrl"])

for action in rollout_data["ctrl"]:
    state = jit_env_step(state, action)
    qposes.append(state.pipeline_state.qpos)

# generate rendering
mujoco.mj_kinematics(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model, height=480, width=640)
frames = []
print("MuJoCo Rendering...")
for qpos1 in tqdm(qposes):
    mj_data.qpos = qpos1
    mujoco.mj_forward(mj_model, mj_data)
    renderer.update_scene(mj_data, camera=cfg.env_config.render_camera_name)
    pixels = renderer.render()
    frames.append(pixels)
# take a look at the video
display_video(frames, framerate=100)

MuJoCo Rendering...


100%|██████████| 499/499 [00:01<00:00, 474.39it/s]


## Load the Rollout File and run through decoder with recorded intentions

In [None]:
## Following code is for batch rollout
# @jax.jit
# def rollout_with_recorded_intention(carry, inputs):
#     intention, obs = inputs
#     state, key = carry
#     # 'intention' is a vector of shape (60,) from your activations tensor.
#     concatenated = jnp.concatenate(
#         [intention, obs[..., cfg["network_config"]["reference_obs_size"] :]],
#         axis=-1,
#     )

#     # Get the decoder output and sample an action.
#     logits = jit_apply(decoder_param, concatenated)
#     key, subkey = jax.random.split(key)
#     action = action_distribution.mode(logits)

#     # Step the environment.
#     new_state = jit_env_step(state, action)

#     # Return the new state and the qpos for later rendering.
#     return (new_state, key), state.pipeline_state.qpos

In [None]:
# pure for loop of the above function:

qposes = []

state = jit_env_reset(rng=env_rng, clip_idx=41)

intentions = jnp.array(rollout_data["activations"]["intention"])

for intention in intentions:
    concatenated = jnp.concatenate(
        [intention, state.obs[..., cfg["network_config"]["reference_obs_size"] :]],
        axis=-1,
    )
    logits = jit_apply(decoder_param, concatenated)
    action = action_distribution.mode(logits)
    action = jnp.array(action)
    state = jit_env_step(state, action)
    qposes.append(state.pipeline_state.qpos)

In [15]:
# take a look at the intention of this rollout data
intention_activations = rollout_data["activations"]["intention"]
obs = rollout_data["observations"][1:]

inputs = (intention_activations, obs)


# Initialize state and random key.
rng, env_rng = jax.random.split(rng)
rng, init_key = jax.random.split(rng)
init_state = jit_env_reset(rng=env_rng, clip_idx=41)
carry = (init_state, init_key)

# Run the rollout for 499 steps.
(final_carry, qposes_rollout) = jax.lax.scan(
    rollout_with_recorded_intention, carry, inputs
)

# generate rendering
mujoco.mj_kinematics(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model, height=480, width=640)
frames = []
print("MuJoCo Rendering...")
for qpos1 in tqdm(qposes_rollout):
    mj_data.qpos = qpos1
    mujoco.mj_forward(mj_model, mj_data)
    renderer.update_scene(mj_data, camera=cfg.env_config.render_camera_name)
    pixels = renderer.render()
    frames.append(pixels)

# take a look at the video
display_video(frames, framerate=50)

KeyboardInterrupt: 

# DEBUG: Discrepancies between the decoder outputs.

In [38]:
rollout_data["activations"]["intention"].shape, rollout_data["observations"].shape

((499, 60), (500, 617))

In [36]:
decoder_output = jit_apply(
    decoder_param,
    jnp.concat(
        [
            rollout_data["activations"]["intention"][1],
            rollout_data["observations"][1][
                ..., cfg["network_config"]["reference_obs_size"] :
            ],
        ],
        axis=-1,
    ),
)
post_processed_actions = action_distribution.mode(decoder_output)
post_processed_actions

Array([-0.11938352,  0.06624471, -0.21458954,  0.5335462 ,  0.23338881,
       -0.19262846,  0.04601811,  0.10032655,  0.42320186, -0.56087536,
        0.22693026, -0.5360854 ,  0.04485681, -0.51897395,  0.07576308,
       -0.1680723 ,  0.41057068, -0.25355864, -0.5224915 , -0.5156675 ,
       -0.10000287,  0.23435864,  0.18586385, -0.44378692,  0.05259428,
        0.28083146, -0.01142785,  0.5370593 , -0.57357585, -0.5034295 ,
        0.21739757, -0.18083051,  0.15029038,  0.19064692, -0.06631051,
        0.12721173, -0.7578277 , -0.5097104 ], dtype=float32)

In [21]:
reconst_input = jnp.concat(
    [
        rollout_data["activations"]["intention"][0],
        rollout_data["observations"][0][
            ..., cfg["network_config"]["reference_obs_size"] :
        ],
    ],
    axis=-1,
)

reconst_input.shape

(207,)

In [35]:
jnp.allclose(
    rollout_data["activations"]["decoder_inputs"][0],
    jnp.concat(
        [
            rollout_data["activations"]["intention"][0],
            rollout_data["observations"][1][
                ..., cfg["network_config"]["reference_obs_size"] :
            ],
        ],
        axis=-1,
    ),
)

Array(False, dtype=bool)

In [17]:
rollout_data["activations"]["decoder_inputs"][0][61:71]

array([-1.3304596 , -0.4962544 ,  0.9060362 , -0.0579351 , -0.00223444,
        0.4027019 , -0.26366568, -0.79364526,  0.43583098,  0.5819483 ],
      dtype=float32)

In [11]:
rollout_data["activations"]["decoder_inputs"][0].shape

(207,)

In [15]:
rollout_data["ctrl"][1]

array([-0.03618211, -0.29606265,  0.16927134,  0.47706315,  0.31237212,
       -0.04365025, -0.05012481,  0.01011869,  0.20906284, -0.6429165 ,
        0.27653074, -0.39353222, -0.63390845, -0.5489313 , -0.12551486,
       -0.2929979 ,  0.39697972, -0.52558404, -0.5085432 , -0.5282537 ,
       -0.31953505,  0.20412096,  0.17608738, -0.36201817,  0.06921589,
        0.26846585, -0.12370621,  0.48651183, -0.67890507, -0.798092  ,
       -0.00809924, -0.29790795,  0.09154154,  0.12609   , -0.34413195,
        0.4570201 , -0.6977743 , -0.75365824], dtype=float32)

In [44]:
frames, realtime_framerate = render_from_saved_rollout(cfg, rollout_data)
display_video(frames, framerate=realtime_framerate)

MuJoCo Rendering...


100%|██████████| 500/500 [00:02<00:00, 226.35it/s]


## Useful Utils -- Param Shape Comparison

The following cell compares the parameter structure between two checkpoints. When reconstructing the model and aligning with the submodule we created, the following utilities are essential to ensure that the parameters in the checkpoint are identical and applicable to the created model.

In [15]:
import jax
import jax.numpy as jnp
from track_mjx.agent.intention_network import Decoder


def compare_param_structure(params1, params2, path=""):
    """
    Recursively compare the structure of two parameter dictionaries.
    Returns a list of strings describing any key or shape differences.
    """
    differences = []
    # Compare keys in params1 against params2
    for key in params1:
        key_path = f"{path}/{key}" if path else key
        if key not in params2:
            differences.append(
                f"Key '{key_path}' exists in init params but not in checkpoint params."
            )
        else:
            val1 = params1[key]
            val2 = params2[key]
            # If both values are dicts, compare recursively.
            if isinstance(val1, dict) and isinstance(val2, dict):
                differences.extend(compare_param_structure(val1, val2, key_path))
            else:
                # Compare shape if the object has a shape attribute (e.g., a JAX array).
                shape1 = getattr(val1, "shape", None)
                shape2 = getattr(val2, "shape", None)
                if shape1 != shape2:
                    differences.append(
                        f"Key '{key_path}' has shape {shape1} in init params vs {shape2} in checkpoint params."
                    )
    # Check for keys in params2 that are not in params1.
    for key in params2:
        key_path = f"{path}/{key}" if path else key
        if key not in params1:
            differences.append(
                f"Key '{key_path}' exists in checkpoint params but not in init params."
            )
    return differences


# Assume your configuration dictionary is available as cfg.
network_config = cfg["network_config"]

# Create the Decoder module.
decoder = Decoder(
    network_config["decoder_layer_sizes"] + [network_config["action_size"] * 2]
)

# IMPORTANT: For Dense layers, ensure to include a batch dimension.
input_shape = (1, 60 + 147)
dummy_input = jnp.zeros(input_shape)

# Initialize the decoder's parameters.
init_variables = decoder.init(jax.random.PRNGKey(0), dummy_input)
params_init = init_variables["params"]

# Assume decoder_param holds your checkpoint parameters.
# If checkpoint parameters are wrapped in a 'params' key, extract them.
params_checkpoint = decoder_raw
if "params" in params_checkpoint:
    params_checkpoint = params_checkpoint["params"]

# Compare the structure of the two parameter dictionaries.
structure_diffs = compare_param_structure(params_init, params_checkpoint)
if structure_diffs:
    print("Structure differences between init params and checkpoint params:")
    for diff in structure_diffs:
        print(" -", diff)
else:
    print("No structural differences found between the parameter dictionaries.")

No structural differences found between the parameter dictionaries.
