In [None]:
%load_ext autoreload

%autoreload 2

In [None]:
import os

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"

In [None]:
import functools
import json
from datetime import datetime
from typing import Any, Dict, Optional

import jax
import jax.numpy as jp
import pickle
import matplotlib.pyplot as plt
import mediapy as media
import mujoco
import numpy as np
import wandb
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from etils import epath
from flax.training import orbax_utils
from IPython.display import clear_output, display
from orbax import checkpoint as ocp

from mujoco_playground import registry
from mujoco_playground.config import locomotion_params, manipulation_params, dm_control_suite_params

from brax.io import model as brax_io_model
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as brax_ppo_networks
from brax.training.agents.sac import networks as brax_sac_networks


In [None]:
def get_inference_fn(
    obs_size, act_size, normalize_obs, network_factory_kwargs, params,
    is_ppo=True,
):
    def make_inference_fn(
        observation_size: int,
        action_size: int,
        normalize_observations: bool = True,
        network_factory_kwargs: Optional[Dict[str, Any]] = None,
    ):
      normalize = lambda x, y: x
      if normalize_observations:
        normalize = running_statistics.normalize
      if is_ppo:
        ppo_network = brax_ppo_networks.make_ppo_networks(
              observation_size,
              action_size,
              preprocess_observations_fn=normalize,
              **(network_factory_kwargs or {}),
        )
        make_policy = brax_ppo_networks.make_inference_fn(ppo_network)
      else:
        sac_network = brax_sac_networks.make_sac_networks(
              observation_size,
              action_size,
              preprocess_observations_fn=normalize,
              **(network_factory_kwargs or {}),
        )
        make_policy = brax_sac_networks.make_inference_fn(sac_network)
      return make_policy

    make_policy = make_inference_fn(
        obs_size,
        act_size,
        normalize_obs,
        network_factory_kwargs,
    )
    jit_inference_fn = jax.jit(make_policy(params, deterministic=True))
    return jit_inference_fn

In [None]:
policy_params = {}
inference_fns = {}
for f in epath.Path('../data/pkl').glob('*.pkl'):
    env_name, algo, *_ = f.name.split('_')
    algo = algo.rstrip('.pkl')
    policy_params[env_name] = brax_io_model.load_params(f)

    if env_name in registry.locomotion.ALL_ENVS:
        algo_params = locomotion_params.brax_ppo_config(env_name)
    elif env_name in registry.dm_control_suite.ALL_ENVS:
        if algo == 'ppo':
            algo_params = dm_control_suite_params.brax_ppo_config(env_name)
        elif algo == 'sac':
            algo_params = dm_control_suite_params.brax_sac_config(env_name)
    elif env_name in registry.manipulation.ALL_ENVS:
        algo_params = manipulation_params.brax_ppo_config(env_name)
    else:
        raise AssertionError('nope')

    network_factory_kwargs = None
    if hasattr(algo_params, 'network_factory'):
        network_factory_kwargs = dict(algo_params.network_factory)

    env = registry.load(env_name)
    inference_fns[env_name] = get_inference_fn(
        env.observation_size, env.action_size,
        normalize_obs=algo_params.normalize_observations,
        network_factory_kwargs=network_factory_kwargs,
        params=policy_params[env_name],
        is_ppo=(algo == 'ppo')
    )
    print(env_name, ' Done!')


In [None]:
env_name = 'CheetahRun'
N_EPISODES = 1

env = registry.load(env_name)
cfg = registry.get_default_config(env_name)

jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [None]:
rng = jax.random.PRNGKey(0)
rollout = [jit_reset(rng)]

for _ in range(N_EPISODES):
    step, done = 0, False
    while not done and step < cfg.episode_length:
        if step % 100 == 0:
            print(step)

        rng, _ = jax.random.split(rng)
        state = rollout[-1]
        action = inference_fns[env_name](state.obs, rng)[0]
        rollout.append(jit_step(state, action))

        step += cfg.action_repeat
        done = bool(state.done)


In [None]:
fps = 1.0 / env.dt
print(f"fps: {fps}")

scene_option = mujoco.MjvOption()
scene_option.geomgroup[2] = True
scene_option.geomgroup[3] = False
scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True
scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True

frames = env.render(
    rollout,
    # camera="track",
    height=480,
    width=640,
    # modify_scene_fns=mod_fns,
    scene_option=scene_option,
)
media.show_video(frames, fps=fps, loop=False)