In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

xla_flags = os.environ.get("XLA_FLAGS", "")
xla_flags += " --xla_gpu_triton_gemm_any=True"
os.environ["XLA_FLAGS"] = xla_flags
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"

In [None]:
import functools
import json
from datetime import datetime

import re
import pandas as pd
import jax
import jax.numpy as jp
import matplotlib.pyplot as plt
import mediapy as media
import mujoco
import numpy as np
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo
from etils import epath
from flax.training import orbax_utils
from IPython.display import clear_output, display
from orbax import checkpoint as ocp

from mujoco_playground import registry, wrapper
from mujoco_playground.config import locomotion_params, manipulation_params

In [None]:
def clean_string_for_filename(s):
  s = s.strip()
  s = s.replace(" ", "_")
  s = re.sub(r'[^\w_-]', '', s)
  return s


df = {
    'env': [],
    'seed': [],
    'training/walltime': [],
    'step': [],
    'eval/episode_reward': []
}

In [None]:
def run_for_env(env_name):
    n_seeds = 3
    env_cfg = registry.get_default_config(env_name)
    try:
      randomizer = registry.get_domain_randomizer(env_name)
    except:
      randomizer = None
    if env_name in registry.locomotion.ALL_ENVS:
        ppo_params = locomotion_params.brax_ppo_config(env_name)
    else:
        ppo_params = manipulation_params.brax_ppo_config(env_name)

    training_params = dict(ppo_params)
    del training_params["network_factory"]

    for i in range(n_seeds):
        print(f'Running seed: {i}')
        training_params['seed'] = i

        x_data, y_data, y_dataerr = [], [], []
        times = [datetime.now()]

        def progress(num_steps, metrics):
            clear_output(wait=True)

            times.append(datetime.now())
            x_data.append(num_steps)
            y_data.append(metrics["eval/episode_reward"])
            y_dataerr.append(metrics["eval/episode_reward_std"])

            plt.xlim([0, training_params["num_timesteps"] * 1.25])
            plt.xlabel("# environment steps")
            plt.ylabel("reward per episode")
            plt.title(f"y={y_data[-1]:.3f}")
            plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")

            df['env'].append(env_name)
            df['seed'].append(training_params['seed'])
            df['training/walltime'].append((times[-1] - times[0]).total_seconds())
            df['step'].append(num_steps)
            df['eval/episode_reward'].append(metrics["eval/episode_reward"])

            display(plt.gcf())

        train_fn = functools.partial(
          ppo.train,
          **training_params,
          network_factory=functools.partial(
              ppo_networks.make_ppo_networks,
              **ppo_params.network_factory
          ),
          progress_fn=progress,
          wrap_env_fn=wrapper.wrap_for_brax_training,
          randomization_fn=randomizer,
        )

        env = registry.load(env_name, config=env_cfg)
        eval_env = registry.load(env_name, config=env_cfg)
        make_inference_fn, params, _ = train_fn(environment=env, eval_env=eval_env)

        print(f"time to jit: {times[1] - times[0]}")
        print(f"time to train: {times[-1] - times[1]}")

In [None]:
run_for_env("Go1JoystickFlatTerrain")

In [None]:
run_for_env("LeapCubeReorient")

In [None]:
run_for_env("G1Joystick")

In [None]:
num_devices = len(jax.devices())
device_kind = jax.devices()[0].device_kind
device_topo = f'{num_devices}x {device_kind}'

In [None]:
del df['device_topo']
df = pd.DataFrame(df)
df['device_topo'] = device_topo

In [None]:
df.to_csv('../data/' + clean_string_for_filename(device_topo) + '.csv')

In [None]:
# df = {k: list(v.values()) for k, v in df.to_dict().items()}

In [None]:
df