# Generate and render a rollout for existing checkpoint

This notebook demonstrates how to load a training checkpoint, perform a rollout, and render the result. Full network activations are saved as an output of this rollout for further analysis.

## Imports

In [1]:
import os
import logging
# Send logging outputs to stdout (comment this out if preferred)
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "osmesa"

from track_mjx.agent import checkpointing
from track_mjx.analysis.rollout import (
    create_rollout_generator,
    create_environment,
)
from track_mjx.analysis.render import render_from_saved_rollout, display_video
from track_mjx.analysis.utils import save_to_h5py, load_from_h5py

import jax
from jax import numpy as jp
from pathlib import Path

INFO:absl:Handler "orbax.checkpoint._src.handlers.array_checkpoint_handler.ArrayCheckpointHandler" already exists in the registry with associated type <class 'orbax.checkpoint._src.handlers.array_checkpoint_handler.ArrayCheckpointHandler'>. Skipping registration.
INFO:absl:Handler "orbax.checkpoint._src.handlers.proto_checkpoint_handler.ProtoCheckpointHandler" already exists in the registry with associated type <class 'orbax.checkpoint._src.handlers.proto_checkpoint_handler.ProtoCheckpointHandler'>. Skipping registration.
INFO:absl:Handler "orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler" already exists in the registry with associated type <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler'>. Skipping registration.
INFO:absl:Handler "orbax.checkpoint._src.handlers.base_pytree_checkpoint_handler.BasePyTreeCheckpointHandler" already exists in the registry with associated type <class 'orbax.checkpoint._src.handlers.base_pytree

## Load checkpoint

In [2]:
# replace with your checkpoint path
ckpt_path = "/home/tim.kim/track-mjx/model_checkpoints/250315_113155/"
# Load config from checkpoint 
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/allen/aind/scratch/tim.kim/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

INFO:absl:[thread=MainThread] Failed to get flag value for EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:absl:[process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers=None, handler_registry=None
INFO:absl:Initialized registry DefaultCheckpointHandlerRegistry({('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7efc395eea90>, ('metrics', <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7efc395

ValueError: Requested shape: (470, 512) is not compatible with the stored shape: (617, 512). Truncating/padding is disabled by setting of `strict=True`. When using standard Orbax APIs, this behavior can be modified by specifying `strict=False` in `ArrayRestoreArgs` for any array in which padding/truncation is desired.

### Restore policy and make rollout functions

In [None]:
env = create_environment(cfg)
inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
generate_rollout = create_rollout_generator(cfg, env, inference_fn)

### Generate rollouts from the checkpoint

After we load the checkpoint, we can do inference on the rollout!

We can generate a rollout imitating single clip, specified by the clip index. The first time you call the function there will be ~1-3min of compilation time, after which it will take only a few seconds.

In [7]:
single_rollout = generate_rollout(clip_idx=1)

#### Batch Generating Rollouts

Alternatively, you can use `jax.vmap` to parallelize the rollout function. This is useful for performing a rollout over an entire dataset for eval/analysis purposes. We pass in a 1D array of clip indexes (`clip_idxs`) as input. 

The first run for this will also have a few minutes of compilation time.

**Note:** `vmap` compiles based on the input shape. This means that if you use the same length for `clip_idxs`, JAX will reuse the compiled function for acceleration. However, if the input length changes, JAX will **recompile the entire function**, incurring additional overhead.

In [9]:
# Generate rollout for 5 clips simultaneously
jit_vmap_generate_rollout = jax.jit(jax.vmap(generate_rollout))
clip_idxs = jp.arange(0, 5)
jit_vmap_out = jit_vmap_generate_rollout(clip_idxs)

In [10]:
# Running it with a different clip_idxs length will cause reocmpilation
clip_idxs = jp.arange(15, 30)
jit_vmap_out2 = jit_vmap_generate_rollout(clip_idxs)

### Step 4: Save it to disk

In [8]:
save_path = Path(ckpt_path) / "rollout.h5"

In [9]:
save_to_h5py(save_path.resolve(), single_rollout)

# Render Rollout Videos from the Saved Rollouts

## Load the rollout file

In [10]:
rollout = load_from_h5py(save_path)

## Render rollout

Note: Currently only works for non-batched rollouts

In [None]:
frames, realtime_framerate = render_from_saved_rollout(cfg, rollout)
display_video(frames, framerate=realtime_framerate)