# Generate and render a rollout for existing checkpoint

This notebook demonstrates how to load a training checkpoint, perform a rollout, and render the result. Full network activations are saved as an output of this rollout for further analysis.

## Terminology

**Rollout** - a process of running a trained agent in an environment to collect data. In this case, we will run the agent in the environment and save the activations of the network.

**Rollout File** - a file containing the activations of the network for each step of the rollout for a specific episode, saved in `.h5` formats.

## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import os
import logging
# Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = os.environ.get(
    "XLA_PYTHON_CLIENT_MEM_FRACTION", "0.4"
)


from track_mjx.agent import checkpointing
from track_mjx.analysis import rollout, render, utils

import jax
from jax import numpy as jp
from pathlib import Path



## Load checkpoint

In [5]:
# replace with your checkpoint path
# ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156"
ckpt_path = "/root/vast/kaiwen/track-mjx/model_checkpoints/rodent_data/RodentReferenceClip.h5_250416_022944" # LSTM checkpoint
# Load config from checkpoint
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/root/vast/scott-yang/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

### Restore policy and make rollout functions

In [3]:
env = rollout.create_environment(cfg)
inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
generate_rollout = rollout.create_rollout_generator(cfg, env, inference_fn)

env._steps_for_cur_frame: 2.0


### Generate rollouts from the checkpoint

After we load the checkpoint, we can do inference on the rollout!

We can generate a rollout imitating single clip, specified by the clip index. The first time you call the function there will be ~1-3min of compilation time, after which it will take only a few seconds.

In [4]:
clip_idx = 41
single_rollout = generate_rollout(clip_idx=clip_idx, seed=0)

In [7]:
single_rollout["activations"]["decoder_inputs"]

Array([[ 2.97204942e-01,  2.54580110e-01,  1.37342319e-01, ...,
         6.69378648e-03, -4.29338291e-02, -8.50227848e-02],
       [ 1.28380883e+00,  2.03114009e+00,  5.88994384e-01, ...,
        -3.92604041e+00,  7.82569027e+00,  7.75961733e+00],
       [ 6.20529056e-01,  2.28442097e+00,  1.04933619e-01, ...,
         5.17710507e-01,  2.51279283e+00,  4.50232410e+00],
       ...,
       [ 9.70331907e-01,  1.32450163e+00,  5.03547072e-01, ...,
         1.17652245e-01, -2.93279231e-01, -1.03792116e-01],
       [ 1.03791022e+00,  1.37213612e+00,  4.76267785e-01, ...,
         2.35759281e-02, -3.20093036e-01, -1.43825144e-01],
       [ 1.10235357e+00,  1.43230534e+00,  3.74840140e-01, ...,
         9.23475344e-03, -2.34795928e-01, -1.94541767e-01]],      dtype=float32)

#### Batch Generating Rollouts

Alternatively, you can use `jax.vmap` to parallelize the rollout function. This is useful for performing a rollout over an entire dataset for eval/analysis purposes. We pass in a 1D array of clip indexes (`clip_idxs`) as input. 

The first run for this will also have a few minutes of compilation time.

**Note:** `vmap` compiles based on the input shape. This means that if you use the same length for `clip_idxs`, JAX will reuse the compiled function for acceleration. However, if the input length changes, JAX will **recompile the entire function**, incurring additional overhead.

In [None]:
# Generate rollout for 5 clips simultaneously
jit_vmap_generate_rollout = jax.jit(jax.vmap(generate_rollout))
clip_idxs = jp.arange(0, 5)
jit_vmap_out = jit_vmap_generate_rollout(clip_idxs)

In [None]:
# Running it with a different clip_idxs length will cause reocmpilation
clip_idxs = jp.arange(15, 30)
jit_vmap_out2 = jit_vmap_generate_rollout(clip_idxs)

### Step 4: Save it to disk

In [8]:
save_path = Path(ckpt_path) / f"rollout_{clip_idx}.h5"
print(f"Saving rollout to {save_path}")

Saving rollout to /root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156/rollout_41.h5


In [9]:
utils.save_to_h5py(save_path.resolve(), single_rollout)

# Render Rollout Videos from the Saved Rollouts

## Load the rollout file

In [10]:
rollout = utils.load_from_h5py(save_path)

## Render rollout

Note: Currently only works for single rollouts

In [11]:
frames, realtime_framerate = render.render_rollout(cfg, single_rollout)
render.display_video(frames, framerate=realtime_framerate)

MuJoCo Rendering...


100%|██████████| 500/500 [00:02<00:00, 207.50it/s]
