# Generate and render a rollout for existing checkpoint

This notebook demonstrates how to load a training checkpoint, perform a rollout, and render the result. Full network activations are saved as an output of this rollout for further analysis.

## Imports

In [1]:
import os
import logging
# Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "glfw"

from track_mjx.agent import checkpointing
from track_mjx.analysis.rollout import (
    create_rollout_generator,
    create_environment,
)
from track_mjx.analysis.render import render_from_saved_rollout, display_video
from track_mjx.analysis.utils import save_to_h5py, load_from_h5py

import jax
from jax import numpy as jp
from pathlib import Path

## Load checkpoint

In [2]:
# replace with your checkpoint path
ckpt_path = "/Users/charleszhang/GitHub/track-mjx/model_checkpoints/250220_125514"  
# Load config from checkpoint 
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/Users/charleszhang/GitHub/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

### Restore policy and make rollout functions

In [7]:
env = create_environment(cfg)
inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
generate_rollout = create_rollout_generator(cfg["reference_config"], env, inference_fn)

KeyboardInterrupt: 

### Generate rollouts from the checkpoint

After we load the checkpoint, we can do inference on the rollout!

We can generate a rollout imitating single clip, specified by the clip index. The first time you call the function there will be ~1-3min of compilation time, after which it will take only a few seconds.

In [7]:
single_rollout = generate_rollout(clip_idx=1)

#### Batch Generating Rollouts

Alternatively, you can use `jax.vmap` to parallelize the rollout function. This is useful for performing a rollout over an entire dataset for eval/analysis purposes. We pass in a 1D array of clip indexes (`clip_idxs`) as input. 

The first run for this will also have a few minutes of compilation time.

**Note:** `vmap` compiles based on the input shape. This means that if you use the same length for `clip_idxs`, JAX will reuse the compiled function for acceleration. However, if the input length changes, JAX will **recompile the entire function**, incurring additional overhead.

In [9]:
# Generate rollout for 5 clips simultaneously
jit_vmap_generate_rollout = jax.jit(jax.vmap(generate_rollout))
clip_idxs = jp.arange(0, 5)
jit_vmap_out = jit_vmap_generate_rollout(clip_idxs)

In [10]:
# Running it with a different clip_idxs length will cause reocmpilation
clip_idxs = jp.arange(15, 30)
jit_vmap_out2 = jit_vmap_generate_rollout(clip_idxs)

### Step 4: Save it to disk

In [13]:
save_path = Path(ckpt_path) / "rollout.h5"

In [9]:
save_to_h5py(save_path.resolve(), single_rollout)

# Render Rollout Videos from the Saved Rollouts

## Load the rollout file

In [14]:
rollout = load_from_h5py(save_path)

## Render rollout

Note: Currently only works for non-batched rollouts

In [None]:
frames, realtime_framerate = render_from_saved_rollout(cfg, rollout)
display_video(frames, framerate=realtime_framerate)

In [3]:
from track_mjx.agent import ppo_networks as networks
from track_mjx.environment import wrappers

env = create_environment(cfg)

  link = jax.tree_map(lambda x: x[1:].copy(), link)
  motion = jax.tree_map(lambda *x: np.concatenate(x), *motions)
  limit = jax.tree_map(lambda *x: np.concatenate(x), *limits)
  act_kwargs = jax.tree_map(lambda x: x[act_mask], act_kwargs)
  sys = jax.tree_map(jp.array, sys)
  link = jax.tree_map(lambda x: x[1:].copy(), link)


env._steps_for_cur_frame: 2.0


  motion = jax.tree_map(lambda *x: np.concatenate(x), *motions)
  limit = jax.tree_map(lambda *x: np.concatenate(x), *limits)
  act_kwargs = jax.tree_map(lambda x: x[act_mask], act_kwargs)
  sys = jax.tree_map(jp.array, sys)


In [4]:
decoder_policy_fn = networks.make_decoder_policy_fn(ckpt_path)

In [5]:
key = jax.random.PRNGKey(0)
action, _ = decoder_policy_fn(jp.ones(60 + 147), key)

In [12]:
# TODO rollout env overrides reset function without calling the previous reset 
# so it has to be first
rollout_env = wrappers.RenderRolloutWrapperTracking(env) 
rollout_highlvlenv = wrappers.HighLevelWrapper(
    rollout_env, 
    decoder_policy_fn, 
    cfg["network_config"]["reference_obs_size"]
)
latent_size = cfg["network_config"]["intention_size"]

In [13]:
key = jax.random.PRNGKey(0)
jit_reset = jax.jit(rollout_highlvlenv.reset)
jit_step = jax.jit(rollout_highlvlenv.step)

In [48]:
@jax.jit
def random_walk_rollout(clip_idx, seed=1):
    key = jax.random.PRNGKey(seed)
    _, reset_rng, rollout_key = jax.random.split(key, 3)
    init_state = jit_reset(reset_rng, clip_idx=clip_idx)
    num_steps = (
        int(cfg["reference_config"].clip_length * rollout_highlvlenv._steps_for_cur_frame) - 1
    )

    def _step_fn(carry, _):
        state, act_rng, last_latents = carry
        _, act_rng = jax.random.split(act_rng)
        # action is a random walk
        latents = last_latents * 0.99 + ((jax.random.normal(act_rng, (latent_size), dtype=jp.float32)*0.1) * 0.1)
        next_state = jit_step(state, latents)
        return (next_state, act_rng, latents), (next_state, latents)

    # Run rollout
    init_carry = (init_state, rollout_key, jp.zeros(latent_size))
    (final_state, _, _), (states, ctrls) = jax.lax.scan(
        _step_fn, init_carry, None, length=num_steps
    )
 
    def prepend(element, arr):
        # Scalar elements shouldn't be modified
        if arr.ndim == 0:
            return arr

        return jp.concatenate([element[None], arr])

    rollout_states = jax.tree.map(prepend, init_state, states)

    # Compute rewards and metrics
    # TODO: refactor to collect metrics based on cfg metric list
    rewards = {
        "pos_rewards": jax.vmap(lambda s: s.metrics["pos_reward"])(rollout_states),
        "endeff_rewards": jax.vmap(lambda s: s.metrics["endeff_reward"])(
            rollout_states
        ),
        "quat_rewards": jax.vmap(lambda s: s.metrics["quat_reward"])(
            rollout_states
        ),
        "angvel_rewards": jax.vmap(lambda s: s.metrics["angvel_reward"])(
            rollout_states
        ),
        "bodypos_rewards": jax.vmap(lambda s: s.metrics["bodypos_reward"])(
            rollout_states
        ),
        "joint_rewards": jax.vmap(lambda s: s.metrics["joint_reward"])(
            rollout_states
        ),
        # "summed_pos_distances": jax.vmap(lambda s: s.info["summed_pos_distance"])(
        #     rollout_states
        # ),
        # "joint_distances": jax.vmap(lambda s: s.info["joint_distance"])(
        #     rollout_states
        # ),
        "torso_heights": jax.vmap(
            lambda s: s.pipeline_state.xpos[rollout_highlvlenv.walker._torso_idx][2]
        )(rollout_states),
    }

    # Reference and rollout qposes
    ref_traj = rollout_highlvlenv._get_reference_clip(init_state.info)
    qposes_ref = jp.repeat(
        jp.hstack([ref_traj.position, ref_traj.quaternion, ref_traj.joints]),
        int(rollout_highlvlenv._steps_for_cur_frame),
        axis=0,
    )

    # Collect qposes from states
    qposes_rollout = jax.vmap(lambda s: s.pipeline_state.qpos)(rollout_states)

    return {
        "rewards": rewards,
        "observations": jax.vmap(lambda s: s.obs)(rollout_states),
        "ctrl": ctrls,
        "qposes_ref": qposes_ref,
        "qposes_rollout": qposes_rollout,
        "info": jax.vmap(lambda s: s.info)(rollout_states),
    }

In [49]:
rollout = random_walk_rollout(clip_idx=1, seed=3)

In [50]:
frames, realtime_framerate = render_from_saved_rollout(cfg, rollout)
display_video(frames, framerate=realtime_framerate)

MuJoCo Rendering...


100%|██████████| 500/500 [00:07<00:00, 64.60it/s]
