# FPO on MuJoCo Playground

We begin with imports and JAX compilation caching.

In [1]:
%load_ext autoreload
%autoreload 2
%env MUJOCO_GL=egl
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env CUDA_VISIBLE_DEVICES=0

env: MUJOCO_GL=egl
env: XLA_PYTHON_CLIENT_PREALLOCATE=false
env: CUDA_VISIBLE_DEVICES=0


In [2]:
# Use JAX with compilation cache.
import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update(
    "jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir"
)

In [3]:
from flow_policy import fpo, rollouts

## Load configurations

In [4]:
from mujoco_playground import registry

# Load environment and default FPO config.
env_name = "CheetahRun"
env = registry.load(env_name)
env_cfg = registry.get_default_config(env_name)
config = fpo.FpoConfig()

## Initialize training state and environments

In [5]:
agent_state = fpo.FpoState.init(prng=jax.random.key(42), env=env, config=config)
rollout_state = rollouts.BatchedRolloutState.init(
    env,
    prng=jax.random.key(42),
    num_envs=config.num_envs,
)

## FPO training loop

In [6]:
import time
import numpy as onp

outer_iters = config.num_timesteps // config.iterations_per_env // config.num_envs
eval_iters = set(onp.linspace(0, outer_iters - 1, config.num_evals, dtype=int))
for i in range(outer_iters):
    # Rollout + inner training loop.
    rollout_state, transitions = rollout_state.rollout(
        agent_state,
        episode_length=config.episode_length,
        iterations_per_env=config.iterations_per_env,
    )
    agent_state, metrics = agent_state.training_step(transitions)

    # Print eval metrics.
    if i in eval_iters:
        eval_outputs = rollouts.eval_policy(
            agent_state,
            prng=jax.random.fold_in(agent_state.prng, i),
            num_envs=128,
            max_episode_length=config.episode_length,
        )
        s_np = {k: onp.array(v) for k, v in eval_outputs.scalar_metrics.items()}
        print(f"Eval metrics at FPO step {i}/{outer_iters}:")
        print(
            f"  Reward: {s_np['reward_mean']:.2f} +/- {s_np['reward_std']:.2f}"
        )

Eval metrics at FPO step 0/61:
  Reward: 21.68 +/- 4.52
Eval metrics at FPO step 6/61:
  Reward: 284.23 +/- 31.72
Eval metrics at FPO step 13/61:
  Reward: 581.11 +/- 16.62
Eval metrics at FPO step 20/61:
  Reward: 700.04 +/- 117.06
Eval metrics at FPO step 26/61:
  Reward: 753.22 +/- 100.79
Eval metrics at FPO step 33/61:
  Reward: 808.30 +/- 95.87
Eval metrics at FPO step 40/61:
  Reward: 869.48 +/- 60.77
Eval metrics at FPO step 46/61:
  Reward: 857.08 +/- 41.23
Eval metrics at FPO step 53/61:
  Reward: 784.28 +/- 132.16
Eval metrics at FPO step 60/61:
  Reward: 877.80 +/- 52.56


## Render rollouts from trained policy

In [7]:
import mediapy as media
from jax import numpy as jnp
import random
from tqdm.auto import tqdm

jit_act = jax.jit(type(agent_state).sample_action, static_argnums=(3,))
jit_step = jax.jit(env.step)
jit_reset = jax.jit(env.reset)

def render_and_show() -> None:
    rng = jax.random.key(random.randint(0, 10_000))
    
    rollout = []
    n_episodes = 1
    render_every = 1
    
    for _ in range(n_episodes):
        state = jit_reset(rng)
        rollout.append(state)
        for i in tqdm(range(100)):
            act_rng, rng = jax.random.split(rng)
            ctrl, _ = jit_act(agent_state, state.obs, act_rng, deterministic=False)
            state = jit_step(state, jnp.tanh(ctrl))
            rollout.append(state)
            
    frames = env.render(rollout[::render_every], height=480, width=640)
    media.show_video(frames, fps=1.0 / env.dt / render_every)
    media.write_video("renders/fpo_rollout.mp4", frames, fps=1.0 / env.dt / render_every)
    print("Wrote to renders/fpo_rollout.mp4")

In [8]:
render_and_show()

  0%|          | 0/100 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:00<00:00, 576.37it/s]


0
This browser does not support the video tag.


Wrote to renders/fpo_rollout.mp4
