# Generate and render a rollout for existing checkpoint

This notebook demonstrates how to load a training checkpoint, perform a rollout, and render the result. Full network activations are saved as an output of this rollout for further analysis.

## Terminology

**Rollout** - a process of running a trained agent in an environment to collect data. In this case, we will run the agent in the environment and save the activations of the network.

**Rollout File** - a file containing the activations of the network for each step of the rollout for a specific episode, saved in `.h5` formats.

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

import os
import logging
# Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = os.environ.get(
    "XLA_PYTHON_CLIENT_MEM_FRACTION", "0.4"
)


from track_mjx.agent import checkpointing
from track_mjx.analysis import rollout, render, utils

import jax
from jax import numpy as jp
from pathlib import Path
from orbax import checkpoint as ocp
from omegaconf import OmegaConf

## Load checkpoint

In [2]:
# replace with your checkpoint path
# ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156"
ckpt_path = "/root/vast/kaiwen/track-mjx/model_checkpoints/rodent_data/RodentReferenceClip.h5_250416_022944" # LSTM checkpoint

Patch older checkpoints that didn't have some config arguments

In [76]:
def load_config_from_checkpoint(checkpoint_path):
    options = ocp.CheckpointManagerOptions(step_prefix="PPONetwork")
    with ocp.CheckpointManager(
        checkpoint_path,
        options=options,
    ) as mngr:
        print(f"latest checkpoint step: {mngr.latest_step()}")
        config = mngr.restore(
            mngr.latest_step(),
            args=ocp.args.Composite(config=ocp.args.JsonRestore(None)),
        )["config"]
    return config


def patch(cfg: dict, env_state, env, normalize_observations) -> dict:
    network_cfg = cfg.setdefault("network_config", {})

    network_cfg.setdefault(
        "observation_size", int(env_state.obs.shape[-1])
    )
    network_cfg.setdefault(
        "action_size", int(env.action_size)
    )
    network_cfg.setdefault(
        "normalize_observations", normalize_observations
    )
    network_cfg.setdefault(
        "reference_obs_size",
        int(env_state.info["reference_obs_size"]),
    )
    
    logging_cfg = cfg.setdefault("logging_config", {})
    logging_cfg.setdefault(
        "rollout_metrics",
        [
            "pos_reward",
            "quat_reward",
            "joint_reward",
            "angvel_reward",
            "bodypos_reward",
            "endeff_reward",
            "summed_pos_distance",
            "joint_distance",
            "quat_distance",
            "ctrl_cost",
            "ctrl_diff_cost",
            "energy_cost",
            "too_far",
            "bad_pose",
            "bad_quat",
            "fall",
        ],
    )

    return cfg

In [None]:
cfg = load_config_from_checkpoint(ckpt_path)
cfg['data_path'] = "/root/vast/kaiwen/track-mjx/data/RodentReferenceClip.h5"
env = rollout.create_environment(cfg)
key = jax.random.PRNGKey(0)
env_state = env.reset(key)

In [78]:
cfg = patch(cfg, env_state, env, True)

In [None]:
cfg['network_config']

In [80]:
def cfg_only_load_checkpoint_for_eval(
    checkpoint_path: str,
    cfg: OmegaConf | dict,
    step_prefix: str = "PPONetwork",
    step: int | None = None,
):
    if not OmegaConf.is_config(cfg):
        cfg = OmegaConf.create(cfg)

    mgr_opts = ocp.CheckpointManagerOptions(create=False, step_prefix=step_prefix)
    ckpt_mgr = ocp.CheckpointManager(checkpoint_path, options=mgr_opts)

    if step is None:
        step = ckpt_mgr.latest_step()
    logging.info(f"Loading checkpoint from {checkpoint_path} @ step {step}")

    # 5) restore the policy (assumes you have a `load_policy` util)
    policy = checkpointing.load_policy(
        checkpoint_path=checkpoint_path,
        cfg=cfg,
        ckpt_mgr=ckpt_mgr,
        step_prefix=step_prefix,
        step=step,
    )

    return {"cfg": cfg, "policy": policy}

In [None]:
# # Load config from checkpoint
ckpt = cfg_only_load_checkpoint_for_eval(ckpt_path, cfg)

# cfg = ckpt["cfg"]

# # make some changes to the config
# # replace with absolute path to your data
# # -- your notebook may not have access to the same relative path
# cfg.data_path = "/root/vast/scott-yang/track-mjx/data/transform_snips.h5"
# cfg.train_setup.checkpoint_to_restore = ckpt_path

### Restore policy and make rollout functions

In [82]:
cfg = OmegaConf.create(cfg)
env = rollout.create_environment(cfg)
inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
generate_rollout = rollout.create_rollout_generator(cfg, env, inference_fn, model='lstm')

### Generate rollouts from the checkpoint

After we load the checkpoint, we can do inference on the rollout!

We can generate a rollout imitating single clip, specified by the clip index. The first time you call the function there will be ~1-3min of compilation time, after which it will take only a few seconds.

In [None]:
clip_idx = 41
single_rollout = generate_rollout(clip_idx=clip_idx, seed=0)

In [None]:
single_rollout["activations"]["decoder_inputs"]

In [None]:
single_rollout["activations"]["hidden_states"]

#### Batch Generating Rollouts

Alternatively, you can use `jax.vmap` to parallelize the rollout function. This is useful for performing a rollout over an entire dataset for eval/analysis purposes. We pass in a 1D array of clip indexes (`clip_idxs`) as input. 

The first run for this will also have a few minutes of compilation time.

**Note:** `vmap` compiles based on the input shape. This means that if you use the same length for `clip_idxs`, JAX will reuse the compiled function for acceleration. However, if the input length changes, JAX will **recompile the entire function**, incurring additional overhead.

In [None]:
# Generate rollout for 5 clips simultaneously
jit_vmap_generate_rollout = jax.jit(jax.vmap(generate_rollout))
clip_idxs = jp.arange(0, 5)
jit_vmap_out = jit_vmap_generate_rollout(clip_idxs)

In [None]:
# Running it with a different clip_idxs length will cause reocmpilation
clip_idxs = jp.arange(15, 30)
jit_vmap_out2 = jit_vmap_generate_rollout(clip_idxs)

### Step 4: Save it to disk

In [None]:
save_path = Path(ckpt_path) / f"rollout_{clip_idx}.h5"
print(f"Saving rollout to {save_path}")

In [9]:
utils.save_to_h5py(save_path.resolve(), single_rollout)

# Render Rollout Videos from the Saved Rollouts

## Load the rollout file

In [10]:
rollout = utils.load_from_h5py(save_path)

## Render rollout

Note: Currently only works for single rollouts

In [None]:
frames, realtime_framerate = render.render_rollout(cfg, single_rollout)
render.display_video(frames, framerate=realtime_framerate)