In [1]:
import orbax.checkpoint as ocp
from track_mjx.agent.custom_ppo import TrainingState
from track_mjx.agent import custom_losses as ppo_losses
from track_mjx.agent import custom_ppo, custom_ppo_networks
import jax
from jax import numpy as jnp
import optax
from brax import envs
from brax.training.acme import running_statistics, specs
from track_mjx.environment import custom_wrappers

from track_mjx.environment.task.multi_clip_tracking import MultiClipTracking
from track_mjx.environment.task.single_clip_tracking import SingleClipTracking
from track_mjx.environment.task.reward import RewardConfig

from track_mjx.environment.walker.rodent import Rodent
import hydra
from hydra import initialize, compose
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
import pickle
import functools

if GlobalHydra.instance().is_initialized():
    GlobalHydra.instance().clear()
initialize(version_base=None, config_path="./../track_mjx/config")

# Load the config file
cfg = compose(config_name="rodent-full-clips")

# register the environment
envs.register_environment("rodent_single_clip", SingleClipTracking)
envs.register_environment("rodent_multi_clip", MultiClipTracking)
envs.register_environment("fly_multi_clip", MultiClipTracking)

env_cfg = OmegaConf.to_container(cfg, resolve=True)
env_args = cfg.env_config["env_args"]
env_rewards = cfg.env_config["reward_weights"]
train_config = cfg.train_setup["train_config"]
walker_config = cfg["walker_config"]
traj_config = cfg["reference_config"]

# Create the `TrainingState`

Currently I just directly copy and paste the scripts from the training code. In the future for the sake of runtime and cleaner code, we can have a helper function that construct the abstracted `TrainingState` directly without creating the actual environment. This is on Scott's TODOs.

In [None]:
# TODO(Scott): move this to track_mjx.io module
input_data_path = "/root/vast/scott-yang/track-mjx/data/ReferenceClip.p"
with open(input_data_path, "rb") as file:
    reference_clip = pickle.load(file)

walker_map = {
    "rodent": Rodent,
}

walker_class = walker_map[env_cfg["walker_type"]]
walker = walker_class(**walker_config)

# didn't use args** since penalty_pos_distance_scale need conversion
reward_config = RewardConfig(
    too_far_dist=env_rewards.too_far_dist,
    bad_pose_dist=env_rewards.bad_pose_dist,
    bad_quat_dist=env_rewards.bad_quat_dist,
    ctrl_cost_weight=env_rewards.ctrl_cost_weight,
    ctrl_diff_cost_weight=env_rewards.ctrl_diff_cost_weight,
    pos_reward_weight=env_rewards.pos_reward_weight,
    quat_reward_weight=env_rewards.quat_reward_weight,
    joint_reward_weight=env_rewards.joint_reward_weight,
    angvel_reward_weight=env_rewards.angvel_reward_weight,
    bodypos_reward_weight=env_rewards.bodypos_reward_weight,
    endeff_reward_weight=env_rewards.endeff_reward_weight,
    healthy_z_range=env_rewards.healthy_z_range,
    pos_reward_exp_scale=env_rewards.pos_reward_exp_scale,
    quat_reward_exp_scale=env_rewards.quat_reward_exp_scale,
    joint_reward_exp_scale=env_rewards.joint_reward_exp_scale,
    angvel_reward_exp_scale=env_rewards.angvel_reward_exp_scale,
    bodypos_reward_exp_scale=env_rewards.bodypos_reward_exp_scale,
    endeff_reward_exp_scale=env_rewards.endeff_reward_exp_scale,
    penalty_pos_distance_scale=jnp.array(env_rewards.penalty_pos_distance_scale),
)

# Automatically match dict keys and func needs
environment = envs.get_environment(
    env_name=cfg.env_config.env_name,
    reference_clip=reference_clip,
    walker=walker,
    reward_config=reward_config,
    **env_args,
    **traj_config,
)
# Episode length is equal to (clip length - random init range - traj length) * steps per cur frame.
# Will work on not hardcoding these values later
episode_length = (
    traj_config.clip_length - traj_config.random_init_range - traj_config.traj_length
) * environment._steps_for_cur_frame
print(f"episode_length {episode_length}")

network_factory = functools.partial(
    custom_ppo_networks.make_intention_ppo_networks,
    encoder_hidden_layer_sizes=tuple(cfg.network_config.encoder_layer_sizes),
    decoder_hidden_layer_sizes=tuple(cfg.network_config.decoder_layer_sizes),
    value_hidden_layer_sizes=tuple(cfg.network_config.critic_layer_sizes),
)

seed = 42

key = jax.random.PRNGKey(seed)
global_key, local_key = jax.random.split(key)
local_key, key_env, eval_key = jax.random.split(local_key, 3)
# key_networks should be global, so that networks are initialized the same
# way for different processes.
key_policy, key_value, policy_params_fn_key = jax.random.split(global_key, 3)

v_randomization_fn = None

if isinstance(environment, envs.Env):
    wrap_for_training = custom_wrappers.wrap
else:
    wrap_for_training = custom_wrappers.wrap

train_env = wrap_for_training(
    environment,
    episode_length=episode_length,
    action_repeat=1,
    randomization_fn=v_randomization_fn,
)

reset_fn = train_env.reset
key_envs = jax.random.split(key_env, 1)
env_state = reset_fn(key_envs)

normalize = lambda x, y: x
if True:
    normalize = running_statistics.normalize
ppo_network = network_factory(
    env_state.obs.shape[-1],
    int(env_state.info["reference_obs_size"][0]),
    train_env.action_size,
    preprocess_observations_fn=normalize,
)
make_policy = custom_ppo_networks.make_inference_fn(ppo_network)

optimizer = optax.adam(learning_rate=1e-4)

init_params = ppo_losses.PPONetworkParams(
    policy=ppo_network.policy_network.init(key_policy),
    value=ppo_network.value_network.init(key_value),
)

training_state = TrainingState(  # pytype: disable=wrong-arg-types  # jax-ndarray
    optimizer_state=optimizer.init(init_params),  # pytype: disable=wrong-arg-types  # numpy-scalars
    params=init_params,
    normalizer_params=running_statistics.init_state(specs.Array(env_state.obs.shape[-1:], jnp.dtype("float32"))),
    env_steps=0,
)

abstract_policy = (training_state.normalizer_params, training_state.params.policy)

# Orbax Checkpoint Related

The following code will create the checkpoint manager and restore the checkpoint for both the policy and train state.

In [None]:
# restore the whole model (both policy module, and training state)
options = ocp.CheckpointManagerOptions(step_prefix="PPONetwork")
with ocp.CheckpointManager(
    "/root/vast/scott-yang/track-mjx/model_checkpoints/rodent_data/ReferenceClip.p_250116_073049/",
    options=options,
) as mngr:
    print(f"latest checkpoint step: {mngr.latest_step()}")
    policy = mngr.restore(
        mngr.latest_step(), args=ocp.args.Composite(policy=ocp.args.StandardRestore(abstract_policy))
    )["policy"]
    train_state = mngr.restore(
        mngr.latest_step(), args=ocp.args.Composite(train_state=ocp.args.StandardRestore(training_state))
    )["train_state"]

# Get Internal Activations During Rollout

In [4]:
import numpy as np

rollout_key = jax.random.PRNGKey(42)
ref_trak_config = cfg["reference_config"]

env_config = cfg["env_config"]
walker_config = cfg["walker_config"]

# Wrap the env in the brax autoreset and episode wrappers
rollout_env = custom_wrappers.RenderRolloutWrapperTracking(environment)

# define the jit reset/step functions
jit_reset = jax.jit(rollout_env.reset)
jit_step = jax.jit(rollout_env.step)

jit_inference_fn = jax.jit(make_policy(policy, deterministic=True, get_activation=True))
rollout_key, reset_rng, act_rng = jax.random.split(rollout_key, 3)

In [5]:
def process_state_to_save(state):
    state_keys = ["obs", "reward", "done", "metrics", "info"]
    ps_keys = [
        "act",
        "act_dot",
        "actuator_force",
        "actuator_length",
        "actuator_moment",
        "actuator_velocity",
        "bvh_aabb_dyn",
        "bvh_active",
        "cacc",
        "cam_xmat",
        "cam_xpos",
        "cdof",
        "cdof_dot",
        "cfrc_ext",
        "cfrc_int",
        "cinert",
        "crb",
        "ctrl",
        "cvel",
        "efc_D",
        "efc_J",
        "efc_aref",
        "efc_force",
        "efc_frictionloss",
        "efc_margin",
        "efc_pos",
        "efc_type",
        "eq_active",
        "flexedge_J",
        "flexedge_J_colind",
        "flexedge_J_rowadr",
        "flexedge_J_rownnz",
        "flexedge_length",
        "flexedge_velocity",
        "flexelem_aabb",
        "flexvert_xpos",
        "geom_xmat",
        "geom_xpos",
        "light_xdir",
        "light_xpos",
        "mocap_pos",
        "mocap_quat",
        "ncon",
        "ne",
        "nefc",
        "nf",
        "nl",
        "q",
        "qDeriv",
        "qH",
        "qHDiagInv",
        "qLD",
        "qLDiagInv",
        "qLDiagSqrtInv",
        "qLU",
        "qM",
        "qacc",
        "qacc_smooth",
        "qacc_warmstart",
        "qd",
        "qfrc_actuator",
        "qfrc_applied",
        "qfrc_bias",
        "qfrc_constraint",
        "qfrc_damper",
        "qfrc_fluid",
        "qfrc_gravcomp",
        "qfrc_inverse",
        "qfrc_passive",
        "qfrc_smooth",
        "qfrc_spring",
        "qpos",
        "qvel",
        "sensordata",
        "site_xmat",
        "site_xpos",
        "solver_niter",
        "subtree_angmom",
        "subtree_com",
        "subtree_linvel",
        "ten_J",
        "ten_J_colind",
        "ten_J_rowadr",
        "ten_J_rownnz",
        "ten_length",
        "ten_velocity",
        "ten_wrapadr",
        "ten_wrapnum",
        "time",
        "userdata",
        "wrap_obj",
        "wrap_xpos",
        "xanchor",
        "xaxis",
        "xfrc_applied",
        "ximat",
        "xipos",
        "xmat",
        "xpos",
        "xquat",
    ]
    pipeline_state = state.pipeline_state
    output = {}
    for key in state_keys:
        output[key] = getattr(state, key)
    for key in ps_keys:
        output[key] = getattr(pipeline_state, key)
    return output

In [6]:
from typing import List, Dict, Tuple


def generate_rollout(clip_idx: int | None = None) -> Dict:
    """
    Generate a rollout for a given clip id, with the loaded checkpoint

    Args:
        clip_id (int): The clip id to generate the rollout for, if None, will generate a random clip

    returns:
        ctrls (List): The controls for the rollout
        extrases (List): The extra outputs from the policy for the rollout
        rewards (Dict): The rewards for the rollout
    """

    rollout_key = jax.random.PRNGKey(42)
    rollout_key, reset_rng, act_rng = jax.random.split(rollout_key, 3)
    # do a rollout on the saved model
    state = jit_reset(reset_rng, clip_idx=clip_idx)

    rollout_states = [state]
    ctrls, activations, rewards = [], [], {}
    for i in range(
        int(ref_trak_config.clip_length * environment._steps_for_cur_frame) - 1
    ):  # why is this? what's the observation for the last few step?
        _, act_rng = jax.random.split(act_rng)
        obs = state.obs
        ctrl, extras = jit_inference_fn(obs, act_rng)
        state = jit_step(state, ctrl)
        rollout_states.append(state)
        ctrls.append(ctrl)
        activations.append(extras["activations"])

    # might include those reward term in the visual rendering
    pos_rewards = [state.metrics["pos_reward"] for state in rollout_states]
    endeff_rewards = [state.metrics["endeff_reward"] for state in rollout_states]
    quat_rewards = [state.metrics["quat_reward"] for state in rollout_states]
    angvel_rewards = [state.metrics["angvel_reward"] for state in rollout_states]
    bodypos_rewards = [state.metrics["bodypos_reward"] for state in rollout_states]
    joint_rewards = [state.metrics["joint_reward"] for state in rollout_states]
    summed_pos_distances = [state.info["summed_pos_distance"] for state in rollout_states]
    joint_distances = [state.info["joint_distance"] for state in rollout_states]
    torso_heights = [state.pipeline_state.xpos[environment.walker._torso_idx][2] for state in rollout_states]
    rewards = {
        "pos_rewards": pos_rewards,
        "endeff_rewards": endeff_rewards,
        "quat_rewards": quat_rewards,
        "angvel_rewards": angvel_rewards,
        "bodypos_rewards": bodypos_rewards,
        "joint_rewards": joint_rewards,
        "summed_pos_distances": summed_pos_distances,
        "joint_distances": joint_distances,
        "torso_heights": torso_heights,
    }
    # get qposes for both rollout and reference
    ref_traj = rollout_env._get_reference_clip(rollout_states[0].info)
    qposes_ref = np.repeat(
        np.hstack([ref_traj.position, ref_traj.quaternion, ref_traj.joints]),
        environment._steps_for_cur_frame,
        axis=0,
    )
    qposes_rollout = np.array([state.pipeline_state.qpos for state in rollout_states])
    # processed_states = [process_state_to_save(state) for state in rollout_states]
    observations = [state.obs for state in rollout_states]
    output = {
        "rewards": rewards,
        "observations": observations,
        # "states": processed_states,
        "ctrl": ctrls,
        "activations": activations,
        "qposes_ref": qposes_ref,
        "qposes_rollout": qposes_rollout,
        "info": [state.info for state in rollout_states],
    }

    output = jax.tree.map(lambda x: np.array(x), output)
    return output

# Save to HDF5 File

Since we have a pytree dictionary data structure, we might need to 

In [None]:
from tqdm import tqdm

output = generate_rollout(32)

We need to save out the observations, info, metrics, done, reward, in the state, and the mujoco variable in the pipeline state if we want to do analysis. Let's save a complete list of item.

Also, need to serialize the clip information to make the video rendering directly reproducible.

# TODO

run the rollout for all of the clips.

In [11]:
import h5py
from tqdm import tqdm


def save_to_h5py(file, data, group_path="/"):
    """
    Save a pytree (like a dictionary) into an HDF5 file.

    Args:
        file (h5py.File): An open HDF5 file object.
        data: The data to save (can be a dictionary, list, etc.).
        group_path (str): The HDF5 group path for saving the data.
    """
    if isinstance(data, dict):
        for key, value in data.items():
            sub_group_path = f"{group_path}/{key}"
            save_to_h5py(file, value, sub_group_path)
    elif isinstance(data, list):
        for i, item in enumerate(data):
            sub_group_path = f"{group_path}/{i}"
            save_to_h5py(file, item, sub_group_path)
    elif isinstance(data, (int, float, str, bool, np.ndarray)):
        file.create_dataset(group_path, data=data)
    elif hasattr(data, "numpy"):  # For NumPy arrays or PyTorch tensors
        file.create_dataset(group_path, data=data.numpy())
    else:
        raise TypeError(f"Unsupported data type: {type(data)}")

In [None]:
from IPython.display import clear_output

for i in tqdm(range(environment._n_clips)):
    output = generate_rollout(i)
    with h5py.File(f"/root/vast/scott-yang/rodent_rollout_info/clip_{i}.h5", "w") as h5file:
        save_to_h5py(h5file, output)
    clear_output(wait=True)

In [28]:
def load_from_h5py(file, group_path="/"):
    """
    Load a pytree structure from an HDF5 file.

    Args:
        file (h5py.File): An open HDF5 file object.
        group_path (str): The HDF5 group path to read data from.

    Returns:
        The reconstructed data structure.
    """
    group = file[group_path]
    if isinstance(group, h5py.Dataset):
        return group[()]  # Read dataset value
    elif isinstance(group, h5py.Group):
        if all(k.isdigit() for k in group.keys()):  # Likely a list
            return [load_from_h5py(file, f"{group_path}/{k}") for k in sorted(group.keys(), key=int)]
        else:  # Dictionary-like group
            return {k: load_from_h5py(file, f"{group_path}/{k}") for k in group.keys()}
    else:
        raise TypeError(f"Unsupported group type: {type(group)}")


# Example usage
with h5py.File("clip1_rollout.h5", "r") as h5file:
    loaded_data = load_from_h5py(h5file)

# Analysis

In [None]:
rollout_info = {}
for i in tqdm(range(environment._n_clips)[:50]):
    crtls, extrases, rewards = generate_rollout(i)
    rollout_info[i] = {"ctrls": crtls, "extrases": extrases, "rewards": rewards}

In [48]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

In [32]:
intentions, clips = [], []
for i in range(50):
    rollout = rollout_info[i]["extrases"]
    for j in range(int(episode_length)):
        intention = rollout[j]["activations"]["intention"]
        intentions.append(intention)
        clips.append(i)
intentions, clips = np.array(intentions), np.array(clips)

In [None]:
import numpy as np

pca = PCA()
# tsne = TSNE(n_components=2, random_state=42, n_jobs=-1, perplexity=30)

# embedded = tsne.fit_transform(intentions, clips)
pca = pca.fit(intentions)
print(np.cumsum(pca.explained_variance_ratio_[:10]))
pca_embedded = pca.transform(intentions)

In [45]:
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'

In [None]:
plt.scatter(pca_embedded[:, 0], pca_embedded[:, 1], c=clips, cmap="tab20", alpha=0.5)

In [None]:
clip_ind = 0

is_clip = clips == clip_ind
clip_intentions = intentions[is_clip]
clip_projections = pca.transform(clip_intentions)

plt.figure(figsize=(12, 4))
for pc_ind in range(4):
    plt.plot(clip_projections[:, pc_ind], label=f"PC {pc_ind} ({pca.explained_variance_ratio_[pc_ind]*100:.1f}%)")

plt.legend()
plt.xlabel("Time (steps)")

# Rendering

Video on the left, and the PCA of the intention rolling out.

Walker type.

In [62]:
import pandas as pd
import seaborn as sns

In [None]:
clip_ind = 0

is_clip = clips == clip_ind
clip_intentions = intentions[is_clip]
clip_projections = pca.transform(clip_intentions)

clip_df = pd.DataFrame(clip_projections[:, :10], columns=[f"PC {i+1}" for i in range(10)])
clip_df["time"] = np.arange(clip_projections.shape[0])

sns.set_context("talk")
plt.figure(figsize=(6, 6))
sns.lineplot(data=clip_df, x="PC 1", y="PC 2")

In [None]:
plt.scatter(embedded[:, 0], embedded[:, 1], c=clips, cmap="tab20", alpha=0.5)

# rollout save

`xarray` `hdf5`. Everything including states, obs. Dataset (datapoint session) should be modular.

save file hdf5, one each clip.