# Load the checkpoint and render

In [None]:
%load_ext autoreload
%autoreload 2

from flax import nnx
import orbax.checkpoint as ocp
import jax
from jax import numpy as jnp
import numpy as np

from track_mjx.agent.nnx_ppo_network import make_intention_ppo_networks

from absl import flags
import hydra
from omegaconf import DictConfig
import uuid
import jax
import wandb
from brax import envs
import pickle
import warnings

from track_mjx.environment.task.multi_clip_tracking import RodentMultiClipTracking
from track_mjx.environment.task.single_clip_tracking import RodentTracking
from track_mjx.io import preprocess as preprocessing  # the pickle file needs it
from track_mjx.agent import nnx_ppo
from track_mjx.agent import nnx_ppo_network
from track_mjx.agent.logging import policy_params_fn
from track_mjx.environment.walker.rodent import Rodent

FLAGS = flags.FLAGS
warnings.filterwarnings("ignore", category=DeprecationWarning)

from hydra import initialize, compose
from omegaconf import OmegaConf

initialize(version_base=None, config_path="./../track_mjx/config")

# Load the config file
cfg = compose(config_name="rodent-mc-intention")


"""Main function using Hydra configs"""

try:
    n_devices = jax.device_count(backend="gpu")
    print(f"Using {n_devices} GPUs")
except:
    n_devices = 1
    print("Not using GPUs")

flags.DEFINE_enum("solver", "cg", ["cg", "newton"], "constraint solver")
flags.DEFINE_integer("iterations", 4, "number of solver iterations")
flags.DEFINE_integer("ls_iterations", 4, "number of linesearch iterations")

envs.register_environment("single clip", RodentTracking)
envs.register_environment("multi clip", RodentMultiClipTracking)

input_data_path = "/root/vast/scott-yang/track-mjx/data/ReferenceClip.p"
print(f"Loading data: {input_data_path}")
with open(input_data_path, "rb") as file:
    # Use pickle.load() to load the data from the file
    reference_clip = pickle.load(file)

# TODO (Kevin): add this as a yaml config
walker = Rodent

# instantiate the environment
env = envs.get_environment(
    cfg.env_config.env_name,
    reference_clip=reference_clip,
    walker=walker,
    torque_actuators=cfg.env_config.torque_actuators,
    solver=cfg.env_config.solver,
    iterations=cfg.env_config.iterations,
    ls_iterations=cfg.env_config.ls_iterations,
    too_far_dist=cfg.env_config.reward_weights.too_far_dist,
    bad_pose_dist=cfg.env_config.reward_weights.bad_pose_dist,
    bad_quat_dist=cfg.env_config.reward_weights.bad_quat_dist,
    ctrl_cost_weight=cfg.env_config.reward_weights.ctrl_cost_weight,
    ctrl_diff_cost_weight=cfg.env_config.reward_weights.ctrl_diff_cost_weight,
    pos_reward_weight=cfg.env_config.reward_weights.pos_reward_weight,
    quat_reward_weight=cfg.env_config.reward_weights.quat_reward_weight,
    joint_reward_weight=cfg.env_config.reward_weights.joint_reward_weight,
    angvel_reward_weight=cfg.env_config.reward_weights.angvel_reward_weight,
    bodypos_reward_weight=cfg.env_config.reward_weights.bodypos_reward_weight,
    endeff_reward_weight=cfg.env_config.reward_weights.endeff_reward_weight,
    healthy_z_range=tuple(cfg.env_config.reward_weights.healthy_z_range),
    physics_steps_per_control_step=cfg.env_config.physics_steps_per_control_step,
)

"""Main logging functions for policy params,
cfg, wandb, model_path, and env currently func.partial from train.py"""

from track_mjx.environment import custom_wrappers
import os
import mujoco
from dm_control import mjcf as mjcf_dm
from dm_control.locomotion.walkers import rescale
import imageio

# Wrap the env in the brax autoreset and episode wrappers
# rollout_env = custom_wrappers.AutoResetWrapperTracking(env)
rollout_env = custom_wrappers.RenderRolloutWrapperTracking(env)
# define the jit reset/step functions
jit_reset = jax.jit(rollout_env.reset)
jit_step = jax.jit(rollout_env.step)


In [2]:
# Setup MuJoCo Renderer -- ONLY RUN ONCE
# TODO Better relative path scripts
_XML_PATH = "/root/vast/scott-yang/track-mjx/track_mjx/environment/walker/assets/rodent_ghostpair_scale080.xml"
root = mjcf_dm.from_path(_XML_PATH)
rescale.rescale_subtree(
    root,
    0.9 / 0.8,
    0.9 / 0.8,
)

mj_model = mjcf_dm.Physics.from_mjcf_model(root).model.ptr
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
mj_data = mujoco.MjData(mj_model)

# walker._load_mjcf_model(torque_actuators=cfg.env_config.torque_actuators, path=_XML_PATH)

# save rendering and log to wandb
mujoco.mj_kinematics(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model, height=512, width=512)

In [None]:
env_state = rollout_env.reset(rng=jax.random.PRNGKey(0))

In [4]:
train_cfg = cfg.train_config
network_cfg = cfg.network_config

ppo_network = make_intention_ppo_networks(
    env_state.obs.shape[-1],
    env_state.info["reference_obs_size"],
    env.action_size,
    normalize_obs=True,
    encoder_layers=[256, 256, 256],
    decoder_layers=[256, 256, 256],
    value_layers=network_cfg.critic_layer_sizes,
)

In [None]:
nnx.display(ppo_network)

In [None]:
# Restore the checkpoint back to its `nnx.State` structure - need an abstract reference.
abstract_model = nnx.eval_shape(
    lambda: make_intention_ppo_networks(
        env_state.obs.shape[-1],
        env_state.info["reference_obs_size"],
        env.action_size,
        normalize_obs=True,
        encoder_layers=network_cfg.encoder_layer_sizes,
        decoder_layers=network_cfg.decoder_layer_sizes,
        value_layers=network_cfg.critic_layer_sizes,
    )
)
graphdef, key, abstract_state = nnx.split(abstract_model, nnx.RngKey, ...)
print("The abstract NNX state (all leaves are abstract arrays):")
nnx.display(abstract_state)

In [None]:
checkpointer = ocp.StandardCheckpointer()
ckpt_path = "/root/vast/scott-yang/track_mjx_model_checkpoints/d4a8e221-52f4-481e-9d07-10fe4dc8c8fc"
ckpt_name = "ppo_networks_1"
state_restored = checkpointer.restore(f"{ckpt_path}/{ckpt_name}", abstract_state)

In [8]:
model = nnx.merge(graphdef, state_restored, key)

from track_mjx.brax_nnx import acting
from tqdm import tqdm

The following script could be run in another thread of python that is dedicated for rendering nnx.

In [None]:
rollout_rng, reset_rng, act_rng = jax.random.split(jax.random.PRNGKey(0), 3)
for j in range(100):
    # restore the model
    checkpointer = ocp.StandardCheckpointer()
    ckpt_path = "/root/vast/scott-yang/track_mjx_model_checkpoints/d4a8e221-52f4-481e-9d07-10fe4dc8c8fc"
    ckpt_name = f"ppo_networks_{j+8}"
    state_restored = checkpointer.restore(f"{ckpt_path}/{ckpt_name}", abstract_state)
    model = nnx.merge(graphdef, state_restored, key)

    for i in range(5):
        reset_rng, reset_key = jax.random.split(reset_rng)
        state = jit_reset(reset_key)

        rollout = [state]
        for i in range(int(250 * env._steps_for_cur_frame)):
            rollout_key, rollout_rng = jax.random.split(rollout_rng)
            obs = state.obs
            ctrl, extras = model.policy(obs, rollout_key, deterministic=True)
            state = jit_step(state, ctrl)
            rollout.append(state)

        pos_rewards = [state.metrics["pos_reward"] for state in rollout]
        endeff_rewards = [state.metrics["endeff_reward"] for state in rollout]
        quat_rewards = [state.metrics["quat_reward"] for state in rollout]
        angvel_rewards = [state.metrics["angvel_reward"] for state in rollout]
        bodypos_rewards = [state.metrics["bodypos_reward"] for state in rollout]
        joint_rewards = [state.metrics["joint_reward"] for state in rollout]
        summed_pos_distances = [state.info["summed_pos_distance"] for state in rollout]
        joint_distances = [state.info["joint_distance"] for state in rollout]
        torso_heights = [state.pipeline_state.xpos[env.walker._torso_idx][2] for state in rollout]

        # Render the walker with the reference expert demonstration trajectory
        qposes_rollout = np.array([state.pipeline_state.qpos for state in rollout])

        # ref_traj = jax.tree_util.tree_map(f, reference_clip)
        ref_traj = rollout_env._get_reference_clip(rollout[0].info)
        print(f"clip_id:{rollout[0].info}")
        qposes_ref = np.repeat(
            np.hstack([ref_traj.position, ref_traj.quaternion, ref_traj.joints]),
            env._steps_for_cur_frame,
            axis=0,
        )

        # render while stepping using mujoco
        video_path = ckpt_path + f"/{ckpt_name}/clip_{rollout[0].info['clip_idx']}.mp4"

        with imageio.get_writer(video_path, fps=int((1.0 / env.dt / 2))) as video:
            for qpos1, qpos2 in tqdm(zip(qposes_rollout[::2], qposes_ref[::2])):

                # TODO: ValueError: could not broadcast input array from shape (148,) into shape (74,)
                mj_data.qpos = np.append(qpos1, qpos2)
                mujoco.mj_forward(mj_model, mj_data)
                renderer.update_scene(mj_data, camera=f"close_profile")
                pixels = renderer.render()
                video.append_data(pixels)