# Load Intention Checkpoint and Rendering

This notebook will load the training checkpoint from the intention network and do a rendering of the rollout. 

In [None]:
%load_ext autoreload
%autoreload 2

import os
import logging
from tqdm import tqdm
# Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # visible GPU masks
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"


from track_mjx.agent import checkpointing
from vnl_mjx.tasks.rodent import flat_arena, bowl_escape
import mediapy as media
import numpy as np
import jax
from jax import numpy as jnp
import mediapy as media
import imageio
import mujoco as mj
import mujoco
from pathlib import Path
from track_mjx.analysis.utils import load_from_h5py, save_to_h5py
from orbax import checkpoint as ocp

# enable JAX persistent compilation cache
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")


def render(model, data=None, height: int = 400, camera: int | str = -1, filename="render.png", save=False):
    if data is None:
        data = mj.MjData(model)
    with mj.Renderer(model, 2160, 3840) as renderer:
        mj.mj_forward(model, data)
        renderer.update_scene(data, camera=camera)
        frame = renderer.render()
        media.show_image(frame, height=height)
        if save:
            imageio.imwrite(filename, frame)

In [None]:
# Load config from checkpoint
# replace with your checkpoint path
checkpoints_dir = Path("/root/vast/scott-yang/vnl-mjx/model_checkpoints")

transfer_and_freeze = checkpoints_dir / "250701_203400"
transfer_not_freeze = checkpoints_dir / "250701_203433"
transfer_005_lr = checkpoints_dir / "250702_035047"
transfer_05_lr = checkpoints_dir / "250702_035205"
transfer_01_lr = checkpoints_dir / "250702_040055"

ckpts_paths = {
    "transfer_and_freeze": transfer_and_freeze,
    "transfer_not_freeze": transfer_not_freeze,
    "transfer_005_lr": transfer_005_lr,
    "transfer_05_lr": transfer_05_lr,
    "transfer_01_lr": transfer_01_lr,
}

ckpt = checkpointing.load_checkpoint_for_eval(transfer_and_freeze)
cfg = ckpt["cfg"]
cfg["env_config"]["env_args"]["bowl_vsize"] = 0.5
env = bowl_escape.BowlEscape(config_overrides=cfg["env_config"]["env_args"])
jit_reset, jit_step = jax.jit(env.reset), jax.jit(env.step)

In [None]:
# replace with your checkpoint path
ckpt_path = transfer_and_freeze
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

In [None]:
# once you compile the environment, you can easily swap the checkpoint
# and do the rollout and renderings
inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
jit_inference_fn = jax.jit(inference_fn)

In [None]:
def rollout_fn(
    seed, jit_reset=jit_reset, jit_step=jit_step, jit_inference_fn=jit_inference_fn
):
    rng = jax.random.PRNGKey(seed)
    rng, reset_rng = jax.random.split(rng)
    state = jit_reset(reset_rng)
    rollout = [state]
    for _ in tqdm(range(3000)):
        act_rng, rng = jax.random.split(rng)
        ctrl, _ = jit_inference_fn(state.obs, act_rng)
        state = jit_step(state, ctrl)
        rollout.append(state)
    return rollout

In [None]:
task = {}

for name, ckpt_path in ckpts_paths.items():
    # load checkpoint & inference fn
    ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)
    inference_fn = checkpointing.load_inference_fn(ckpt["cfg"], ckpt["policy"])
    jit_inference_fn = jax.jit(inference_fn)

    # collect qpos/xpos for 5 seeds
    qposes_episodes = []
    xposes_root_episodes = []
    for seed in range(10):
        rollout = rollout_fn(seed, jit_reset, jit_step, jit_inference_fn)
        qposes_single = [r.data.qpos for r in rollout]
        xposes_single = [r.data.xpos[3] for r in rollout]
        qposes_episodes.append(qposes_single)
        xposes_root_episodes.append(xposes_single)

    # store under this checkpoint name
    task[name] = {
        "qpos": np.array(qposes_episodes),
        "xpos": np.array(xposes_root_episodes),
    }

In [None]:
rollout = rollout_fn(10, jit_reset, jit_step, jit_inference_fn)

In [None]:
geom_names = [g.name for g in env._spec.worldbody.geoms]
# Find the Python-list index
geom_names

In [None]:
xposes_full_episode = [r.data.xpos[3] for r in rollout]
qposes_full_episode = [r.data.qpos for r in rollout]

In [None]:
save_to_h5py("bowl_escape_rollout_info.h5", task)

In [None]:
import pickle

with open("rollout.pkl", "wb") as f:
    pickle.dump(rollout, f)

In [None]:
render_every = 2
fps = 1.0 / env.dt / render_every
traj = rollout[::render_every]

scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False

frames = env.render(
    traj,
    camera="close_profile-rodent",
    scene_option=scene_option,
    height=480,
    width=640,
)
media.show_video(frames, fps=fps, loop=False)
media.write_video(f"succ_escape.mp4", frames, fps=fps, qp=18)

In [None]:
rollout[3].data.qpos

## Use the renderer environment

In [None]:
# load the qposes of the rollout

rollouts = load_from_h5py("qposes_task_transfer.h5")

In [None]:
print(rollouts.keys())
transfer_freeze = rollouts["transfer_and_freeze"][0]

In [None]:
bid = mujoco.mj_name2id(env.mj_model, mujoco.mjtObj.mjOBJ_BODY, "torso-rodent")
bid

In [None]:
task = load_from_h5py("bowl_escape_rollout_info.h5")
print(task.keys())

In [None]:
trans_freeze = task["transfer_01_lr"]
qpos, xpos = trans_freeze["qpos"][0], trans_freeze["xpos"][0]

In [None]:
# Load config from checkpoint
# TODO: serialize the vsize and the seed with this.
cfg = ckpt["cfg"]
cfg["env_config"]["env_args"]["bowl_vsize"] = 0.5
render_env = bowl_escape.BowlEscapeRender(
    num_rodents=30, config_overrides=cfg["env_config"]["env_args"]
)

In [None]:
xposes = [r.data.xpos[3] for r in rollout]
xposes = np.array(xposes)

In [None]:
render_env.remove_line_geoms()

In [None]:
render_env.add_line_geoms(xpos)

In [None]:
render(render_env.mj_model, height=400, camera="close_profile-rodent-2")

In [None]:
data = mujoco.MjData(render_env.mj_model)
start, stride, num_frames = 0, 100, 30
qposes = np.concatenate(
    [qpos[i] for i in np.arange(start, start + stride * num_frames, stride)],
    axis=0,
)
data.qpos = qposes

In [None]:
i = 4
render(render_env.mj_model, data, height=400, camera=f"close_profile-rodent-{i}")
render(
    render_env.mj_model,
    data,
    height=400,
    camera=f"back-rodent-{i}",
    save=False,
    filename=f"transfer_0.1_back-{i}.png",
)
render(render_env.mj_model, data, height=400, camera=f"side-rodent-{i}")
render(
    render_env.mj_model,
    data,
    height=400,
    camera=f"side_alt-rodent-{i}",
    filename=f"transfer_0.1_side_alt-{i}.png",
    save=True,
)
render(
    render_env.mj_model,
    data,
    height=400,
    camera=f"top-rodent-{i}",
    filename="transfer_0.1_top.png",
    save=True,
)
# render(render_env.mj_model, data, height=400, camera=-1)