# Checkpoint Playground

This notebook is a playground for the saved checkpoints of the model of `track-mjx`. It allows you to load the intention network and perform different operations on it, such as inspecting the decoder with random intentions, or generating a trajectory from a given intention.

In [1]:
# imports

%load_ext autoreload
%autoreload 2

import os
import logging

# Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"

from track_mjx.agent import checkpointing
from track_mjx.agent.intention_network import Decoder
from track_mjx.analysis.rollout import (
    create_rollout_generator,
    create_environment,
)
from track_mjx.analysis.render import render_from_saved_rollout, display_video
from track_mjx.analysis.utils import save_to_h5py, load_from_h5py

import jax
from jax import numpy as jnp
from brax.training import distribution
from pathlib import Path

# rendering related
import mujoco
from dm_control import mjcf as mjcf_dm
from tqdm import tqdm
import numpy as np  

In [2]:
# replace with your checkpoint path
ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156"
# Load config from checkpoint
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/root/vast/scott-yang/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

In [3]:
# initialize the abstract decoder
network_config = cfg["network_config"]

decoder = Decoder(
    network_config["decoder_layer_sizes"] + [network_config["action_size"] * 2]
)

# load the decoder parameters
decoder_raw = ckpt["policy"][1]["params"]["decoder"]
decoder_param = {"params": decoder_raw}
# create the action distributions
action_distribution = distribution.NormalTanhDistribution(
    event_size=network_config["action_size"]
)

In [4]:
actions = action_distribution.sample(
    decoder.apply(decoder_param, jnp.zeros(60 + 147)), seed=jax.random.PRNGKey(0)
)
print("Number of actions:", len(actions))
actions

Number of actions: 38


Array([ 0.18949549, -0.00768461, -0.0074492 ,  0.6148951 ,  0.12509039,
       -0.24831723, -0.19015917, -0.04588649,  0.5104487 , -0.44528174,
        0.43416873, -0.43378222, -0.18514986, -0.17391203,  0.1971217 ,
       -0.01080685,  0.38500613, -0.19363432, -0.01408207, -0.47464454,
       -0.02117862,  0.03347109,  0.12487873,  0.00479396, -0.07515386,
        0.5482495 ,  0.3974665 ,  0.26658338, -0.71396106, -0.79009014,
       -0.51609546, -0.01904   , -0.02236654,  0.50454557,  0.1739393 ,
        0.31533882, -0.4272301 , -0.8686534 ], dtype=float32)

## Examine the Decoder with Random Intentions

The first thing we can do is to examine the decoder with random intentions. We can generate random intentions and see how the decoder generates the trajectory from them.

In [6]:
env = create_environment(cfg)
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)

rodent_XML_path = "/root/vast/scott-yang/track-mjx/track_mjx/environment/walker/assets/rodent/rodent.xml"
root = mjcf_dm.from_path(rodent_XML_path)

mj_model = mjcf_dm.Physics.from_mjcf_model(root).model.ptr
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
mj_data = mujoco.MjData(mj_model)

# Calulate realtime rendering fps
render_fps = (
    1.0 / mj_model.opt.timestep
) / cfg.env_config.env_args.physics_steps_per_control_step


state = jit_env_reset(rng=jax.random.PRNGKey(0))

env._steps_for_cur_frame: 2.0


In [13]:
qposes_rollout = []
for _ in tqdm(range(200)):

    intentions = np.random.normal(size=(60,))
    concatenated = jnp.concatenate(
        [intentions, state.obs[..., cfg["network_config"]["reference_obs_size"] :]],
        axis=-1,
    )
    action = action_distribution.sample(
        decoder.apply(decoder_param, concatenated), seed=jax.random.PRNGKey(0)
    )
    qposes_rollout.append(state.pipeline_state.qpos)
    state = jit_env_step(state, action)


# save rendering and log to wandb
mujoco.mj_kinematics(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model, height=480, width=640)
frames = []
print("MuJoCo Rendering...")
for qpos1 in tqdm(qposes_rollout):
    mj_data.qpos = qpos1
    mujoco.mj_forward(mj_model, mj_data)
    renderer.update_scene(mj_data, camera=cfg.env_config.render_camera_name)
    pixels = renderer.render()
    frames.append(pixels)

100%|██████████| 200/200 [00:25<00:00,  7.77it/s]


MuJoCo Rendering...


100%|██████████| 200/200 [00:00<00:00, 433.20it/s]


In [12]:
display_video(frames, framerate=50)

In [None]:
# initialize the abstract decoder
network_config = cfg["network_config"]
decoder = Decoder(
    network_config["decoder_layer_sizes"] + [network_config["action_size"] * 2]
)

## Load the Rollout File and run through decoder with recorded intentions

In [None]:
example_rollout_path = (
    "/root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156/rollout.h5"
)

## Useful Utils -- Param Shape Comparison

The following cell compares the parameter structure between two checkpoints. When reconstructing the model and aligning with the submodule we created, the following utilities are essential to ensure that the parameters in the checkpoint are identical and applicable to the created model.

In [15]:
import jax
import jax.numpy as jnp
from track_mjx.agent.intention_network import Decoder


def compare_param_structure(params1, params2, path=""):
    """
    Recursively compare the structure of two parameter dictionaries.
    Returns a list of strings describing any key or shape differences.
    """
    differences = []
    # Compare keys in params1 against params2
    for key in params1:
        key_path = f"{path}/{key}" if path else key
        if key not in params2:
            differences.append(
                f"Key '{key_path}' exists in init params but not in checkpoint params."
            )
        else:
            val1 = params1[key]
            val2 = params2[key]
            # If both values are dicts, compare recursively.
            if isinstance(val1, dict) and isinstance(val2, dict):
                differences.extend(compare_param_structure(val1, val2, key_path))
            else:
                # Compare shape if the object has a shape attribute (e.g., a JAX array).
                shape1 = getattr(val1, "shape", None)
                shape2 = getattr(val2, "shape", None)
                if shape1 != shape2:
                    differences.append(
                        f"Key '{key_path}' has shape {shape1} in init params vs {shape2} in checkpoint params."
                    )
    # Check for keys in params2 that are not in params1.
    for key in params2:
        key_path = f"{path}/{key}" if path else key
        if key not in params1:
            differences.append(
                f"Key '{key_path}' exists in checkpoint params but not in init params."
            )
    return differences


# Assume your configuration dictionary is available as cfg.
network_config = cfg["network_config"]

# Create the Decoder module.
decoder = Decoder(
    network_config["decoder_layer_sizes"] + [network_config["action_size"] * 2]
)

# IMPORTANT: For Dense layers, ensure to include a batch dimension.
input_shape = (1, 60 + 147)
dummy_input = jnp.zeros(input_shape)

# Initialize the decoder's parameters.
init_variables = decoder.init(jax.random.PRNGKey(0), dummy_input)
params_init = init_variables["params"]

# Assume decoder_param holds your checkpoint parameters.
# If checkpoint parameters are wrapped in a 'params' key, extract them.
params_checkpoint = decoder_raw
if "params" in params_checkpoint:
    params_checkpoint = params_checkpoint["params"]

# Compare the structure of the two parameter dictionaries.
structure_diffs = compare_param_structure(params_init, params_checkpoint)
if structure_diffs:
    print("Structure differences between init params and checkpoint params:")
    for diff in structure_diffs:
        print(" -", diff)
else:
    print("No structural differences found between the parameter dictionaries.")

No structural differences found between the parameter dictionaries.
