In [1]:
import os
import logging
# Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "egl"
# TEMP
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# TEMP
    
from track_mjx.agent import checkpointing
from track_mjx.agent import ppo_networks as networks
from track_mjx.environment import wrappers
from track_mjx.analysis import rollout, render, utils

import jax
from jax import numpy as jp
from pathlib import Path

In [2]:
# replace with your checkpoint path
ckpt_path = "/n/holylabs/LABS/olveczky_lab/Users/charleszhang/track-mjx/model_checkpoints/250428_112805"  
# Load config from checkpoint 
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path, step=11)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/n/holylabs/LABS/olveczky_lab/Users/charleszhang/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

In [3]:
env = rollout.create_environment(cfg)
decoder_policy_fn = networks.make_decoder_policy_fn(ckpt_path)

In [4]:
highlvlenv = wrappers.HighLevelWrapper(
    env, 
    decoder_policy_fn, 
    cfg["network_config"]["reference_obs_size"]
)
rollout_env = wrappers.RenderRolloutWrapperMulticlipTracking(highlvlenv) 
latent_size = cfg["network_config"]["intention_size"]


In [5]:
key = jax.random.PRNGKey(0)
jit_reset = jax.jit(rollout_env.reset)
jit_step = jax.jit(rollout_env.step)

In [6]:
@jax.jit
def random_walk_rollout(clip_idx, scale, constant_latent=None, seed=1):
    key = jax.random.PRNGKey(seed)
    _, reset_rng, rollout_key = jax.random.split(key, 3)
    init_state = jit_reset(reset_rng, clip_idx=clip_idx)
    num_steps = (
        int(cfg["reference_config"].clip_length * rollout_env._steps_for_cur_frame) - 1
    )

    def _step_fn(carry, _):
        state, act_rng, last_latents = carry
        _, act_rng = jax.random.split(act_rng)
        # action is a random walk, unless a constant latent is given
        latents = constant_latent
        if constant_latent is None:
            latents = last_latents * 0.9 + ((jax.random.normal(act_rng, (latent_size), dtype=jp.float32) * scale) * 0.1)
        next_state = jit_step(state, latents)
        return (next_state, act_rng, latents), (next_state, latents)

    # Run rollout
    init_carry = (init_state, rollout_key, jp.zeros(latent_size))
    (final_state, _, _), (states, ctrls) = jax.lax.scan(
        _step_fn, init_carry, None, length=num_steps
    )
 
    def prepend(element, arr):
        # Scalar elements shouldn't be modified
        if arr.ndim == 0:
            return arr

        return jp.concatenate([element[None], arr])

    rollout_states = jax.tree.map(prepend, init_state, states)

    # Get metrics
    rollout_metrics = {}
    for rollout_metric in cfg.logging_config.rollout_metrics:
        rollout_metrics[f"{rollout_metric}s"] = jax.vmap(
            lambda s: s.metrics[rollout_metric]
        )(rollout_states)

    # Reference and rollout qposes
    ref_traj = rollout_env._get_reference_clip(init_state.info)
    qposes_ref = jp.repeat(
        jp.hstack([ref_traj.position, ref_traj.quaternion, ref_traj.joints]),
        int(rollout_env._steps_for_cur_frame),
        axis=0,
    )

    # Collect qposes from states
    qposes_rollout = jax.vmap(lambda s: s.pipeline_state.qpos)(rollout_states)

    return {
        "rollout_metrics": rollout_metrics,
        "observations": jax.vmap(lambda s: s.obs)(rollout_states),
        "ctrl": ctrls,
        "qposes_ref": qposes_ref,
        "qposes_rollout": qposes_rollout,
        "info": jax.vmap(lambda s: s.info)(rollout_states),
    }

In [11]:
_, key = jax.random.split(key)
rollout = random_walk_rollout(
    clip_idx=1, 
    scale=0.01, 
    constant_latent=jax.random.normal(key, (latent_size), dtype=jp.float32) * 0.1, 
    seed=1
)

In [None]:
frames, realtime_framerate = render.render_rollout(cfg, rollout)
render.display_video(frames, framerate=realtime_framerate)