# Generate and render a rollout for existing checkpoint

This notebook demonstrates how to load a training checkpoint, perform a rollout, and render the result. Full network activations are saved as an output of this rollout for further analysis.

## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "osmesa"
import mediapy as media

from track_mjx.agent import checkpointing
from track_mjx.analysis import rollout, render, utils

import jax
from jax import numpy as jp
from pathlib import Path

## Load checkpoint

In [2]:
# replace with your checkpoint path
ckpt_path = Path.cwd().parent / "model_checkpoints/scott_best"
# Load config from checkpoint 
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
# cfg.data_path = Path.cwd().parent / "data/transform_snips.h5"
# cfg.data_path = Path.cwd().parent / "data/transform_coltrane_2021_07_29_1.h5"
# cfg.train_setup.checkpoint_to_restore = ckpt_path



Loading checkpoint from /n/holylabs-olveczky/Users/charleszhang/track-mjx/model_checkpoints/scott_best at step 99


In [3]:
from vnl_mjx.tasks.rodent import imitation

def _track_to_vnl_cfg(cfg):
    """Replace the config values with the ones in our hydra cfg"""
    env_cfg = imitation.default_config()

    # Map environment parameters directly
    env_args = cfg.env_config.env_args
    env_cfg.solver = env_args.solver
    env_cfg.iterations = env_args.iterations
    env_cfg.ls_iterations = env_args.ls_iterations
    env_cfg.sim_dt = env_args.mj_model_timestep
    env_cfg.mocap_hz = env_args.mocap_hz

    # Map walker parameters directly
    walker_cfg = cfg.walker_config
    env_cfg.torque_actuators = walker_cfg.torque_actuators
    env_cfg.rescale_factor = walker_cfg.rescale_factor

    # Map reference parameters directly
    ref_cfg = cfg.reference_config
    env_cfg.clip_length = ref_cfg.clip_length
    env_cfg.reference_length = ref_cfg.traj_length
    env_cfg.start_frame_range = [0, ref_cfg.random_init_range]

    # Map reward terms directly
    reward_weights = cfg.env_config.reward_weights

    # Map imitation rewards
    env_cfg.reward_terms["root_pos"] = {
        "exp_scale": 1.0 / reward_weights.pos_reward_exp_scale,
        "weight": reward_weights.pos_reward_weight,
    }

    env_cfg.reward_terms["root_quat"] = {
        "exp_scale": 1.0 / reward_weights.quat_reward_exp_scale,
        "weight": reward_weights.quat_reward_weight,
    }

    env_cfg.reward_terms["joints"] = {
        "exp_scale": 1.0 / reward_weights.joint_reward_exp_scale,
        "weight": reward_weights.joint_reward_weight,
    }

    env_cfg.reward_terms["joints_vel"] = {
        "exp_scale": 1.0 / reward_weights.angvel_reward_exp_scale,
        "weight": reward_weights.angvel_reward_weight,
    }

    env_cfg.reward_terms["bodies_pos"] = {
        "exp_scale": 1.0 / reward_weights.bodypos_reward_exp_scale,
        "weight": reward_weights.bodypos_reward_weight,
    }

    env_cfg.reward_terms["end_eff"] = {
        "exp_scale": 1.0 / reward_weights.endeff_reward_exp_scale,
        "weight": reward_weights.endeff_reward_weight,
    }

    # Map cost terms (these exist in default config)
    env_cfg.reward_terms["control_cost"] = {"weight": reward_weights.ctrl_cost_weight}

    env_cfg.reward_terms["control_diff_cost"] = {
        "weight": reward_weights.ctrl_diff_cost_weight
    }

    # Handle energy_cost properly - it exists in default config but has different structure
    env_cfg.reward_terms["energy_cost"]["weight"] = reward_weights.energy_cost_weight

    # Map healthy z range (exists in default config)
    env_cfg.reward_terms["torso_z_range"] = {
        "healthy_z_range": tuple(reward_weights.healthy_z_range),
        "weight": 1.0,  # This doesnt exist in hydra cfg
    }

    # Map penalty parameters to termination criteria (these exist in default config)
    # env_cfg.termination_criteria["root_too_far"] = {
    #     "max_distance": reward_weights.too_far_dist
    # }

    # env_cfg.termination_criteria["pose_error"] = {
    #     "max_l2_error": reward_weights.bad_pose_dist
    # }

    # env_cfg.termination_criteria["root_too_rotated"] = {
    #     "max_degrees": reward_weights.bad_quat_dist
    # }

    return env_cfg

In [4]:
from vnl_mjx.tasks.rodent import wrappers as vnl_wrappers
env_cfg = _track_to_vnl_cfg(cfg)
env = vnl_wrappers.FlattenObsWrapper(imitation.Imitation(config=env_cfg))


In [5]:
reference_obs_size = env.non_proprioceptive_obs_size
proprioceptive_obs_size = env.proprioceptive_obs_size
print(f"Reference observation size: {reference_obs_size}")
print(f"Proprioceptive observation size: {proprioceptive_obs_size}")

2025-09-24 16:24:08.671300: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.
2025-09-24 16:24:14.896670: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


Reference observation size: 640
Proprioceptive observation size: 226


In [6]:
abstract_state = jax.eval_shape(env.reset, jax.random.PRNGKey(0))
abstract_state.obs

ShapeDtypeStruct(shape=(866,), dtype=float32)

### Restore policy and make rollout functions

In [3]:
cfg.env_config.env_args.max_start_frame = 40

env = rollout.create_environment(cfg)
inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
generate_rollout = rollout.create_rollout_generator(
    cfg, 
    env, 
    inference_fn, 
    log_activations=True, 
    log_metrics=True, 
    log_sensor_data=True
)

Converting to torque actuators
Rescaling body tree with scale factor 0.9




In [4]:
step = jax.jit(env.step)
reset = jax.jit(env.reset)

state = reset(jax.random.PRNGKey(0))
obs = env._get_proprioception(state.pipeline_state)

  brax_contact = _reformat_contact(sys, data.contact)


In [7]:
state.info["proprioceptive_obs_size"]

Array(231, dtype=int32, weak_type=True)

In [6]:
state.obs.shape

(701,)

In [5]:
obs.shape

(226,)

### Generate rollouts from the checkpoint

After we load the checkpoint, we can do inference on the rollout!

We can generate a rollout imitating single clip, specified by the clip index. The first time you call the function there will be ~1-3min of compilation time, after which it will take only a few seconds.

In [9]:
single_rollout = generate_rollout(clip_idx=0)

In [None]:
frames, realtime_framerate = render.render_rollout(
    cfg, 
    single_rollout, 
    height=480,
    width=640,
)

# save the video to disk
media.write_video(Path(ckpt_path) / "rollout.mp4", frames, fps=realtime_framerate)
media.show_video(frames, fps=realtime_framerate)

### Step 4: Save it to disk

In [8]:
save_path = Path(ckpt_path) / "rollout.h5"

In [9]:
utils.save_to_h5py(save_path.resolve(), single_rollout)

### Step 5: you can load it too

In [None]:
rollout = utils.load_from_h5py(save_path)

## Render rollout

Note: Currently only works for single rollouts

#### Batch Generating Rollouts

Alternatively, you can use `jax.vmap` to parallelize the rollout function. This is useful for performing a rollout over an entire dataset for eval/analysis purposes. We pass in a 1D array of clip indexes (`clip_idxs`) as input. 

The first run for this will also have a few minutes of compilation time.

**Note:** `vmap` compiles based on the input shape. This means that if you use the same length for `clip_idxs`, JAX will reuse the compiled function for acceleration. However, if the input length changes, JAX will **recompile the entire function**, incurring additional overhead.

In [9]:
# Generate rollout for 5 clips simultaneously
jit_vmap_generate_rollout = jax.jit(jax.vmap(generate_rollout))
clip_idxs = jp.arange(0, 5)
jit_vmap_out = jit_vmap_generate_rollout(clip_idxs)

In [10]:
# Running it with a different clip_idxs length will cause reocmpilation
clip_idxs = jp.arange(15, 30)
jit_vmap_out2 = jit_vmap_generate_rollout(clip_idxs)