# Generate and render a rollout for existing checkpoint

This notebook demonstrates how to load a training checkpoint, perform a rollout, and render the result. Full network activations are saved as an output of this rollout for further analysis.

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "glfw"
import mediapy as media

from track_mjx.agent import checkpointing
from track_mjx.analysis import rollout, render, utils

import jax
from jax import numpy as jp
from pathlib import Path

## Load checkpoint

In [None]:
# replace with your checkpoint path
ckpt_path = "/root/vast/kaiwen/track-mjx/model_checkpoints/charles_good_rodent_data/250613_173348_472536" #ar_rodent_data/AR_250619_140732_166355
# Load config from checkpoint 
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
# cfg.data_path = Path.cwd().parent / "data/transform_snips.h5"
cfg.data_path = "/root/vast/kaiwen/track-mjx/data/RodentReferenceClip.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

### Restore policy and make rollout functions

In [None]:
env = rollout.create_environment(cfg)
inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
generate_rollout = rollout.create_rollout_generator(
    cfg, 
    env, 
    inference_fn, 
    log_activations=True, 
    log_metrics=True, 
    log_sensor_data=True
)

### Generate rollouts from the checkpoint

After we load the checkpoint, we can do inference on the rollout!

We can generate a rollout imitating single clip, specified by the clip index. The first time you call the function there will be ~1-3min of compilation time, after which it will take only a few seconds.

In [None]:
single_rollout = generate_rollout(clip_idx=0)

In [None]:
single_rollout['activations']['egocentric_obs'].shape

In [None]:
single_rollout['activations']['intention'].shape

In [None]:
single_rollout['ctrl'].shape

In [None]:
single_rollout['state_rewards'].shape

In [None]:
single_rollout['qposes_rollout'].shape

In [None]:
frames, realtime_framerate = render.render_rollout(
    cfg, 
    single_rollout, 
    height=480,
    width=640,
)

# save the video to disk
# media.write_video(Path(ckpt_path) / "rollout.mp4", frames, fps=realtime_framerate)
media.show_video(frames, fps=realtime_framerate)

### Step 4: Save it to disk

In [8]:
save_path = Path(ckpt_path) / "rollout.h5"

In [9]:
utils.save_to_h5py(save_path.resolve(), single_rollout)

### Step 5: you can load it too

In [None]:
rollout = utils.load_from_h5py(save_path)

## Render rollout

Note: Currently only works for single rollouts

#### Batch Generating Rollouts

Alternatively, you can use `jax.vmap` to parallelize the rollout function. This is useful for performing a rollout over an entire dataset for eval/analysis purposes. We pass in a 1D array of clip indexes (`clip_idxs`) as input. 

The first run for this will also have a few minutes of compilation time.

**Note:** `vmap` compiles based on the input shape. This means that if you use the same length for `clip_idxs`, JAX will reuse the compiled function for acceleration. However, if the input length changes, JAX will **recompile the entire function**, incurring additional overhead.

In [9]:
# Generate rollout for 5 clips simultaneously
jit_vmap_generate_rollout = jax.jit(jax.vmap(generate_rollout))
clip_idxs = jp.arange(0, 842)
jit_vmap_out = jit_vmap_generate_rollout(clip_idxs)

In [10]:
# Running it with a different clip_idxs length will cause reocmpilation
clip_idxs = jp.arange(15, 30)
jit_vmap_out2 = jit_vmap_generate_rollout(clip_idxs)

In [None]:
from track_mjx.analysis.utils import save_to_h5py
from tqdm import tqdm

for i in tqdm(range(0, 842)):
    output = generate_rollout(i)
    path = f"/root/vast/kaiwen/track-mjx/rollouts/rodent_gau_minimal/clip_{i}.h5"
    new_dict = {
        'obs':output['activations']['egocentric_obs'],
        'intention':output['activations']['intention'],
        'ctrl':output['ctrl'],
        'reward':output['state_rewards'],
        'qpos':output['qposes_rollout'],
                }
    save_to_h5py(path, new_dict)

In [None]:
from track_mjx.analysis.utils import save_to_h5py
from tqdm import tqdm

for i in tqdm(range(0, 842)):
    output = generate_rollout(i)
    path = f"/root/vast/kaiwen/track-mjx/rollouts/rodent_gau/clip_{i}.h5"
    save_to_h5py(path, output)