# Generate and render a rollout for existing checkpoint

This notebook demonstrates how to load a training checkpoint, perform a rollout, and render the result. Full network activations are saved as an output of this rollout for further analysis.

## Terminology

**Rollout** - a process of running a trained agent in an environment to collect data. In this case, we will run the agent in the environment and save the activations of the network.

**Rollout File** - a file containing the activations of the network for each step of the rollout for a specific episode, saved in `.h5` formats.

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

import os
import logging
# Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"


from track_mjx.agent import checkpointing
from track_mjx.analysis import rollout, render, utils

import jax
from jax import numpy as jp
from pathlib import Path
from orbax import checkpoint as ocp
from omegaconf import OmegaConf

# Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

## Load checkpoint

If you are facing issues where the bug indicates that there are `config names` missing, it is likely that you are using a checkpoint that is older than the track-mjx system. Hence, you need to patch the old checkpoint configs to match with the new system. setting.

In [17]:
# replace with your checkpoint path
ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250521_200910"
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

# Load config from checkpoint
cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/root/vast/scott-yang/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

Loading checkpoint from /root/vast/scott-yang/track-mjx/model_checkpoints/250521_200910 at step 158


### Restore policy and make rollout functions

In [27]:
cfg = OmegaConf.create(cfg)
env = rollout.create_environment(cfg)
inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
generate_rollout = rollout.create_rollout_generator(cfg, env, inference_fn, model="mlp")

Converting to torque actuators
Rescaling body tree with scale factor 0.9


### Generate rollouts from the checkpoint

After we load the checkpoint, we can do inference on the rollout!

We can generate a rollout imitating single clip, specified by the clip index. The first time you call the function there will be ~1-3min of compilation time, after which it will take only a few seconds.

In [None]:
clip_idx = 0
single_rollout = generate_rollout(clip_idx=clip_idx, seed=0)

In [None]:
frames, render_fps = render.render_rollout(cfg, single_rollout)
render.display_video(frames, framerate=render_fps)

#### Batch Generating Rollouts

Alternatively, you can use `jax.vmap` to parallelize the rollout function. This is useful for performing a rollout over an entire dataset for eval/analysis purposes. We pass in a 1D array of clip indexes (`clip_idxs`) as input. 

The first run for this will also have a few minutes of compilation time.

**Note:** `vmap` compiles based on the input shape. This means that if you use the same length for `clip_idxs`, JAX will reuse the compiled function for acceleration. However, if the input length changes, JAX will **recompile the entire function**, incurring additional overhead.

In [None]:
# Generate rollout for 5 clips simultaneously
jit_vmap_generate_rollout = jax.jit(jax.vmap(generate_rollout))
clip_idxs = jp.arange(0, 5)
jit_vmap_out = jit_vmap_generate_rollout(clip_idxs)

In [4]:
jit_vmap_generate_rollout = jax.jit(jax.vmap(generate_rollout))

In [5]:
clip_idxs = jp.arange(0, 421)
jit_vmap_out = jit_vmap_generate_rollout(clip_idxs)

keys = ["qposes_ref", "qposes_rollout"]
masked = {key: jit_vmap_out[key] for key in keys}

save_path = Path(ckpt_path) / f"rollout_0_421.h5"
print(f"Saving rollout to {save_path}")
utils.save_to_h5py(save_path.resolve(), masked)

Saving rollout to /root/vast/scott-yang/track-mjx/model_checkpoints/250521_200910/rollout_0_421.h5


In [6]:
clip_idxs = jp.arange(421, 842)
jit_vmap_out = jit_vmap_generate_rollout(clip_idxs)

keys = ["qposes_ref", "qposes_rollout"]
masked = {key: jit_vmap_out[key] for key in keys}

save_path = Path(ckpt_path) / f"rollout_421_842.h5"
print(f"Saving rollout to {save_path}")
utils.save_to_h5py(save_path.resolve(), masked)

Saving rollout to /root/vast/scott-yang/track-mjx/model_checkpoints/250521_200910/rollout_421_842.h5


### Step 4: Save it to disk

In [None]:
save_path = Path(ckpt_path) / f"rollout_{clip_idx}.h5"
print(f"Saving rollout to {save_path}")

In [None]:
utils.save_to_h5py(save_path.resolve(), single_rollout)

# Render Rollout Videos from the Saved Rollouts

## Load the rollout file

In [None]:
rollout = utils.load_from_h5py(save_path)

## Render rollout

Note: Currently only works for single rollouts

In [None]:
frames, realtime_framerate = render.render_rollout(cfg, single_rollout)
render.display_video(frames, framerate=realtime_framerate)

# Batch Rollout & Save Locally

In [None]:
from track_mjx.analysis.utils import save_to_h5py
from tqdm import tqdm
import h5py

for i in tqdm(range(0, 800)):
    output = generate_rollout(i)
    path = f"/root/vast/kaiwen/track-mjx/force_analysis_rollout_charles/clip_{i}.h5"
    save_to_h5py(path, output)