<a href="https://colab.research.google.com/github/pwspen/pterobot/blob/master/Pterobot_Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Installs
# Based on: https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb#scrollTo=d-UhypudApBy

!pip install mujoco
!pip install mujoco_mjx
!pip install brax
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy

Collecting mujoco
  Downloading mujoco-3.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
Collecting glfw (from mujoco)
  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.8/211.8 kB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: glfw, mujoco
Successfully installed glfw-2.7.0 mujoco-3.1.2
Collecting mujoco_mjx
  Downloading mujoco_mjx-3.1.2-py3-none-any.whl (10.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.6/10.6 MB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
Collecting trimesh (from mujoco_mjx)
  Downloading trimesh-4.1.6-py3-none-any.whl (690 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m690.1/690.1 kB[0m [31m29.8 MB/s[0m eta [36

In [2]:
#@title Mujoco installation check + config for GPU

from google.colab import files

import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl
Checking that the installation succeeded:
Installation successful.


In [3]:
#@title Imports

# ackages for plotting and creating graphics
import time
import itertools
import json
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx

import os
import shutil

from IPython.display import display, Javascript
def clear_output_in_colab():
    display(Javascript('''
    google.colab.output.clear()
    '''))

In [4]:
#@title Grab pterobot repo
exists = os.path.isdir('pterobot')
if exists:
  shutil.rmtree('pterobot') # Dangerous function - be careful with local use!

!git clone https://github.com/pwspen/pterobot

Cloning into 'pterobot'...
remote: Enumerating objects: 95, done.[K
remote: Counting objects: 100% (95/95), done.[K
remote: Compressing objects: 100% (71/71), done.[K
remote: Total 95 (delta 30), reused 86 (delta 21), pack-reused 0[K
Receiving objects: 100% (95/95), 1.31 MiB | 6.55 MiB/s, done.
Resolving deltas: 100% (30/30), done.


In [7]:
#@title Get pterobot env + most recent XML

try:
  shutil.copytree('pterobot/load','/load') # Lets class Pterobot have single xml pointer that works on cloud and local
except FileExistsError:
  pass

from pterobot.pterobot_brax_env import Pterobot
from pterobot.train_json_viz import create_plotly_figure

envs.register_environment('pterobot', Pterobot)

xml_file = 'pterobot/load/pterobot_v0.xml'


  and should_run_async(code)


In [None]:
#@title Environment Test (optional)

# instantiate the environment
env_name = 'pterobot'
env = envs.get_environment(env_name, xml_file=xml_file)

print(f'action size: {env.action_size}')
print(f'obs size: {env.observation_size}')
# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

# # initialize the state
# state = jit_reset(jax.random.PRNGKey(0))
# rollout = [state.pipeline_state]

# # grab a trajectory
# for i in range(10):
#   ctrl = -0.1 * jp.ones(env.sys.nu)
#   state = jit_step(state, ctrl)
#   rollout.append(state.pipeline_state)

# media.show_video(env.render(rollout, camera='side'), fps=1.0 / env.dt)

action size: 17
obs size: 45


In [None]:
#@title Train

# instantiate the environment
env_name = 'pterobot'
env = envs.get_environment(env_name, xml_file=xml_file)

train_fn = functools.partial(
    ppo.train, num_timesteps=5_000_000, num_evals=10, reward_scaling=0.1,
    episode_length=1000, normalize_observations=True, action_repeat=1,
    unroll_length=10, num_minibatches=32, num_updates_per_batch=8,
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=2048,
    batch_size=1024, seed=0)


x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]

max_y, min_y = 100, -100
metric_dict = {}
metrics_filename = 'train.json'
def progress(num_steps, metrics):
  # print(num_steps, metrics)
  # print(type(num_steps), type(metrics))
  # for name, val in metrics.items():
  #   print(f'{name}: {type(val)}')
  # print(f'timesteps: {num_steps}')
  metric_dict[num_steps] = {str(name): float(val) for name, val in metrics.items()}
  with open(metrics_filename, 'w') as f:
    json.dump(metric_dict, f, indent=4)

  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  clear_output_in_colab() # JS function
  create_plotly_figure(metrics_filename)

  # Old matplotlib
  # plt.xlim([0, max(x_data)])
  # plt.ylim([min_y, max_y])
  # plt.xlabel('# environment steps')
  # plt.ylabel('reward per episode')
  # plt.title(f'y={y_data[-1]:.3f}')
  # plt.errorbar(
  #     x_data, y_data, yerr=ydataerr)
  # plt.show()

make_inference_fn, params, _= train_fn(environment=env, progress_fn=progress)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

model_path = 'policyx.zip'
model.save_params(model_path, params)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
#@title Load and Visualize
# print(metrics)

params = model.load_params(model_path)

inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)

eval_env = envs.get_environment(env_name, xml_file=xml_file)

jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)

# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
rollout = [state.pipeline_state]

# grab a trajectory
n_steps = 5
render_every = 1

for i in range(n_steps):
  print(f'step: {i}')
  act_rng, rng = jax.random.split(rng)
  ctrl, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_step(state, ctrl)
  rollout.append(state.pipeline_state)

  if state.done:
    break

media.write_video(env.render(rollout[::render_every], camera='side'), fps=1.0 / env.dt / render_every)

ValueError: The camera "side" does not exist.

In [None]:
1!pip list

  and should_run_async(code)


Package                          Version
-------------------------------- ---------------------
absl-py                          1.4.0
aiohttp                          3.9.3
aiosignal                        1.3.1
alabaster                        0.7.16
albumentations                   1.3.1
altair                           4.2.2
annotated-types                  0.6.0
anyio                            3.7.1
appdirs                          1.4.4
argon2-cffi                      23.1.0
argon2-cffi-bindings             21.2.0
array-record                     0.5.0
arviz                            0.15.1
astropy                          5.3.4
astunparse                       1.6.3
async-timeout                    4.0.3
atpublic                         4.0
attrs                            23.2.0
audioread                        3.0.1
autograd                         1.6.2
Babel                            2.14.0
backcall                         0.2.0
beautifulsoup4                   4.12.3
bi