This colab runs best using a GPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'**.

In [None]:
!apt-get install -y \
    libgl1-mesa-dev \
    libgl1-mesa-glx \
    libglew-dev \
    libosmesa6-dev \
    software-properties-common

!apt-get install -y patchelf

In [None]:
!pip install gym
!pip install free-mujoco-py

In [None]:
from IPython.display import HTML, clear_output
# import jax:
import jax

In [None]:
# import brax
try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

In [None]:
# xpag install and import:
try:
    import xpag
except ImportError:
    !pip install git+https://github.com/perrin-isir/xpag.git
    clear_output()
    import xpag
from xpag.wrappers import gym_vec_env
from xpag.buffers import DefaultEpisodicBuffer
from xpag.samplers import DefaultEpisodicSampler, HER
from xpag.goalsetters import DefaultGoalSetter
from xpag.agents import SAC
from xpag.tools import learn

In [None]:
# verifying GPU backend for jax:
assert(jax.lib.xla_bridge.get_backend().platform == 'gpu')

We first define the training and eval environments:

In [None]:
num_envs = 5  # the number of rollouts in parallel during training
env, eval_env, env_info = gym_vec_env('HalfCheetah-v3', num_envs)

We then define the agent, the buffer and the goal-setter:

In [None]:
agent = SAC(
    env_info['observation_dim'] if not env_info['is_goalenv']
    else env_info['observation_dim'] + env_info['desired_goal_dim'],
    env_info['action_dim'],
    {}
)
sampler = DefaultEpisodicSampler() if not env_info['is_goalenv'] else HER()
buffer = DefaultEpisodicBuffer(
    max_episode_steps=env_info['max_episode_steps'],
    buffer_size=1_000_000,
    sampler=sampler
)
goalsetter = DefaultGoalSetter()

We set the hyperparameters:

In [None]:
batch_size = 256
gd_steps_per_step = 1
start_training_after_x_steps = env_info['max_episode_steps'] * 10
max_steps = 10_000_000
evaluate_every_x_steps = 5_000
save_agent_every_x_steps = 20_000
save_dir = None
save_episode = False
plot_projection = None

Finally, we run the training loop:

In [None]:
learn(
    env,
    eval_env,
    env_info,
    agent,
    buffer,
    goalsetter,
    batch_size,
    gd_steps_per_step,
    start_training_after_x_steps,
    max_steps,
    evaluate_every_x_steps,
    save_agent_every_x_steps,
    save_dir,
    save_episode,
    plot_projection,
)