<a href="https://colab.research.google.com/github/perrin-isir/xpag-tutorials/blob/main/train_gmazes.ipynb"> <img align="left" src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab" title="Open in Google Colaboratory"></a>
<a id="raw-url" href="https://raw.githubusercontent.com/perrin-isir/xpag-tutorials/main/train_gmazes.ipynb" download> <img align="left" src="https://img.shields.io/badge/Github-Download%20(Right%20click%20%2B%20Save%20link%20as...)-blue" alt="Download (Right click + Save link as)" title="Download Notebook"></a>

**IMPORTANT:** This colab runs best using a GPU runtime.  
From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'**.

In [None]:
# avoid errors with recent versions of matplotlib in colab 
if 'google.colab' in str(get_ipython()):
    !pip uninstall -y matplotlib
    !pip install matplotlib==3.1.3
    os.kill(os.getpid(), 9)

Imports:

In [None]:
import os
from ipywidgets import interact
from IPython.display import display, Image, clear_output
# gym-gmazes:
try:
    import gym_gmazes
except ImportError:
    !pip install git+https://github.com/perrin-isir/gym-gmazes.git
    clear_output()
    import gym_gmazes
# jax:
import jax
# flax:
try:
    import flax
except ImportError:
    !pip install git+https://github.com/google/flax.git@v0.3.6
    clear_output()
    import brax
# brax:
try:
    import brax
except ImportError:
    !pip install git+https://github.com/google/brax.git@main
    clear_output()
    import brax
# xpag:
try:
    import xpag
except ImportError:
    !pip install git+https://github.com/perrin-isir/xpag.git
    clear_output()
    import xpag
from xpag.wrappers import gym_vec_env
from xpag.buffers import DefaultEpisodicBuffer
from xpag.samplers import DefaultEpisodicSampler, HER
from xpag.goalsetters import DefaultGoalSetter
from xpag.agents import SAC
from xpag.tools import learn

In [None]:
# verifying GPU backend for jax:
assert(jax.lib.xla_bridge.get_backend().platform == 'gpu')

We first define the training and eval environments:

In [None]:
num_envs = 10  # the number of rollouts in parallel during training
env, eval_env, env_info = gym_vec_env('GMazeGoalDubins-v0', num_envs)

We set the walls of the maze:

In [None]:
walls = [([0.0, 1.01], [0.0, -0.5])]
env.set_walls(walls)
eval_env.set_walls(walls)

We then define the agent, the buffer and the goal-setter:

In [None]:
agent = SAC(
    env_info['observation_dim'] if not env_info['is_goalenv']
    else env_info['observation_dim'] + env_info['desired_goal_dim'],
    env_info['action_dim'],
    {}
)
sampler = DefaultEpisodicSampler() if not env_info['is_goalenv'] else HER(env.compute_reward)
buffer = DefaultEpisodicBuffer(
    max_episode_steps=env_info['max_episode_steps'],
    buffer_size=1_000_000,
    sampler=sampler
)
goalsetter = DefaultGoalSetter()

We set the hyperparameters:

In [None]:
batch_size = 256
gd_steps_per_step = 1
start_training_after_x_steps = env_info['max_episode_steps'] * 10
max_steps = 10_000_000
evaluate_every_x_steps = 5_000
save_agent_every_x_steps = 100_000
save_dir = os.path.join(os.path.expanduser('~'), 'results', 'xpag', 'train_gmazes')
save_episode = True
def plot_projection(x):
    return x[0:2]

Finally, we run the training loop:

In [None]:
learn(
    env,
    eval_env,
    env_info,
    agent,
    buffer,
    goalsetter,
    batch_size,
    gd_steps_per_step,
    start_training_after_x_steps,
    max_steps,
    evaluate_every_x_steps,
    save_agent_every_x_steps,
    save_dir,
    save_episode,
    plot_projection,
)

After stopping the training, we can display the evaluation episodes.

In [None]:
@interact
def show_images(file=sorted(os.listdir(os.path.join(save_dir, 'plots')))):
    display(Image(filename = os.path.join(save_dir, 'plots', file)))