# `track-mjx` - Checkpoint Playground

This notebook is a playground for the saved checkpoints of `track-mjx`'s kinematic replay tasks. It allows you to load the intention network and perform different operations on it, such as inspecting the decoder with random intentions, generating a trajectory from a given intention, evaluating model performance with various metrics, visualizing the trajectories and intention distributions, and implementing additional analysis tools to enhance the exploration of model behavior and capabilities.

## !! JAX's Sharp Tips !!

If you want to modify the input of a `jax.jit` compiled function, please do make sure that the input arguments have consistent types across the notebook. For consistency, this notebook uses `jnp.array` as the default type for all inputs. If you want to use `np.array`, please make sure that the input arguments are all `jnp.array`, or consistent with previous compilation time.

If you trigger excessive re-compilation, please review the input types and ensure they are properly aligned with previous compiled functions.

In [None]:
# imports

%load_ext autoreload
%autoreload 2

import os
import logging

# Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"

# This is important when you have multiple notebook that is using JAX on the same machine
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = os.environ.get(
    "XLA_PYTHON_CLIENT_MEM_FRACTION", "0.4"
)


from track_mjx.agent import checkpointing
from track_mjx.agent.mlp_ppo.intention_network import Decoder
from track_mjx.analysis.rollout import create_environment
from track_mjx.environment.wrappers import RenderRolloutWrapperMulticlipTracking
from track_mjx.analysis.render import render_rollout, display_video
from track_mjx.analysis.utils import save_to_h5py, load_from_h5py

import jax
from jax import numpy as jnp
# enable JAX persistent compilation cache
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")

from brax.training import distribution
from brax.training.acme import running_statistics
from pathlib import Path

# rendering related
import mujoco
from dm_control import mjcf as mjcf_dm
from tqdm import tqdm
import numpy as np
from dm_control.locomotion.walkers import rescale
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from PIL import Image
import multiprocessing as mp
import functools

## Part 0: Network Initialization and Checkpoint Loading

First, we need to load the checkpoint. If you want to load your own custom checkpoint, you can replace the `ckpt_path` variable to your checkpoint path.

In [None]:
# replace with your checkpoint path
ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250701_232752_537208"
# Load config from checkpoint
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/root/vast/scott-yang/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

# NOTE: To use accelerated JAX.JIT, only run the following code only once.
# If you re-run, you will triggers recompilations

env = create_environment(cfg)
env = RenderRolloutWrapperMulticlipTracking(env)

The following code initializes the network and loads the checkpoint.

In [None]:
# initialize the abstract decoder
network_config = cfg["network_config"]

# initialize the decoder with last layer represent the mean and variance of the action distribution
decoder = Decoder(
    network_config["decoder_layer_sizes"] + [network_config["action_size"] * 2]
)

normalizer = running_statistics.normalize
# load the normalizer parameters
normalizer_param = ckpt["policy"][0]

# load the decoder parameters
decoder_raw = ckpt["policy"][1]["params"]["decoder"]
decoder_param = {"params": decoder_raw}
# initialize the action distribution
action_distribution = distribution.NormalTanhDistribution(
    event_size=network_config["action_size"]
)

# prevent recompilation
jit_env_reset = jax.jit(env.reset, static_argnames=("clip_idx",))
jit_env_step = jax.jit(env.step)
jit_apply = jax.jit(decoder.apply)

Test run: feed decoder with zero intentions

In [None]:
# Initialize state and random key.
rng = jax.random.PRNGKey(0)
rng, env_rng = jax.random.split(rng)
rng, init_key = jax.random.split(rng)

state = jit_env_reset(rng=env_rng, clip_idx=41)
proprioceptive_obs_size, reference_obs_size, intention_size = (
    state.info["proprioceptive_obs_size"],
    state.info["reference_obs_size"],
    cfg["network_config"]["intention_size"],
)

actions = action_distribution.mode(
    jit_apply(decoder_param, jnp.zeros(proprioceptive_obs_size + intention_size))[0]
)
print("Number of actions:", len(actions))
actions

### JAX Tips:

When performing training, it’s recommended to fully leverage JAX’s functionality, such as using `jax.lax.scan` to replace Python-native for loops. This brings the entire loop into the computation graph, allowing JAX to optimize it more effectively.

However, during inference and analysis, it’s often better to avoid compiling the entire function — including the `for` loop — all at once. Instead, you can compile individual sub-functions and call those smaller, compiled functions within a regular Python for loop. This approach is preferable because inference and analysis workflows often involve frequent modifications to the rollout functions. By keeping the loop outside the compiled graph, you can avoid unnecessary recompilation of the entire loop when only a small part of the algorithm changes.

!! To avoid costly recompilation, make sure your input parameter to the jitted function is the same type.

## Part 1: Decoder with Random Intentions

In this section, we will explore how the decoder behaves with random intentions, compared with the randomly generated actions.

In [None]:
# Initialize state and random key.
rng = jax.random.PRNGKey(0)
rng, env_rng = jax.random.split(rng)
rng, init_key = jax.random.split(rng)

state = jit_env_reset(rng=env_rng, clip_idx=41)

qposes_rollout = []

for _ in range(500):
    # Generate intentions using JAX’s random for consistency
    # normalize the observation
    obs = normalizer(state.obs, normalizer_param)
    rng, subkey = jax.random.split(rng)
    intentions = jax.random.normal(subkey, shape=(intention_size,))
    # Concatenate intentions with a slice of the observation.
    concatenated = jnp.concatenate(
        [intentions, obs[..., cfg["network_config"]["reference_obs_size"] :]],
        axis=-1,
    )
    # Get the decoder output and sample an action.
    logits = jit_apply(decoder_param, concatenated)[0]
    rng, subkey = jax.random.split(rng)
    # NOTE: This action returns as np array, from brax implementations
    # and this triggers the recompilation of jax.jit
    action = jnp.array(action_distribution.mode(logits))
    qposes_rollout.append(state.pipeline_state.qpos)
    # Step the environment.
    state = jit_env_step(state, action)

rollout = {"qposes_rollout": jnp.array(qposes_rollout)}

# render the first 200 steps
frames, render_fps = render_rollout(cfg, rollout, render_ghost=False)
print("Rollout videos with random intentions")
display_video(frames, framerate=render_fps)

first run will take ~7 minutes, second run takes about 20s seconds on an A40 GPU

In [None]:
# Initialize state and random key.
rng, env_rng = jax.random.split(rng)
rng, init_key = jax.random.split(rng)

state = jit_env_reset(rng=env_rng, clip_idx=41)

qposes_rollout = []

for _ in range(500):
    # Generate intentions using JAX’s random for consistency
    rng, subkey = jax.random.split(rng)
    action = jax.random.normal(subkey, shape=(38,))
    qposes_rollout.append(state.pipeline_state.qpos)
    # Step the environment.
    state = jit_env_step(state, action)

rollout = {"qposes_rollout": jnp.array(qposes_rollout)}

# render the first 200 steps
frames, render_fps = render_rollout(cfg, rollout, render_ghost=False)
# take a look at the video
print("Rollout with random actions")
display_video(frames, framerate=render_fps)

As you can see, the rendering of the decoder actions with random intentions is not very good, but still better than pure random actions that it remains upright and did not fall over.

## Part 2: Recreate the Recorded Rollout

The following cell recreates the recorded rollout based on the recorded statistics taken during the saved rollout. We will explore different ways of recreating the recorded rollout, including:

**Note:** Please get the saved rollout file from the `rollout_from_checkpoint.ipynb` file, with the `clip_idx` set to `41`

- 2.0: Load the saved rollout
- 2.1: Step with recorded **actions**
- 2.2: Step with recorded **decoder inputs**
- 2.3: Step with recorded **intentions**

In [None]:
# some rendering related code
from typing import List


def visualize_ego_observation(
    ego_obs, window_size, title, frame, feature_mask: List[int] | None = None
):
    """
    Render one frame of the sliding window plot.

    For frame indices less than window_size, it displays the data from index 0 to current index,
    but fixes the x-axis from 0 to window_size. For later frames, the most recent window_size time
    steps are displayed. A red vertical line indicates the current time index.

    Args:
        ego_obs: A 2D numpy array of shape (147, N) representing the ego observation.
        window_size: The size of the sliding window.
        frame: The current frame index.

    Returns:
        rgb_array: A numpy array of shape (480, 640, 3) representing the rendered image.
    """
    if feature_mask is not None:
        ego_obs = ego_obs[feature_mask[0] : feature_mask[1], :]
    # Compute global min and max to keep colormap consistent
    global_vmin = float(ego_obs.min())
    global_vmax = float(ego_obs.max())

    current_idx = frame + 1  # current time step (1-indexed)

    # Create the figure with fixed dimensions: 6.4x4.8 inches at dpi=100 gives 640x480 pixels.
    fig, ax = plt.subplots(figsize=(6.4, 4.8), dpi=100)
    plt.tight_layout()
    canvas = FigureCanvas(fig)

    if current_idx < window_size:
        # For early frames, plot data from 0 to current_idx.
        window_data = ego_obs[:, :current_idx]  # shape: (147, current_idx)
        # Set extent so the data covers [0, current_idx] in x,
        # but later we force the x-axis to show 0 to window_size.
        extent = [0, current_idx, 0, ego_obs.shape[0]]
        ax.imshow(
            window_data,
            aspect="auto",
            interpolation="none",
            origin="lower",
            extent=extent,
            vmin=global_vmin,
            vmax=global_vmax,
        )
        # Fix x-axis to [0, window_size]
        ax.set_xlim(0, window_size)
        # Draw red vertical line at the current index
        ax.axvline(x=current_idx, color="red", linewidth=2)
    else:
        # For later frames, display the most recent window_size time steps.
        window_data = ego_obs[
            :, current_idx - window_size : current_idx
        ]  # shape: (147, window_size)
        extent = [current_idx - window_size, current_idx, 0, ego_obs.shape[0]]
        ax.imshow(
            window_data,
            aspect="auto",
            interpolation="none",
            origin="lower",
            extent=extent,
            vmin=global_vmin,
            vmax=global_vmax,
        )
        # Even here, we fix the view from 0 to window_size (the window length).
        ax.set_xlim(current_idx - window_size, current_idx)

        ax.axvline(x=current_idx, color="red", linewidth=2)
    # ax.yaxis.set_visible(False)
    # This hides the entire y-axis including ticks and labels
    ax.set_title(title)
    # Render the figure and convert to a NumPy RGB array using the proper canvas method.
    canvas.draw()
    s, (width, height) = canvas.print_to_buffer()
    image = Image.frombytes("RGBA", (width, height), s)
    rgb_array = np.array(image.convert("RGB"))
    plt.close(fig)
    return rgb_array

### Part 2.0: Load the saved rollout


In [None]:
# load the rollout data, generated in rollout_from_checkpoint.ipynb
example_rollout_path = (
    "/root/vast/scott-yang/track-mjx/model_checkpoints/250227_200156/rollout_41.h5"
)

rollout_data = load_from_h5py(example_rollout_path)

# render the saved rollout
frames, realtime_framerate = render_rollout(cfg, rollout_data)
display_video(frames, framerate=realtime_framerate)

### Part 2.1: Step with Recorded Actions

The following cell step through the MuJoCo environment based on the recorded action taken during the saved rollout. In other words, the next for loop re-step through the physic simulator based on the recorded actions

In [None]:
# Initialize state and random key.
rng, env_rng = jax.random.split(rng)
rng, init_key = jax.random.split(rng)

state = jit_env_reset(rng=env_rng, clip_idx=41)

qposes = []

# make sure that the array is the same input type that gets feed into the jit_env_step
ctrl = jnp.array(rollout_data["ctrl"])

for action in rollout_data["ctrl"]:
    state = jit_env_st4ep(state, action)
    qposes.append(state.pipeline_state.qpos)

frames = render_rollout(jnp.array(qposes))
print("Rollout with the recorded actions")
display_video(frames, framerate=100)

In [None]:
# Parameters
ctrl = np.array(ctrl)
window_size = 100
total_time = ctrl.shape[1]  # should be 499
title = "Actions"
# Use multiprocessing to parallelize the rendering.
f = functools.partial(visualize_ego_observation, ctrl, window_size, title)
with mp.Pool(processes=mp.cpu_count()) as pool:
    frames_plt_1 = pool.map(f, list(range(total_time)))

In [None]:
combined_video = np.concatenate((frames, frames_plt_1), axis=2)
display_video(combined_video, framerate=30)

From the video we can see that it generates a somewhat similar trajectory as the recorded one, but the rodent eventually fall over. This is because we are not using the same random seed for the mujoco simulator, and the physics engine is not deterministic. In order to make it deterministic, we need to set the random seed for the mujoco simulator.

### Part 2.2: Step with Recorded Decoder Inputs

The following cell steps through the MuJoCo environment based on the recorded decoder inputs taken during the saved rollout. In other words, the next for loop re-steps through the physics simulator based on the recorded decoder inputs.

In [None]:
inps = rollout_data["activations"]["decoder_inputs"]
inps = jnp.array(inps)


state = jit_env_reset(rng=env_rng, clip_idx=41)
qposes = []
actual_ego_obs = []

for inp in inps:
    obs = normalizer(state.obs, normalizer_param)
    actual_ego_obs.append(obs[..., cfg["network_config"]["reference_obs_size"] :])
    logits = jit_apply(decoder_param, inp)
    action = jnp.array(action_distribution.mode(logits))
    state = jit_env_step(state, action)
    qposes.append(state.pipeline_state.qpos)

actual_ego_obs = np.array(actual_ego_obs).T
frames_mj = render_rollout(jnp.array(qposes))
print("Rollout videos with decoder inputs")
# display_video(frames_mj, framerate=100)

In [None]:
ego_obs = rollout_data["activations"]["decoder_inputs"][..., 60:].T

# Parameters
window_size = 100
total_time = ego_obs.shape[1]  # should be 499
title = "Recorded EgoObs (Offline EgoObs) (Feed into Decoder)"
# Use multiprocessing to parallelize the rendering.
f = functools.partial(visualize_ego_observation, ego_obs, window_size, title)
with mp.Pool(processes=mp.cpu_count()) as pool:
    frames_plt = pool.map(f, list(range(total_time)))

In [None]:
# Parameters
window_size = 100
total_time = actual_ego_obs.shape[1]  # should be 499
title = "Actual EgoObs (Online EgoObs) (Discarded)"
# Use multiprocessing to parallelize the rendering.
f = functools.partial(visualize_ego_observation, actual_ego_obs, window_size, title)
with mp.Pool(processes=mp.cpu_count()) as pool:
    frames_plt_1 = pool.map(f, list(range(total_time)))

In [None]:
combined_video = np.concatenate((frames_mj, frames_plt, frames_plt_1), axis=2)
display_video(combined_video, framerate=30)

highlights the features that are related to the gait cycle.

In [None]:
ego_obs = rollout_data["activations"]["decoder_inputs"][..., 60:].T

# Parameters
window_size = 100
total_time = ego_obs.shape[1]  # should be 499
title = "Recorded EgoObs (Offline EgoObs) (Feed into Decoder)"
feature_mask = [75, 100]
# Use multiprocessing to parallelize the rendering.
f = functools.partial(
    visualize_ego_observation, ego_obs, window_size, title, feature_mask=feature_mask
)
with mp.Pool(processes=mp.cpu_count()) as pool:
    frames_plt = pool.map(f, list(range(total_time)))

In [None]:
# Parameters
window_size = 100
total_time = actual_ego_obs.shape[1]  # should be 499
title = "Actual EgoObs (Online EgoObs) (Discarded)"
feature_mask = [75, 100]
# Use multiprocessing to parallelize the rendering.
f = functools.partial(
    visualize_ego_observation,
    actual_ego_obs,
    window_size,
    title,
    feature_mask=feature_mask,
)
with mp.Pool(processes=mp.cpu_count()) as pool:
    frames_plt_1 = pool.map(f, list(range(total_time)))

In [None]:
combined_video = np.concatenate((frames_mj, frames_plt, frames_plt_1), axis=2)
display_video(combined_video, framerate=30)

### Part 2.3: Step with Recorded Intentions

The following cell step through the MuJoCo environment based on the recorded intentions taken during the saved rollout. In other words, the next for loop re-step through the physic simulator based on the recorded intentions

In [None]:
qposes = []
state = jit_env_reset(rng=env_rng, clip_idx=41)
intentions = jnp.array(rollout_data["activations"]["intention"])
ego_obs = []
actions = []

for intention in intentions:
    obs = normalizer(state.obs, normalizer_param)
    ego_obs.append(obs[..., cfg["network_config"]["reference_obs_size"] :])
    concatenated = jnp.concatenate(
        [intention, obs[..., cfg["network_config"]["reference_obs_size"] :]],
        axis=-1,
    )
    logits = jit_apply(decoder_param, concatenated)
    action = jnp.array(action_distribution.mode(logits))
    actions.append(action)
    state = jit_env_step(state, action)
    qposes.append(state.pipeline_state.qpos)

actions = np.array(actions).T
ego_obs = np.array(ego_obs).T
frames_mj = render_rollout(jnp.array(qposes))
print("Rollout videos with recorded intention")
# display_video(frames_mj, framerate=100)

In [None]:
# Example dummy data. Replace this with your actual ego_obs.
# ego_obs should be a numpy array of shape (147, 499) (features x time).

# Parameters
window_size = 100
total_time = ego_obs.shape[1]  # should be 499
title = "Online EgoObs (Feed into Decoder)"
# Use multiprocessing to parallelize the rendering.
f = functools.partial(visualize_ego_observation, ego_obs, window_size, title)
with mp.Pool(processes=mp.cpu_count()) as pool:
    frames_plt = pool.map(f, list(range(total_time)))

In [None]:
# Parameters
window_size = 100
total_time = actions.shape[1]  # should be 499
title = "Actions"
# Use multiprocessing to parallelize the rendering.
f = functools.partial(visualize_ego_observation, actions, window_size, title)
with mp.Pool(processes=mp.cpu_count()) as pool:
    frames_plt_1 = pool.map(f, list(range(total_time)))

In [None]:
combined_video = np.concatenate((frames_mj, frames_plt, frames_plt_1), axis=2)
display_video(combined_video, framerate=30)

focus on the gait related features

In [None]:
# Example dummy data. Replace this with your actual ego_obs.
# ego_obs should be a numpy array of shape (147, 499) (features x time).

# Parameters
window_size = 100
total_time = ego_obs.shape[1]  # should be 499
title = "Online EgoObs (Feed into Decoder)"
feature_mask = [75, 100]
# Use multiprocessing to parallelize the rendering.
f = functools.partial(
    visualize_ego_observation, ego_obs, window_size, title, feature_mask=feature_mask
)
with mp.Pool(processes=mp.cpu_count()) as pool:
    frames_plt = pool.map(f, list(range(total_time)))

In [None]:
combined_video = np.concatenate((frames_mj, frames_plt), axis=2)
display_video(combined_video, framerate=30)

This is remarkable! The rodent is able to perform the locomotion task more stably than the recorded actions and decoder inputs. This is because, with the sensory feedback loop provided by egocentric observations fed back into the decoder, the low-level controller can continuously adjust locomotion based on real-time sensory feedback. This highlights the importance of the sensorimotor loop.

### Bonus: Step with concatenated intentions

In [None]:
frames = render_rollout(jnp.array(qposes)[130:230])
print("Rollout videos with recorded intention from 130-230")
display_video(frames, framerate=100)

In [None]:
qposes = []
state = jit_env_reset(rng=env_rng, clip_idx=41)


intentions = jnp.array(rollout_data["activations"]["intention"])[130:230]
ego_obs = []

for _ in range(20):
    for intention in intentions:
        obs = normalizer(state.obs, normalizer_param)
        ego_obs.append(obs[..., cfg["network_config"]["reference_obs_size"] :])
        concatenated = jnp.concatenate(
            [intention, obs[..., cfg["network_config"]["reference_obs_size"] :]],
            axis=-1,
        )
        logits = jit_apply(decoder_param, concatenated)
        action = jnp.array(action_distribution.mode(logits))
        state = jit_env_step(state, action)
        qposes.append(state.pipeline_state.qpos)

ego_obs = np.array(ego_obs).T
frames_mj = render_rollout(jnp.array(qposes))
print("Rollout videos with 20 repeats of recorded intention from 130 to 230")
# display_video(frames_mj, framerate=100)

In [None]:
# Example dummy data. Replace this with your actual ego_obs.
# ego_obs should be a numpy array of shape (147, 499) (features x time).
ego_obs = np.array(ego_obs).T
# Parameters
window_size = 100
total_time = ego_obs.shape[1]  # should be 499
title = "Online EgoObs (Feed into Decoder)"
# feature_mask = [75, 100]
# Use multiprocessing to parallelize the rendering.
f = functools.partial(visualize_ego_observation, ego_obs, window_size, title)
with mp.Pool(processes=mp.cpu_count()) as pool:
    frames_plt = pool.map(f, list(range(total_time)))

In [None]:
combined_video = np.concatenate((frames_mj, frames_plt), axis=2)
display_video(combined_video, framerate=30)

focus on gait related.

In [None]:
# Example dummy data. Replace this with your actual ego_obs.
# ego_obs should be a numpy array of shape (147, 499) (features x time).
# Parameters
window_size = 100
total_time = ego_obs.shape[1]  # should be 499
title = "Online EgoObs (Feed into Decoder)"
feature_mask = [75, 100]
# Use multiprocessing to parallelize the rendering.
f = functools.partial(
    visualize_ego_observation, ego_obs, window_size, title, feature_mask=feature_mask
)
with mp.Pool(processes=mp.cpu_count()) as pool:
    frames_plt = pool.map(f, list(range(total_time)))

In [None]:
combined_video = np.concatenate((frames_mj, frames_plt), axis=2)
display_video(combined_video, framerate=30)

## Misc. Useful Utils -- Param Shape Comparison

The following cell compares the parameter structure (the parameter here refers to the `flax.linen` model parameter, which is usually a `pytree` object) between two checkpoints. When reconstructing the model and aligning with the submodule we created, the following utilities are essential to ensure that the parameters in the checkpoint are identical and applicable to the created model.

In [None]:
import jax
import jax.numpy as jnp
from track_mjx.agent.intention_network import Decoder


def compare_param_structure(params1, params2, path=""):
    """
    Recursively compare the structure of two parameter dictionaries.
    Returns a list of strings describing any key or shape differences.
    """
    differences = []
    # Compare keys in params1 against params2
    for key in params1:
        key_path = f"{path}/{key}" if path else key
        if key not in params2:
            differences.append(
                f"Key '{key_path}' exists in init params but not in checkpoint params."
            )
        else:
            val1 = params1[key]
            val2 = params2[key]
            # If both values are dicts, compare recursively.
            if isinstance(val1, dict) and isinstance(val2, dict):
                differences.extend(compare_param_structure(val1, val2, key_path))
            else:
                # Compare shape if the object has a shape attribute (e.g., a JAX array).
                shape1 = getattr(val1, "shape", None)
                shape2 = getattr(val2, "shape", None)
                if shape1 != shape2:
                    differences.append(
                        f"Key '{key_path}' has shape {shape1} in init params vs {shape2} in checkpoint params."
                    )
    # Check for keys in params2 that are not in params1.
    for key in params2:
        key_path = f"{path}/{key}" if path else key
        if key not in params1:
            differences.append(
                f"Key '{key_path}' exists in checkpoint params but not in init params."
            )
    return differences


# Assume your configuration dictionary is available as cfg.
network_config = cfg["network_config"]

# Create the Decoder module.
decoder = Decoder(
    network_config["decoder_layer_sizes"] + [network_config["action_size"] * 2]
)

# IMPORTANT: For Dense layers, ensure to include a batch dimension.
input_shape = (1, 60 + 147)
dummy_input = jnp.zeros(input_shape)

# Initialize the decoder's parameters.
init_variables = decoder.init(jax.random.PRNGKey(0), dummy_input)
params_init = init_variables["params"]

# Assume decoder_param holds your checkpoint parameters.
# If checkpoint parameters are wrapped in a 'params' key, extract them.
params_checkpoint = decoder_raw
if "params" in params_checkpoint:
    params_checkpoint = params_checkpoint["params"]

# Compare the structure of the two parameter dictionaries.
structure_diffs = compare_param_structure(params_init, params_checkpoint)
if structure_diffs:
    print("Structure differences between init params and checkpoint params:")
    for diff in structure_diffs:
        print(" -", diff)
else:
    print("No structural differences found between the parameter dictionaries.")