<a href="https://colab.research.google.com/github/perrin-isir/xpag-tutorials/blob/main/train_brax.ipynb"> <img align="left" src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab" title="Open in Google Colaboratory"></a>
<a id="raw-url" href="https://raw.githubusercontent.com/perrin-isir/xpag-tutorials/main/train_brax.ipynb" download> <img align="left" src="https://img.shields.io/badge/Github-Download%20(Right%20click%20%2B%20Save%20link%20as...)-blue" alt="Download (Right click + Save link as)" title="Download Notebook"></a>

**IMPORTANT:** This notebook runs best using a GPU runtime.  
In Colab: from the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'**.

In [None]:
# restart runtime after changing package versions
import os
from packaging import version
from IPython import get_ipython
restart_runtime = False
import jax
if 'google.colab' in str(get_ipython()) and version.parse(jax.__version__) != version.parse("0.3.1"):
    # due to an issue with recent flax versions (PReLU import), we use old versions of jax, jaxlib and flax
    !pip install -I flax==0.3.6
    !pip install --user --no-cache-dir -I jax[cuda11_cudnn805]==0.3.1 -f https://storage.googleapis.com/jax-releases/jax_releases.html
    !pip install -I jaxlib==0.1.75
    restart_runtime = True
if restart_runtime:
    os.kill(os.getpid(), 9)

In [None]:
import os
from IPython.display import clear_output
import jax
import flax
# verifying GPU backend for jax:
assert(jax.lib.xla_bridge.get_backend().platform == 'gpu')

Brax and xpag imports:

In [None]:
# brax:
try:
    import brax
except ImportError:
    if 'google.colab' in str(get_ipython()):
        !pip install git+https://github.com/google/brax.git@v0.0.10
    else:
        print("brax not found.")
    clear_output()
    import brax
# xpag:
try:
    import xpag
except ImportError:
    if 'google.colab' in str(get_ipython()):
        !pip install git+https://github.com/perrin-isir/xpag.git
    else:
        print("xpag not found.")
    clear_output()
    import xpag
from xpag.wrappers import brax_vec_env
from xpag.buffers import DefaultEpisodicBuffer
from xpag.samplers import DefaultEpisodicSampler, HER
from xpag.goalsetters import DefaultGoalSetter
from xpag.agents import TD3
from xpag.tools import learn, brax_notebook_replay

We first define the training and eval environments:

In [None]:
num_envs = 50  # the number of rollouts in parallel during training
env, eval_env, env_info = brax_vec_env('walker2d', num_envs)

We then define the agent, the sampler, the buffer and the goal-setter:

In [None]:
agent = TD3(
    env_info['observation_dim'] if not env_info['is_goalenv']
    else env_info['observation_dim'] + env_info['desired_goal_dim'],
    env_info['action_dim'],
    {}
)
sampler = DefaultEpisodicSampler() if not env_info['is_goalenv'] else HER(env.compute_reward)
buffer = DefaultEpisodicBuffer(
    max_episode_steps=env_info['max_episode_steps'],
    buffer_size=1_000_000,
    sampler=sampler
)
goalsetter = DefaultGoalSetter()

We set the hyperparameters:

In [None]:
batch_size = 256
gd_steps_per_step = 1
start_training_after_x_steps = env_info['max_episode_steps'] * 10
max_steps = 10_000_000
evaluate_every_x_steps = 5_000
save_agent_every_x_steps = 100_000
save_dir = os.path.join(os.path.expanduser('~'), 'results', 'xpag', 'train_brax')
save_episode = True
plot_projection = None

Finally, we run the training loop.  
In Colab, it should take just over 9 seconds for every 5k steps. Usually, interesting results (reward > 1000) start to occur after around 150k steps,
but the variance of evaluation rollouts is high.

In [None]:
learn(
    env,
    eval_env,
    env_info,
    agent,
    buffer,
    goalsetter,
    batch_size,
    gd_steps_per_step,
    start_training_after_x_steps,
    max_steps,
    evaluate_every_x_steps,
    save_agent_every_x_steps,
    save_dir,
    save_episode,
    plot_projection,
)

After stopping the training, we can replay the last evaluation episode.

In [None]:
brax_notebook_replay(save_dir)