# Generate and render a rollout for existing checkpoint

This notebook demonstrates how to load a training checkpoint, perform a rollout, and render the result. Full network activations are saved as an output of this rollout for further analysis.

## Imports

In [1]:
import os
import logging
# Send logging outputs to stdout (comment this out if preferred)
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Change this to egl or glfw if available
os.environ["MUJOCO_GL"] = "osmesa"

from track_mjx.agent import checkpointing
from track_mjx.analysis.rollout import (
    create_rollout_generator,
    create_environment,
)
from track_mjx.analysis.render import render_from_saved_rollout, display_video
from track_mjx.analysis.utils import save_to_h5py, load_from_h5py

import jax
from jax import numpy as jp
from pathlib import Path

INFO:absl:Handler "orbax.checkpoint._src.handlers.array_checkpoint_handler.ArrayCheckpointHandler" already exists in the registry with associated type <class 'orbax.checkpoint._src.handlers.array_checkpoint_handler.ArrayCheckpointHandler'>. Skipping registration.
INFO:absl:Handler "orbax.checkpoint._src.handlers.proto_checkpoint_handler.ProtoCheckpointHandler" already exists in the registry with associated type <class 'orbax.checkpoint._src.handlers.proto_checkpoint_handler.ProtoCheckpointHandler'>. Skipping registration.
INFO:absl:Handler "orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler" already exists in the registry with associated type <class 'orbax.checkpoint._src.handlers.json_checkpoint_handler.JsonCheckpointHandler'>. Skipping registration.
INFO:absl:Handler "orbax.checkpoint._src.handlers.base_pytree_checkpoint_handler.BasePyTreeCheckpointHandler" already exists in the registry with associated type <class 'orbax.checkpoint._src.handlers.base_pytree

## Load checkpoint

In [2]:
from omegaconf import DictConfig, OmegaConf
cfg = OmegaConf.load('/home/tim.kim/track-mjx/track_mjx/config/rodent-full-clips.yaml')

In [3]:
#cfg.data_path = "/allen/aind/scratch/tim.kim/track-mjx/data/transform_snips.h5"
cfg.data_path = "/allen/aind/scratch/tim.kim/track-mjx/data/art_tmjx/2020_12_22_1.h5"

# Render Rollout Videos from the Saved Rollouts

## Load the rollout file

In [4]:
import h5py
import numpy as np
#filename = "/allen/aind/scratch/tim.kim/track-mjx/data/transform_snips.h5"
filename = "/allen/aind/scratch/tim.kim/track-mjx/data/art_tmjx/2020_12_22_1.h5"
f = h5py.File(filename, "r")
#qpos = f['qpos'][()]
#qpos = np.reshape(qpos, (842, 250, -1))

In [5]:
f.keys()

<KeysViewHDF5 ['kp_data', 'motion_mapper', 'motion_mapper_names', 'qpos', 'qvel', 'spike_counts', 'xpos', 'xquat']>

In [6]:
qpos = f['qpos'][()]

In [7]:
motion_mapper = f['motion_mapper'][()]

In [8]:
motion_mapper_names = np.array(
    [f['motion_mapper_names'][()][i].decode("utf-8") for i in range(len(f['motion_mapper_names'][()]))]
)

In [9]:
motion_mapper_names

array(['ProneStill', 'RearHigh', 'ProneStill', 'FaceGroom', 'ProneStill',
       'ProneSlow', 'ProneStill', 'ProneSlow', 'ProneSniff', 'RearHigh',
       '(check)', 'RearMid', 'ProneSniff', 'Amble', 'ProneSlow', 'Walk',
       'ProneStill', 'ProneSlow', 'Amble', 'ProneSlow', 'ProneSlow',
       'Amble', 'FaceGroom', 'ProneSniff', 'ProneSlow', 'RearHigh',
       'Amble', 'ProneStill', 'Amble', 'ProneStill', 'Amble',
       'ProneStill', 'FaceGroom', 'RearDown', 'ProneSlow', 'ProneSniff',
       '(check)', 'Amble', 'RearSniff', 'Amble', 'ProneSniff',
       'ProneSlow', 'ProneStill', 'RearHigh', 'ProneStill', 'RearMid',
       'ProneStill', 'ProneSlow', 'ProneSniff', 'ProneStill', 'ProneSlow',
       'ProneSlow', 'ProneStill', 'ProneSniff', 'Walk', 'TrackingError',
       'FaceGroom', 'ProneStill', 'Walk', 'ProneStill', 'FaceGroom',
       'ProneSlow', 'RearHigh', 'WalkFast', 'ProneSlow', 'ProneSlow',
       'ProneSlow', 'FaceGroom', 'ProneStill', 'Hunch', 'ProneSlow',
       'ProneSniff

In [10]:
motion_indices = np.zeros_like(motion_mapper)
indx_list = np.where(motion_mapper_names == 'RearHigh')[0] + 1
for i in indx_list:
    motion_indices += (motion_mapper==i)
motion_indices = motion_indices.astype(np.bool)

In [11]:
def find_clean_long_bouts(bool_array, min_length=100, pre_false_count=10):
    bouts = []
    in_bout = False
    bout_start = None

    for i in range(1, len(bool_array)):
        if not in_bout and bool_array[i - 1] == False and bool_array[i] == True:
            # Check if there were enough consecutive Falses before the bout
            if i >= pre_false_count:
                if all(not b for b in bool_array[i - pre_false_count:i]):
                    bout_start = i
                    in_bout = True
        elif in_bout and bool_array[i - 1] == True and bool_array[i] == False:
            # End of a bout
            bout_end = i
            duration = bout_end - bout_start
            if duration > min_length:
                bouts.append(bout_start)
            in_bout = False

    return bouts

In [12]:
clean_long_bout_indices = find_clean_long_bouts(motion_indices, min_length=50, pre_false_count=50)

In [15]:
qpos_ = np.vstack([qpos[i-50:i+50, :] for i in clean_long_bout_indices])

In [16]:
qposes_ref, qposes_rollout = qpos_, qpos_
if cfg.walker_type == "rodent":
    pair_render_xml_path = "../track_mjx/environment/walker/assets/rodent/rodent_ghostpair_scale080.xml"

In [17]:
from dm_control import mjcf as mjcf_dm
root = mjcf_dm.from_path(pair_render_xml_path)

In [18]:
from dm_control.locomotion.walkers import rescale

In [19]:
rescale.rescale_subtree(
    root,
    cfg.walker_config.rescale_factor / 0.8,
    cfg.walker_config.rescale_factor / 0.8,
)

In [20]:
import mujoco
mj_model = mjcf_dm.Physics.from_mjcf_model(root).model.ptr
mj_model.opt.solver = {
    "cg": mujoco.mjtSolver.mjSOL_CG,
    "newton": mujoco.mjtSolver.mjSOL_NEWTON,
}["cg"]
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
mj_data = mujoco.MjData(mj_model)

In [21]:
from tqdm import tqdm
import numpy as np
# Calulate realtime rendering fps
render_fps = (
    1.0 / mj_model.opt.timestep
) / cfg.env_config.env_args.physics_steps_per_control_step

# save rendering and log to wandb
mujoco.mj_kinematics(mj_model, mj_data)
renderer = mujoco.Renderer(mj_model, height=480, width=640)
frames = []
print("MuJoCo Rendering...")
for qpos1, qpos2 in tqdm(
    zip(qposes_rollout, qposes_ref), total=len(qposes_rollout)
):
    mj_data.qpos = np.append(qpos1, qpos2)
    mujoco.mj_forward(mj_model, mj_data)
    renderer.update_scene(mj_data, camera=cfg.env_config.render_camera_name)
    pixels = renderer.render()
    frames.append(pixels)

MuJoCo Rendering...


100%|██████████| 4000/4000 [13:10<00:00,  5.06it/s]


In [22]:
mj_model.opt.timestep

0.002

In [23]:
cfg.env_config.env_args.physics_steps_per_control_step

5

## Render rollout

Note: Currently only works for non-batched rollouts

In [24]:
display_video(frames, framerate=50)

INFO:matplotlib.animation:Animation.save using <class 'matplotlib.animation.FFMpegWriter'>
INFO:matplotlib.animation:MovieWriter._run: running command: ffmpeg -f rawvideo -vcodec rawvideo -s 640x480 -pix_fmt rgba -framerate 50.0 -loglevel error -i pipe: -vcodec h264 -pix_fmt yuv420p -y /tmp/tmp5g9_6731/temp.m4v
