# Generates Rollout for existing Checkpoint

This notebook will demonstrate the clean usage of the `track_mjx.analysis.rollout` module, which allows user to load the checkpoint from the previous training run, and perform a rollout of the checkpoint. This module abstracted away all the boilerplate codes for initializing the environment, and it is very clean and digestible, while customizable.

### Step 1: Imports and recover the config

In [1]:
# set environment variables for rendering
%env MUJOCO_GL=egl
%env PYOPENGL_PLATFORM=egl
%matplotlib inline
%load_ext autoreload
%autoreload 2

from omegaconf import OmegaConf
from track_mjx.analysis.rollout import (
    restore_config,
    create_rollout_generator,
    create_environment,
    create_inference_fn,
)
from track_mjx.analysis.render import (
    render_from_saved_rollout,
    display_video,

)
from jax import numpy as jp

# save and load
from track_mjx.analysis.utils import save_to_h5py
import h5py
from tqdm import tqdm

from IPython.display import clear_output

# recover the config
# ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/rodent_data/ReferenceClip.p_250127_062443"
ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250409_063617"
# ckpt_path = "/root/vast/kaiwen/track-mjx/model_checkpoints/fly_data/FlyReferenceClipFull.p_250130_061146"
# ckpt_path = "/root/vast/kaiwen/track-mjx/model_checkpoints/rodent_data/try"
config = restore_config(ckpt_path)
cfg = OmegaConf.create(config)

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/root/vast/scott-yang/track-mjx/data/ReferenceClip.p"
#cfg.data_path = "/root/vast/kaiwen/track-mjx/data/FlyReferenceClipFull.p"
cfg.train_setup.checkpoint_to_restore = ckpt_path

env: MUJOCO_GL=egl
env: PYOPENGL_PLATFORM=egl
latest checkpoint step: 998


### Step 2: Restore policy and make rollout functions

In [2]:
env = create_environment(cfg)
inference_fn = create_inference_fn(env, cfg)
generate_rollout = create_rollout_generator(cfg["reference_config"], env, inference_fn)

env._steps_for_cur_frame: 2.0
latest checkpoint step: 998


In [3]:
env.walker._pair_rendering_xml_path

'assets/rodent/rodent_ghostpair_scale080.xml'

In [4]:
model = env.walker._mjcf_model.model.ptr
for i in range(model.nsensor):
    print(f"{model.sensor(i).name}: dim={model.sensor_dim[i]}")

accelerometer: dim=3
velocimeter: dim=3
gyro: dim=3
torso: dim=3


### Step 3: Generate rollouts from the checkpoint!

After we load the checkpoint, we can do inference on the rollout!

#### Generate rollout for a single clip

The following cell will generate rollout for a single clip, specified by the clip id. The first time you call the function JAX needs to complete the `JIT` compilation, which will take around 3 minutes to execute and compile. After compilation, generates the rollout will only take about 8 seconds, since it is hardware accelerated.

In [5]:
import jax

# first pass will take ~2m38s to run to compile
# after complied, it only takes ~9 seconds to run
output = generate_rollout(8)

In [6]:
frames = render_from_saved_rollout(output, walker_name='rodent')
display_video(frames, framerate=50)

MuJoCo Rendering...


100%|██████████| 500/500 [00:02<00:00, 175.78it/s]


#### Batch Generating Rollouts

Alternatively, you can use `jax.vmap` to parallelize the rollout function. To do so, pass a 1D array of clip indexes (`clip_idxs`) as input. 

On the first call, JAX will perform `JIT` compilation, which takes approximately **3 minutes**. Once compiled, subsequent rollouts execute in just **8 seconds**, benefiting from hardware acceleration.

**Note:** `vmap` compiles based on the input shape. This means that if you use the same length for `clip_idxs`, JAX will reuse the compiled function for acceleration. However, if the input length changes, JAX will **recompile the entire function**, incurring additional overhead.

In [9]:
# generate rollout for 5 clips simultaneously
jit_vmap_generate_rollout = jax.jit(jax.vmap(generate_rollout))
clip_idxs = jp.arange(0, 5)
jit_vmap_out = jit_vmap_generate_rollout(clip_idxs)

In [10]:
# hardware acceleration: 5 clips simultaneously
clip_idxs = jp.arange(10, 15)
jit_vmap_out2 = jit_vmap_generate_rollout(clip_idxs)

In [5]:
# triggers recompilation: 15 clips simultaneously
clip_idxs = jp.arange(15, 30)
jit_vmap_out2 = jit_vmap_generate_rollout(clip_idxs)

### Step 4: Save it to disk

In [4]:
jit_vmap_out['joint_forces'].shape

In [7]:
for i in tqdm(range(500, 842)):
    output = generate_rollout(i)
    with h5py.File(f"/root/vast/kaiwen/track-mjx/force_analysis_rollout_better/clip_{i}.h5", "w") as h5file:
        save_to_h5py(h5file, output)
    clear_output(wait=True)

 81%|████████  | 404/500 [38:57<09:13,  5.76s/it]