# Generate and render a rollout for existing checkpoint

This notebook demonstrates how to load a training checkpoint, perform a rollout, and render the result. Full network activations are saved as an output of this rollout for further analysis.

## Terminology

**Rollout** - a process of running a trained agent in an environment to collect data. In this case, we will run the agent in the environment and save the activations of the network.

**Rollout File** - a file containing the activations of the network for each step of the rollout for a specific episode, saved in `.h5` formats.

## Imports

In [2]:
%load_ext autoreload
%autoreload 2

import os
import logging

# # Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"

from track_mjx.agent import checkpointing
from track_mjx.analysis.rollout import (
    create_rollout_generator,
    create_environment,
)
from track_mjx.analysis.render import render_from_saved_rollout, display_video
from track_mjx.analysis.utils import save_to_h5py, load_from_h5py

import jax
from jax import numpy as jp
from pathlib import Path

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load checkpoint

In [None]:
# replace with your checkpoint path
ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156"
# Load config from checkpoint
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/root/vast/scott-yang/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

: 

### Restore policy and make rollout functions

In [3]:
env = create_environment(cfg)
inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
generate_rollout = create_rollout_generator(cfg["reference_config"], env, inference_fn)

INFO:root:Loading data: /root/vast/scott-yang/track-mjx/data/transform_snips.h5


env._steps_for_cur_frame: 2.0


### Generate rollouts from the checkpoint

After we load the checkpoint, we can do inference on the rollout!

We can generate a rollout imitating single clip, specified by the clip index. The first time you call the function there will be ~1-3min of compilation time, after which it will take only a few seconds.

In [4]:
single_rollout = generate_rollout(clip_idx=1, seed=0)

#### Batch Generating Rollouts

Alternatively, you can use `jax.vmap` to parallelize the rollout function. This is useful for performing a rollout over an entire dataset for eval/analysis purposes. We pass in a 1D array of clip indexes (`clip_idxs`) as input. 

The first run for this will also have a few minutes of compilation time.

**Note:** `vmap` compiles based on the input shape. This means that if you use the same length for `clip_idxs`, JAX will reuse the compiled function for acceleration. However, if the input length changes, JAX will **recompile the entire function**, incurring additional overhead.

In [9]:
# Generate rollout for 5 clips simultaneously
jit_vmap_generate_rollout = jax.jit(jax.vmap(generate_rollout))
clip_idxs = jp.arange(0, 5)
jit_vmap_out = jit_vmap_generate_rollout(clip_idxs)

In [10]:
# Running it with a different clip_idxs length will cause reocmpilation
clip_idxs = jp.arange(15, 30)
jit_vmap_out2 = jit_vmap_generate_rollout(clip_idxs)

### Step 4: Save it to disk

In [9]:
save_path = Path(ckpt_path) / "rollout.h5"
print(f"Saving rollout to {save_path}")

Saving rollout to /root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156/rollout.h5


In [6]:
save_to_h5py(save_path.resolve(), single_rollout)

# Render Rollout Videos from the Saved Rollouts

## Load the rollout file

In [10]:
rollout = load_from_h5py(save_path)

## Render rollout

Note: Currently only works for non-batched rollouts

In [11]:
frames, realtime_framerate = render_from_saved_rollout(cfg, rollout)
display_video(frames, framerate=realtime_framerate)

MuJoCo Rendering...


100%|██████████| 500/500 [00:02<00:00, 200.95it/s]
INFO:matplotlib.animation:Animation.save using <class 'matplotlib.animation.FFMpegWriter'>
INFO:matplotlib.animation:MovieWriter._run: running command: ffmpeg -f rawvideo -vcodec rawvideo -s 640x480 -pix_fmt rgba -framerate 100.0 -loglevel error -i pipe: -vcodec h264 -pix_fmt yuv420p -y /tmp/tmpwjrg69o1/temp.m4v
