# Imports & Initializer

In [3]:
import sys
from pathlib import Path

main_path = Path().resolve().parent
if str(main_path) not in sys.path:
    sys.path.append(str(main_path))
    
import orbax.checkpoint as ocp
from track_mjx.agent.custom_ppo import TrainingState
from track_mjx.agent import custom_losses as ppo_losses
from track_mjx.agent import custom_ppo, custom_ppo_networks
import jax
from jax import numpy as jnp
import optax
from brax import envs
from brax.training.acme import running_statistics, specs
from track_mjx.environment import custom_wrappers

from track_mjx.environment.task.multi_clip_tracking import MultiClipTracking
from track_mjx.environment.task.single_clip_tracking import SingleClipTracking
from track_mjx.environment.task.reward import RewardConfig

from track_mjx.environment.walker.rodent import Rodent
import hydra
from hydra import initialize, compose
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
import pickle
import functools

if GlobalHydra.instance().is_initialized():
    GlobalHydra.instance().clear()
initialize(version_base=None, config_path="./../track_mjx/config")

# Load the config file
cfg = compose(config_name="rodent-full-clips")

# register the environment
envs.register_environment("rodent_single_clip", SingleClipTracking)
envs.register_environment("rodent_multi_clip", MultiClipTracking)
envs.register_environment("fly_multi_clip", MultiClipTracking)

env_cfg = OmegaConf.to_container(cfg, resolve=True)
env_args = cfg.env_config["env_args"]
env_rewards = cfg.env_config["reward_weights"]
train_config = cfg.train_setup["train_config"]
walker_config = cfg["walker_config"]
traj_config = cfg["reference_config"]

# Create the `TrainingState`

Currently I just directly copy and paste the scripts from the training code. In the future for the sake of runtime and cleaner code, we can have a helper function that construct the abstracted `TrainingState` directly without creating the actual environment. This is on Scott's TODOs.

In [4]:
# TODO(Scott): move this to track_mjx.io module
input_data_path = "/root/vast/scott-yang/track-mjx/data/ReferenceClip.p"
with open(input_data_path, "rb") as file:
    reference_clip = pickle.load(file)

walker_map = {
    "rodent": Rodent,
}

walker_class = walker_map[env_cfg["walker_type"]]
walker = walker_class(**walker_config)

# didn't use args** since penalty_pos_distance_scale need conversion
reward_config = RewardConfig(
    too_far_dist=env_rewards.too_far_dist,
    bad_pose_dist=env_rewards.bad_pose_dist,
    bad_quat_dist=env_rewards.bad_quat_dist,
    ctrl_cost_weight=env_rewards.ctrl_cost_weight,
    ctrl_diff_cost_weight=env_rewards.ctrl_diff_cost_weight,
    pos_reward_weight=env_rewards.pos_reward_weight,
    quat_reward_weight=env_rewards.quat_reward_weight,
    joint_reward_weight=env_rewards.joint_reward_weight,
    angvel_reward_weight=env_rewards.angvel_reward_weight,
    bodypos_reward_weight=env_rewards.bodypos_reward_weight,
    endeff_reward_weight=env_rewards.endeff_reward_weight,
    healthy_z_range=env_rewards.healthy_z_range,
    pos_reward_exp_scale=env_rewards.pos_reward_exp_scale,
    quat_reward_exp_scale=env_rewards.quat_reward_exp_scale,
    joint_reward_exp_scale=env_rewards.joint_reward_exp_scale,
    angvel_reward_exp_scale=env_rewards.angvel_reward_exp_scale,
    bodypos_reward_exp_scale=env_rewards.bodypos_reward_exp_scale,
    endeff_reward_exp_scale=env_rewards.endeff_reward_exp_scale,
    penalty_pos_distance_scale=jnp.array(env_rewards.penalty_pos_distance_scale),
)

# Automatically match dict keys and func needs
environment = envs.get_environment(
    env_name=cfg.env_config.env_name,
    reference_clip=reference_clip,
    walker=walker,
    reward_config=reward_config,
    **env_args,
    **traj_config,
)
# Episode length is equal to (clip length - random init range - traj length) * steps per cur frame.
# Will work on not hardcoding these values later
episode_length = (
    traj_config.clip_length - traj_config.random_init_range - traj_config.traj_length
) * environment._steps_for_cur_frame
print(f"episode_length {episode_length}")

network_factory = functools.partial(
    custom_ppo_networks.make_intention_ppo_networks,
    encoder_hidden_layer_sizes=tuple(cfg.network_config.encoder_layer_sizes),
    decoder_hidden_layer_sizes=tuple(cfg.network_config.decoder_layer_sizes),
    value_hidden_layer_sizes=tuple(cfg.network_config.critic_layer_sizes),
)

seed = 42

key = jax.random.PRNGKey(seed)
global_key, local_key = jax.random.split(key)
local_key, key_env, eval_key = jax.random.split(local_key, 3)
# key_networks should be global, so that networks are initialized the same
# way for different processes.
key_policy, key_value, policy_params_fn_key = jax.random.split(global_key, 3)

v_randomization_fn = None

if isinstance(environment, envs.Env):
    wrap_for_training = custom_wrappers.wrap
else:
    wrap_for_training = custom_wrappers.wrap

train_env = wrap_for_training(
    environment,
    episode_length=episode_length,
    action_repeat=1,
    randomization_fn=v_randomization_fn,
)

reset_fn = train_env.reset
key_envs = jax.random.split(key_env, 1)
env_state = reset_fn(key_envs)

normalize = lambda x, y: x
if True:
    normalize = running_statistics.normalize
ppo_network = network_factory(
    env_state.obs.shape[-1],
    int(env_state.info["reference_obs_size"][0]),
    train_env.action_size,
    preprocess_observations_fn=normalize,
)
make_policy = custom_ppo_networks.make_inference_fn(ppo_network)

optimizer = optax.adam(learning_rate=1e-4)

init_params = ppo_losses.PPONetworkParams(
    policy=ppo_network.policy_network.init(key_policy),
    value=ppo_network.value_network.init(key_value),
)

training_state = TrainingState(  # pytype: disable=wrong-arg-types  # jax-ndarray
    optimizer_state=optimizer.init(init_params),  # pytype: disable=wrong-arg-types  # numpy-scalars
    params=init_params,
    normalizer_params=running_statistics.init_state(specs.Array(env_state.obs.shape[-1:], jnp.dtype("float32"))),
    env_steps=0,
)

abstract_policy = (training_state.normalizer_params, training_state.params.policy)

self._steps_for_cur_frame: 2.0
episode_length 390.0


  reference_frame = jax.tree_map(
  return jax.tree_map(lambda x: x[info["clip_idx"]], self._reference_clips)
  return jax.tree_map(lambda x: x[info["clip_idx"]], self._reference_clips)


# Orbax Checkpoint Related

The following code will create the checkpoint manager and restore the checkpoint for both the policy and train state.

In [5]:
# restore the whole model (both policy module, and training state)
options = ocp.CheckpointManagerOptions(step_prefix="PPONetwork")
with ocp.CheckpointManager(
    "/root/vast/scott-yang/track-mjx/model_checkpoints/rodent_data/ReferenceClip.p_c52ed4",
    options=options,
) as mngr:
    print(f"latest checkpoint step: {mngr.latest_step()}")
    policy = mngr.restore(
        mngr.latest_step(), args=ocp.args.Composite(policy=ocp.args.StandardRestore(abstract_policy))
    )["policy"]
    train_state = mngr.restore(
        mngr.latest_step(), args=ocp.args.Composite(train_state=ocp.args.StandardRestore(training_state))
    )["train_state"]

latest checkpoint step: 1053736960


# Rendering from the loaded checkpoint

For this, I created a simplified version of the `policy_param_fn`, which just render the videos and save it to disk.

In [6]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation


def display_video(frames, framerate=30):
    """
    Args:
        frames (array): (n_frames, height, width, 3)
        framerate (int)
    """
    height, width, _ = frames[0].shape
    dpi = 70
    orig_backend = matplotlib.get_backend()
    matplotlib.use("Agg")  # Switch to headless 'Agg' to inhibit figure rendering.
    fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
    plt.close("all")  # Figure auto-closing upon backend switching is deprecated.
    matplotlib.use(orig_backend)  # Switch back to the original backend.
    ax.set_axis_off()
    ax.set_aspect("equal")
    ax.set_position([0, 0, 1, 1])
    im = ax.imshow(frames[0])

    def update(frame):
        im.set_data(frame)
        return [im]

    interval = 1000 / framerate
    anim = animation.FuncAnimation(fig=fig, func=update, frames=frames, interval=interval, blit=True, repeat=False)
    return HTML(anim.to_html5_video())

In [7]:
from track_mjx.agent.logging import render_rollout
from IPython.display import HTML

frames = render_rollout(mngr.latest_step(), make_policy, policy, key_policy, cfg, environment, ".")

display_video(frames, framerate=30)

  reference_frame = jax.tree_map(
  return jax.tree_map(lambda x: x[info["clip_idx"]], self._reference_clips)
  return jax.tree_map(lambda x: x[info["clip_idx"]], self._reference_clips)
E0120 02:30:56.002812  471435 pjrt_stream_executor_client.cc:3085] Execution of replica 0 failed: INTERNAL: Failed to allocate 169242000 bytes for new constant


XlaRuntimeError: INTERNAL: Failed to allocate 169242000 bytes for new constant