### Testing generalization to eval dataset (transform_coltrane_2021_07_28_1.h5)

In [1]:
import os
import logging
# Send logging outputs to stdout (comment this out if preferred)
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)
from tqdm import tqdm

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "glfw"

from track_mjx.agent import checkpointing
from track_mjx.analysis import rollout, render, utils
from track_mjx.environment import wrappers
from typing import Dict, Callable
import numpy as np
import jax
from brax.envs.base import Env
from track_mjx.environment.walker.rodent import Rodent
from track_mjx.environment.walker.fly import Fly
from brax import envs
from typing import Dict, Callable
import hydra
import logging
from track_mjx.environment.task.reward import RewardConfig
from jax import numpy as jnp

from track_mjx.environment.task.multi_clip_tracking import MultiClipTracking
from track_mjx.environment.task.single_clip_tracking import SingleClipTracking
from track_mjx.environment import wrappers
from track_mjx.io import load


from omegaconf import DictConfig
from pathlib import Path

# don't preallocate jax memory
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"


In [2]:
eval_data_path = "/Users/charleszhang/GitHub/stac-mjx/transform_coltrane_2021_07_28_1.h5"

# replace with your checkpoint path
ckpt_path = "/Users/charleszhang/GitHub/track-mjx/model_checkpoints/ademamix_64"  
# Load config from checkpoint 
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)

cfg = ckpt["cfg"]

# make some changes to the config
# replace with absolute path to your data
# -- your notebook may not have access to the same relative path
cfg.data_path = "/Users/charleszhang/GitHub/track-mjx/data/transform_snips.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

In [None]:
n_frames_per_clip=1000
inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"], get_activation=False)
reference_clip = load.make_multiclip_data(eval_data_path, n_frames_per_clip=n_frames_per_clip)
reference_clip.position.shape

create env with test set reference clips

In [None]:
envs.register_environment("rodent_multi_clip", MultiClipTracking)

env_args = cfg["env_config"]["env_args"]
env_rewards = cfg["env_config"]["reward_weights"]
walker_config = cfg["walker_config"]
traj_config = cfg["reference_config"]

walker = Rodent(**walker_config)

reward_config = RewardConfig(**env_rewards)
# Automatically match dict keys and func needs
env = envs.get_environment(
    env_name="rodent_multi_clip",
    reference_clip=reference_clip,
    walker=walker,
    reward_config=reward_config,
    **env_args,
    **traj_config,
)

In [5]:
rollout_env = wrappers.RenderRolloutWrapperMulticlipTracking(env)

In [6]:
# vmap wrapper and auto align wrapper
align_vmap_env = wrappers.AutoAlignWrapperTracking(
        wrappers.RenderRolloutVmapWrapper(rollout_env)
    )
align_reset = align_vmap_env.reset 
align_step = align_vmap_env.step

In [None]:
# JIT-compile the necessary functions
jit_inference_fn = jax.jit(jax.vmap(inference_fn))
jit_align_reset = jax.jit(align_reset)
jit_align_step = jax.jit(align_step)


### transform data to multiclip, then initialize to multiclip with a long clip length (5000)

In [None]:
# autoalign wrapper rollout
num_envs = 3
key_envs = jax.random.split(jax.random.PRNGKey(0), num_envs)

clip_idxs = jnp.arange(0, num_envs)

init_states = jit_align_reset(key_envs, clip_idx=clip_idxs)
num_steps = (
    int(n_frames_per_clip * env._steps_for_cur_frame) - 1
) 

rollouts = [init_states]
state = init_states
activations = []
rng_policy = jax.random.split(jax.random.PRNGKey(1), num_envs)

for i in tqdm(range(num_steps)):
    _, rng_policy = jax.vmap(jax.random.split)(rng_policy)
    ctrl, extras = jit_inference_fn(state.obs, rng_policy)
    # activations.append(extras["activations"])
    state = jit_align_step(state, ctrl)
    rollouts.append(state)

In [None]:
# render rollouts
rollout_dict = {
    "qposes_ref": [r.pipeline_state.qpos[0] for r in rollouts], # set this as the mocap qpos
    "qposes_rollout": [r.pipeline_state.qpos[0] for r in rollouts]
}
frames, realtime_framerate = render.render_rollout(cfg, rollout_dict)
render.display_video(frames, framerate=realtime_framerate)


In [None]:
# calculate the number of done frames per hour
dones = jnp.array([r.metrics["done"] for r in rollouts])
done_count = dones.sum()
done_per_hour = done_count * (3600 / (dones.size / realtime_framerate))  # Scale up to per hour
done_per_hour


plot reward over time for a 2000 (40 second) continuous rollout

In [None]:
from matplotlib import pyplot as plt

plt.plot([r.reward[0] for r in rollouts])