<a href="https://colab.research.google.com/github/perrin-isir/xpag-tutorials/blob/main/train_mujoco.ipynb"> <img align="left" src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab" title="Open in Google Colaboratory"></a>
<a id="raw-url" href="https://raw.githubusercontent.com/perrin-isir/xpag-tutorials/main/train_mujoco.ipynb" download> <img align="left" src="https://img.shields.io/badge/Github-Download%20(Right%20click%20%2B%20Save%20link%20as...)-blue" alt="Download (Right click + Save link as)" title="Download Notebook"></a>

**IMPORTANT:** This notebook runs best using a GPU runtime.  
In Colab: from the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'**.

In [None]:
import os
import sys
import PIL
import importlib
from packaging import version
from IPython import get_ipython
from IPython.display import clear_output, HTML, display
import inspect
import jax
# verifying GPU backend for jax:
assert(jax.lib.xla_bridge.get_backend().platform == 'gpu')

if 'google.colab' in str(get_ipython()) and version.parse(PIL.__version__) < version.parse("9.0.1"):
    # upgrading PIL
    !pip install --upgrade Pillow
    # to avoid restarting the runtime, reload specific functions:
    module_list = []
    for m in PIL.__dict__:
        if inspect.ismodule(PIL.__dict__[m]):
            module_list.append(PIL.__dict__[m])
    for mod in module_list:
        importlib.reload(mod)
    clear_output()

In [None]:
# mujoco install:
if 'google.colab' in str(get_ipython()):
    !pip install mujoco

clear_output()

xpag import:

In [None]:
try:
    import xpag
except ImportError:
    if 'google.colab' in str(get_ipython()):
        !pip install git+https://github.com/perrin-isir/xpag.git
        clear_output()
        import xpag
    else:
        sys.exit("ImportError: xpag not found.")
from xpag.wrappers import gym_vec_env
from xpag.buffers import DefaultBuffer
from xpag.samplers import DefaultSampler
from xpag.setters import DefaultSetter
from xpag.agents import SAC
from xpag.tools import learn
from xpag.tools import mujoco_notebook_replay

# remove warnings from tensorflow_probability, a library used by the SAC agent in xpag
# ("WARNING:root:The use of `check_types` is deprecated and does not have any effect.)
import logging
logger = logging.getLogger()


class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

We first define the training and eval environments:

In [None]:
num_envs = 10  # the number of rollouts in parallel during training
env, eval_env, env_info = gym_vec_env('HalfCheetah-v4', num_envs)

We then define the agent, the sampler, the buffer and the setter (the DefaultSetter does nothing):

In [None]:
agent = SAC(
    env_info['observation_dim'] if not env_info['is_goalenv']
    else env_info['observation_dim'] + env_info['desired_goal_dim'],
    env_info['action_dim'],
    {
        "actor_lr": 3e-3,
        "critic_lr": 3e-3,
        "tau": 5e-2,
        "seed": 0
    }
)
sampler = DefaultSampler()
buffer = DefaultBuffer(
    buffer_size=1_000_000,
    sampler=sampler
)
setter = DefaultSetter()

We set the hyperparameters:

In [None]:
batch_size = 256
gd_steps_per_step = 1
start_training_after_x_steps = env_info['max_episode_steps'] * 10
max_steps = 10_000_000
evaluate_every_x_steps = 5_000
save_agent_every_x_steps = 100_000
save_dir = os.path.join(os.path.expanduser('~'), 'results', 'xpag', 'train_mujoco')
save_episode = True
plot_projection = None
seed = 0

Finally, we run the training loop.  
In Colab, it should take just over 11 seconds for every 5k steps, and usually, interesting results (reward > 7000) start to be obtained a little before 100k steps. Remark: even with fixed seeds, JAX/XLA is not deterministic on GPU, so results vary.

In [None]:
learn(
    env,
    eval_env,
    env_info,
    agent,
    buffer,
    setter,
    batch_size=batch_size,
    gd_steps_per_step=gd_steps_per_step,
    start_training_after_x_steps=start_training_after_x_steps,
    max_steps=max_steps,
    evaluate_every_x_steps=evaluate_every_x_steps,
    save_agent_every_x_steps=save_agent_every_x_steps,
    save_dir=save_dir,
    save_episode=save_episode,
    plot_projection=plot_projection,
    custom_eval_function=None,
    additional_step_keys=None,
    seed=seed
)

After stopping (manually) the training, we can replay the last evaluation episode.  
**Remark:** if this notebook is executed in Colab, the animation is really slow, because the rendering of images is done on the server side. You can use the "Generate gif" button instead, which saves the episode as a gif (in Colab, use the file-explorer pane on the left to display or download the file), or the "Generate mp4" button, which saves the episode as an mp4 video.

In [None]:
if 'google.colab' in str(get_ipython()):
    os.environ['MUJOCO_GL'] = "egl"
display(HTML('''<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"> '''))
mujoco_notebook_replay(save_dir)