In [1]:
import os
import argparse
import pickle
import numpy as np
from pathlib import Path
from tqdm import tqdm
import jax

# Set Mujoco backend
os.environ["MUJOCO_GL"] = "osmesa"
# Don't preallocate JAX memory
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true --xla_gpu_triton_gemm_any=True "
)
from track_mjx.agent import checkpointing
from track_mjx.environment import wrappers
from track_mjx.environment.walker.rodent import Rodent
from track_mjx.environment.task.reward import RewardConfig
from track_mjx.environment.task.multi_clip_tracking import MultiClipTracking
from track_mjx.io import load
from brax import envs

### functions

In [2]:
def setup_env(cfg, reference_clip):
    """Set up the environment with the reference clip."""
    envs.register_environment("rodent_multi_clip", MultiClipTracking)

    env_args = cfg["env_config"]["env_args"]
    env_rewards = cfg["env_config"]["reward_weights"]
    walker_config = cfg["walker_config"]
    traj_config = cfg["reference_config"]

    # if eng_reward doesnt have energy_cost_weight, set it to 0.0
    if "energy_cost_weight" not in env_rewards:
        env_rewards["energy_cost_weight"] = 0.0
    walker = Rodent(**walker_config)
    reward_config = RewardConfig(**env_rewards)

    env = envs.get_environment(
        env_name="rodent_multi_clip",
        reference_clip=reference_clip,
        walker=walker,
        reward_config=reward_config,
        **env_args,
        **traj_config,
    )

    rollout_env = wrappers.RenderRolloutWrapperMulticlipTracking(env)
    align_vmap_env = wrappers.AutoAlignWrapperTracking(
        wrappers.RenderRolloutVmapWrapper(rollout_env)
    )

    return align_vmap_env

In [3]:
def make_eval_env(ckpt_path, eval_data_path, n_frames_per_clip=250):
    ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)
    cfg = ckpt["cfg"]

    # Load reference clip, get first 800 clips
    reference_clip = load.make_multiclip_data(
        eval_data_path, n_frames_per_clip=n_frames_per_clip
    )
    eval_reference_clip = jax.tree.map(lambda x: x[:400], reference_clip)
    # Setup environment
    env = setup_env(cfg, eval_reference_clip)

    jit_align_reset = jax.jit(env.reset)
    jit_align_step = jax.jit(env.step)

    return env, jit_align_reset, jit_align_step, eval_reference_clip.position.shape[0]

In [4]:

def evaluate_checkpoint(
    ckpt_path,
    ckpt_step,
    env,
    reset_fn,
    step_fn,
    n_clips,
    n_frames_per_clip=250,
    num_envs=400,
    max_batches=4,
):
    """Evaluate a checkpoint on the eval dataset."""
    # Load checkpoint
    ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path, step=ckpt_step)
    cfg = ckpt["cfg"]

    # Load inference function
    inference_fn = checkpointing.load_inference_fn(
        cfg, ckpt["policy"], get_activation=False
    )

    # JIT-compile functions
    jit_inference_fn = jax.jit(jax.vmap(inference_fn))

    # Run rollouts and collect metrics
    all_metrics = {
        "reward": [],
        "done": [],
    }

    batches = min(n_clips // num_envs, max_batches)
    for i in range(batches):
        print(f"Processing clips {i * num_envs} to {i * num_envs + num_envs}")

        key_envs = jax.random.split(jax.random.PRNGKey(0), num_envs)
        clip_idxs = jax.numpy.arange(0, num_envs) + i * num_envs

        state = reset_fn(key_envs, clip_idx=clip_idxs)
        rng_policy = jax.random.split(jax.random.PRNGKey(1), num_envs)

        num_steps = int(n_frames_per_clip * env._steps_for_cur_frame) - 1
        for _ in tqdm(range(num_steps)):
            rng_policy = jax.vmap(jax.random.split)(rng_policy)[:, 1, :]
            ctrl, _ = jit_inference_fn(state.obs, rng_policy)
            state = step_fn(state, ctrl)
            all_metrics["reward"].append(state.reward)
            all_metrics["done"].append(state.metrics["done"])

    # Calculate metrics
    reshaped_dones = np.array(all_metrics["done"]).reshape(-1, order="F")
    reshaped_rewards = np.array(all_metrics["reward"]).reshape(-1, order="F")

    avg_reward = float(reshaped_rewards.mean())

    done_count = reshaped_dones.sum()
    dones_per_hour = float(done_count * (3600 / (reshaped_dones.size / 100)))

    return {
        "checkpoint_step": ckpt_step,
        "avg_reward": avg_reward,
        "dones_per_hour": dones_per_hour,
    }

### evaluate given a checkpoint

In [5]:
ckpt_dir = "track-mjx/model_checkpoints/250220_125514"
test_data = "transform_coltrane_2021_07_28_1.h5"
n_frames_per_clip = 250

In [6]:
env, reset_fn, step_fn, n_clips = make_eval_env(
    ckpt_dir, test_data, n_frames_per_clip
)

env._steps_for_cur_frame: 2.0


In [9]:
result = evaluate_checkpoint(
    ckpt_dir,
    11,
    env,
    reset_fn,
    step_fn,
    n_clips,
    n_frames_per_clip=n_frames_per_clip,
)
result

Processing clips 0 to 400


100%|██████████| 499/499 [00:23<00:00, 21.05it/s]


{'checkpoint_step': 11,
 'avg_reward': 4.156365394592285,
 'dones_per_hour': 703.4067993164062}