![Brax banner](https://raw.githubusercontent.com/google/brax/main/docs/img/brax_logo.gif)

**A Colab runtime with GPU acceleration is required.** If you're using a CPU-only runtime, you can switch using the menu "Runtime > Change runtime type".









# Install MuJoCo, MJX, and Brax

In [None]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax
!pip install wandb

In [None]:
#@title Check if MuJoCo installation was successful
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

try:
  print('Checking that the installation succeeded:')
  import mujoco
  mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
  raise e from RuntimeError(
      'Something went wrong during installation. Check the shell output above '
      'for more information.\n'
      'If using a hosted Colab runtime, make sure you enable GPU acceleration '
      'by going to the Runtime menu and selecting "Choose runtime type".')

print('Installation successful.')

In [None]:
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
!pip install tqdm
import mediapy as media
import matplotlib.pyplot as plt
from tqdm import tqdm

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

In [None]:
#@title Import MuJoCo, mjx, Brax, Wandb, & Visual Tools
from datetime import datetime
import functools
from IPython.display import HTML, clear_output
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union

from brax import base
from brax import actuator
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, MjxEnv, State, PipelineEnv
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model

from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
import wandb

# Training a Policy Description

Running large batch physics simulation is useful for training RL policies. Here we demonstrate training RL policies with MJX using the RL library from [Brax](https://github.com/google/brax).

Below, we implement the classic Humanoid environment using MJX and Brax. We inherit from the `MjxEnv` implementation in Brax so that we can step the physics with MJX while training with Brax RL implementations.

***Continuous vs Discrete:*** Remanber that we are dealing with a continuous space, so all actions can be taken with decimal numbers, not discrete set of action space, it is a continuous set of action space.


## Making a Full Environment

Environment defines a few things:
1. agent (init)
2. how the agent is rewarded (step)
3. environment feedback of how the agent is doing (get_obs)
4. resetting the environment (reset)

The 3D bipedal robot is designed to simulate a human. It has a torso (abdomen) with a pair of legs and arms. The legs each consist of two links, and so the arms (representing the knees and elbows respectively). The goal of the environment is to walk forward as fast as possible without falling over.

###Action Space
The agent take a 17-element vector for actions. The action space is a continuous `(action, ...)` all in `[-1, 1]`, where `action` represents the numerical torques applied at the hinge joints.

  | Num | Action                                                                             | Control Min | Control Max | Name (in corresponding config)   | Joint | Unit         |
  |-----|------------------------------------------------------------------------------------|-------------|-------------|----------------------------------|-------|--------------|
  | 0   | Torque applied on the hinge in the y-coordinate of the abdomen                     | -1.0        | 1.0         | abdomen_yz                       | hinge | torque (N m) |
  | 1   | Torque applied on the hinge in the z-coordinate of the abdomen                     | -1.0        | 1.0         | abdomen_yz                       | hinge | torque (N m) |
  | 2   | Torque applied on the hinge in the x-coordinate of the abdomen                     | -1.0        | 1.0         | abdomen_x                        | hinge | torque (N m) |
  | 3   | Torque applied on the rotor between torso/abdomen and the right hip (x-coordinate) | -1.0        | 1.0         | right_hip_xyz (right_thigh)      | hinge | torque (N m) |
  | 4   | Torque applied on the rotor between torso/abdomen and the right hip (y-coordinate) | -1.0        | 1.0         | right_hip_xyz (right_thigh)      | hinge | torque (N m) |
  | 5   | Torque applied on the rotor between torso/abdomen and the right hip (z-coordinate) | -1.0        | 1.0         | right_hip_xyz (right_thigh)      | hinge | torque (N m) |
  | 6   | Torque applied on the rotor between the right hip/thigh and the right shin         | -1.0        | 1.0         | right_knee                       | hinge | torque (N m) |
  | 7   | Torque applied on the rotor between torso/abdomen and the left hip (x-coordinate)  | -1.0        | 1.0         | left_hip_xyz (left_thigh)        | hinge | torque (N m) |
  | 8   | Torque applied on the rotor between torso/abdomen and the left hip (y-coordinate)  | -1.0        | 1.0         | left_hip_xyz (left_thigh)        | hinge | torque (N m) |
  | 9   | Torque applied on the rotor between torso/abdomen and the left hip (z-coordinate)  | -1.0        | 1.0         | left_hip_xyz (left_thigh)        | hinge | torque (N m) |
  | 10  | Torque applied on the rotor between the left hip/thigh and the left shin           | -1.0        | 1.0         | left_knee                        | hinge | torque (N m) |
  | 11  | Torque applied on the rotor between the torso and right upper arm (coordinate -1)  | -1.0        | 1.0         | right_shoulder12                 | hinge | torque (N m) |
  | 12  | Torque applied on the rotor between the torso and right upper arm (coordinate -2)  | -1.0        | 1.0         | right_shoulder12                 | hinge | torque (N m) |
  | 13  | Torque applied on the rotor between the right upper arm and right lower arm        | -1.0        | 1.0         | right_elbow                      | hinge | torque (N m) |
  | 14  | Torque applied on the rotor between the torso and left upper arm (coordinate -1)   | -1.0        | 1.0         | left_shoulder12                  | hinge | torque (N m) |
  | 15  | Torque applied on the rotor between the torso and left upper arm (coordinate -2)   | -1.0        | 1.0         | left_shoulder12                  | hinge | torque (N m) |
  | 16  | Torque applied on the rotor between the left upper arm and left lower arm          | -1.0        | 1.0         | left_elbow                       | hinge | torque (N m) |

### Observation Space
The state space consists of positional values of different body parts of the Humanoid, followed by the velocities of those individual parts (their derivatives) with all the positions ordered before all the velocities. The observation is a `ndarray` with shape `(376,)` where the elements correspond to the following:

  | Num | Observation                                                                                                     | Min  | Max | Name (in corresponding config)   | Joint | Unit                     |
  |-----|-----------------------------------------------------------------------------------------------------------------|------|-----|----------------------------------|-------|--------------------------|
  | 0   | z-coordinate of the torso (centre)                                                                              | -Inf | Inf | root                             | free  | position (m)             |
  | 1   | w-orientation of the torso (centre)                                                                             | -Inf | Inf | root                             | free  | angle (rad)              |
  | 2   | x-orientation of the torso (centre)                                                                             | -Inf | Inf | root                             | free  | angle (rad)              |
  | 3   | y-orientation of the torso (centre)                                                                             | -Inf | Inf | root                             | free  | angle (rad)              |
  | 4   | z-orientation of the torso (centre)                                                                             | -Inf | Inf | root                             | free  | angle (rad)              |
  | 5   | z-angle of the abdomen (in lower_waist)                                                                         | -Inf | Inf | abdomen_yz                       | hinge | angle (rad)              |
  | 6   | y-angle of the abdomen (in lower_waist)                                                                         | -Inf | Inf | abdomen_yy                       | hinge | angle (rad)              |
  | 7   | x-angle of the abdomen (in pelvis)                                                                              | -Inf | Inf | abdomen_x                        | hinge | angle (rad)              |
  | 8   | x-coordinate of angle between pelvis and right hip (in right_thigh)                                             | -Inf | Inf | right_hip_xyz                    | hinge | angle (rad)              |
  | 9   | y-coordinate of angle between pelvis and right hip (in right_thigh)                                             | -Inf | Inf | right_hip_xyz                    | hinge | angle (rad)              |
  | 10  | z-coordinate of angle between pelvis and right hip (in right_thigh)                                             | -Inf | Inf | right_hip_xyz                    | hinge | angle (rad)              |
  | 11  | angle between right hip and the right shin (in right_knee)                                                      | -Inf | Inf | right_knee                       | hinge | angle (rad)              |
  | 12  | x-coordinate of angle between pelvis and left hip (in left_thigh)                                               | -Inf | Inf | left_hip_xyz                     | hinge | angle (rad)              |
  | 13  | y-coordinate of angle between pelvis and left hip (in left_thigh)                                               | -Inf | Inf | left_hip_xyz                     | hinge | angle (rad)              |
  | 14  | z-coordinate of angle between pelvis and left hip (in left_thigh)                                               | -Inf | Inf | left_hip_xyz                     | hinge | angle (rad)              |
  | 15  | angle between left hip and the left shin (in left_knee)                                                         | -Inf | Inf | left_knee                        | hinge | angle (rad)              |
  | 16  | coordinate-1 (multi-axis) angle between torso and right arm (in right_upper_arm)                                | -Inf | Inf | right_shoulder12                 | hinge | angle (rad)              |
  | 17  | coordinate-2 (multi-axis) angle between torso and right arm (in right_upper_arm)                                | -Inf | Inf | right_shoulder12                 | hinge | angle (rad)              |
  | 18  | angle between right upper arm and right_lower_arm                                                               | -Inf | Inf | right_elbow                      | hinge | angle (rad)              |
  | 19  | coordinate-1 (multi-axis) angle between torso and left arm (in left_upper_arm)                                  | -Inf | Inf | left_shoulder12                  | hinge | angle (rad)              |
  | 20  | coordinate-2 (multi-axis) angle between torso and left arm (in left_upper_arm)                                  | -Inf | Inf | left_shoulder12                  | hinge | angle (rad)              |
  | 21  | angle between left upper arm and left_lower_arm                                                                 | -Inf | Inf | left_elbow                       | hinge | angle (rad)              |
  | 22  | x-coordinate velocity of the torso (centre)                                                                     | -Inf | Inf | root                             | free  | velocity (m/s)           |
  | 23  | y-coordinate velocity of the torso (centre)                                                                     | -Inf | Inf | root                             | free  | velocity (m/s)           |
  | 24  | z-coordinate velocity of the torso (centre)                                                                     | -Inf | Inf | root                             | free  | velocity (m/s)           |
  | 25  | x-coordinate angular velocity of the torso (centre)                                                             | -Inf | Inf | root                             | free  | angular velocity (rad/s) |
  | 26  | y-coordinate angular velocity of the torso (centre)                                                             | -Inf | Inf | root                             | free  | angular velocity (rad/s) |
  | 27  | z-coordinate angular velocity of the torso (centre)                                                             | -Inf | Inf | root                             | free  | angular velocity (rad/s) |
  | 28  | z-coordinate of angular velocity of the abdomen (in lower_waist)                                                | -Inf | Inf | abdomen_z                        | hinge | angular velocity (rad/s) |
  | 29  | y-coordinate of angular velocity of the abdomen (in lower_waist)                                                | -Inf | Inf | abdomen_y                        | hinge | angular velocity (rad/s) |
  | 30  | x-coordinate of angular velocity of the abdomen (in pelvis)                                                     | -Inf | Inf | abdomen_x                        | hinge | angular velocity (rad/s) |
  | 31  | x-coordinate of the angular velocity of the angle between pelvis and right hip (in right_thigh)                 | -Inf | Inf | right_hip_xyz                    | hinge | angular velocity (rad/s) |
  | 32  | y-coordinate of the angular velocity of the angle between pelvis and right hip (in right_thigh)                 | -Inf | Inf | right_hip_z                      | hinge | angular velocity (rad/s) |
  | 33  | z-coordinate of the angular velocity of the angle between pelvis and right hip (in right_thigh)                 | -Inf | Inf | right_hip_y                      | hinge | angular velocity (rad/s) |
  | 34  | angular velocity of the angle between right hip and the right shin (in right_knee)                              | -Inf | Inf | right_knee                       | hinge | angular velocity (rad/s) |
  | 35  | x-coordinate of the angular velocity of the angle between pelvis and left hip (in left_thigh)                   | -Inf | Inf | left_hip_xyz                     | hinge | angular velocity (rad/s) |
  | 36  | y-coordinate of the angular velocity of the angle between pelvis and left hip (in left_thigh)                   | -Inf | Inf | left_hip_z                       | hinge | angular velocity (rad/s) |
  | 37  | z-coordinate of the angular velocity of the angle between pelvis and left hip (in left_thigh)                   | -Inf | Inf | left_hip_y                       | hinge | angular velocity (rad/s) |
  | 38  | angular velocity of the angle between left hip and the left shin (in left_knee)                                 | -Inf | Inf | left_knee                        | hinge | angular velocity (rad/s) |
  | 39  | coordinate-1 (multi-axis) of the angular velocity of the angle between torso and right arm (in right_upper_arm) | -Inf | Inf | right_shoulder12                 | hinge | angular velocity (rad/s) |
  | 40  | coordinate-2 (multi-axis) of the angular velocity of the angle between torso and right arm (in right_upper_arm) | -Inf | Inf | right_shoulder12                 | hinge | angular velocity (rad/s) |
  | 41  | angular velocity of the angle between right upper arm and right_lower_arm                                       | -Inf | Inf | right_elbow                      | hinge | angular velocity (rad/s) |
  | 42  | coordinate-1 (multi-axis) of the angular velocity of the angle between torso and left arm (in left_upper_arm)   | -Inf | Inf | left_shoulder12                  | hinge | angular velocity (rad/s) |
  | 43  | coordinate-2 (multi-axis) of the angular velocity of the angle between torso and left arm (in left_upper_arm)   | -Inf | Inf | left_shoulder12                  | hinge | angular velocity (rad/s) |
  | 44  | angular velocity of the angle between left upper arm and left_lower_arm                                         | -Inf | Inf | left_elbow                       | hinge | angular velocity (rad/s) |

  Additionally, after all the positional and velocity based values in the table,
  the state_space consists of (in order):

  1. *cinert:* Mass and inertia of a single rigid body relative to the center of mass (this is an intermediate result of transition). It has shape 14*10 (*nbody * 10*) and hence adds to another 140 elements in the state space.

  2. *cvel:* Center of mass based velocity. It has shape 14 * 6 (*nbody * 6*) and hence adds another 84 elements in the state space
  
  3. *qfrc_actuator:* Constraint force generated as the actuator force. This has shape `(23,)`  *(nv * 1)* and hence adds another 23 elements to the state space.

  The (x,y,z) coordinates are translational DOFs while the orientations are
  rotational DOFs expressed as quaternions.

### Rewards
The reward consists of three parts:
- *reward_alive*: Every timestep that the humanoid is alive, it gets a reward of 5.

- *forward_reward*: A reward of walking forward which is measured as *1.25 * (average center of mass before action - average center of mass after action) / dt*. *dt* is the time between actions - the default *dt = 0.015*. This reward would be positive if the humanoid walks forward (right) desired. The calculation for the center of mass is defined in the `.py` file for the Humanoid.

- *reward_quadctrl*: A negative reward for penalising the humanoid if it has too large of a control force. If there are *nu* actuators/controls, then the control has shape  `nu x 1`. It is measured as *0.1 **x** sum(control<sup>2</sup>)*.

### Starting State
All observations start in state (0.0, 0.0,  1.4, 1.0, 0.0  ... 0.0) with a uniform noise in the range of [-0.01, 0.01] added to the positional and velocity values (values in the table) for stochasticity. Note that the initial z coordinate is intentionally selected to be high, thereby indicating a standing up humanoid. The initial orientation is designed to make it face forward as well.

### Episode Termination
The episode terminates when any of the following happens:

1. The episode duration reaches a 1000 timesteps
2. The z-coordinate of the torso (index 0 in state space OR index 2 in the table) is **not** in the range `[0.8, 2.1]` (the humanoid has fallen or is about to fall beyond recovery).

# Into Backend Env Setting
Brax offers three distinct physics pipelines that are easy to swap, all of them are in the parent class of `PipelineEnv`. Everything is implemented in JAX, but the difference being one is implemented directly in brax and the other is an extension of commuinicating with MuJCo (Multi-Joint-Contact):

* ***Generalized*** calculates motion in generalized coordinates using the same accurate robot dynamics algorithms as MuJoCo.
* ***Positional*** uses Position Based Dynamics, a fast but stable method of resolving joint and collision constraints.
* ***Spring*** provides fast and cheap simulation for rapid experimentation, using simple impulse-based methods often found in video games.

Physics pipelines typically become less accurate and unstable as step size grows.  Try fiddling with step size to get a feel for each pipeline's accuracy and stability tradeoffs.

All of the above are backends implemented using JAX and implemented by Brax. However, Brax also provides an alternative routes of backends.

* ***mjx*** is the JAX adapted version of MujoCo implemented by the brax team, it comes from a different parent class of `MjxEnv` instead of the `PipelineEnv` as the other three backends. ***It is a linakge to the MuJoCo backends, not much implemented. Brax 'mjx' backend is configured to use MuJoCo as an underlaying simulator engine***.

The Environment is separated to a few stages:
1. Lowest Level: Connection between bodies with no actions. This is the physical constain of the body
2. Gym Like Environment: Book keeping of data beyond the current satge.

***Observation:***
1. It seems like the `PipelineEnv` class is caplabe of storing more piplines state, the HTML render would reach recursion upper limit when using the `MjxEnv` class but not the `PipelineEnv` Class.
3. Brax backend doesn't perform as well, but it is computationally much more efficient.
2. Not sure if generalized coordinate can take in Mujoco XML files as mjx backends does.

# Training Policy with mjx/pipeline  Env

### Humanoid Env w/ mjx Backend

In [None]:
class Humanoid(MjxEnv):
  '''
  This is greatly coustomizable of what reward you want to give: reward engineering
  '''
  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(1.0, 1.5),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,):
    '''
    Defining initilization of the agent
    '''

    path = epath.Path(epath.resource_path('mujoco')) / ('mjx/benchmark/model/humanoid')
    mj_model = mujoco.MjModel.from_xml_path((path / 'humanoid.xml').as_posix())

    # solver is an optimization system
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG

    #Iterations for solver
    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6

    # Defult framne to be 5, but can self define in kwargs
    physics_steps_per_control_step = 5
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step)

    # Parents inheritence from MjxEnv class
    super().__init__(model=mj_model, **kwargs)

    # Global vraiable for later calling them
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (exclude_current_positions_from_observation)

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""

    #Creating randome keys
    #rng = random number generator key for starting random initiation
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale

    #Vectors of generalized joint position in the configuration space
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )

    #Vectors of generalized joint velocities in the configuration space
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    data = self.pipeline_init(qpos, qvel)

    #Reset everything
    obs = self._get_obs(data.data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    #Previous Pipeline
    data0 = state.pipeline_state

    #Current pipeline state, step 1
    data = self.pipeline_step(data0, action)

    #Running forward (Velocity)
    com_before = data0.data.subtree_com[1]
    com_after = data.data.subtree_com[1]
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0] * 2

    #Height being healthy
    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)

    #Termination condition
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    #Control quad cost
    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    #Feedback from env
    obs = self._get_obs(data.data, action)
    reward = forward_reward + healthy_reward - ctrl_cost

    #Termination State
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0

    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(self, data: mjx.Data, action: jp.ndarray) -> jp.ndarray:
    """Observes humanoid body position, velocities, and angles."""
    position = data.qpos
    if self._exclude_current_positions_from_observation:
      position = position[2:]

    # external_contact_forces are excluded
    # environment observation described later
    return jp.concatenate([
        position,
        data.qvel,
        data.cinert[1:].ravel(),
        data.cvel[1:].ravel(),
        data.qfrc_actuator,
    ])

# Registering the environment setup in env as humanoid_mjx
envs.register_environment('humanoid_mjx', Humanoid)

#### Note: Obs_tracks
Observation Components:

1. ***Position***: Extracts the positions (qpos) of the humanoid body. (If self._exclude_current_positions_from_observation is True, it excludes the first two elements of the position vector. This could be useful if you want to exclude certain position information from the observation.)

2. ***Velocities***: Appends the velocities (qvel) of the humanoid body.

3. ***Inertia Matrix***: Appends the inertia matrix (cinert) excluding the first row. (This matrix represents the inertia of the body segments.) Inertia helps to examine the distribution of mass in the humanoid and then calculates the relatyionship it would have with the forces that may be generated to accelerate or deaccelerate

4. ***Velocity of Inertia***: Appends the velocity of the inertia matrix (cvel) excluding the first row.

5. ***Actuator Forces***: Appends the actuator forces (qfrc_actuator). Actuators are typically modeled as components that generate forces or torques to drive the movement of joints in a simulated robotic system. These forces or torques are applied to the joints of the simulated body, affecting its motion.

### Humanoid Env w/ Pipeline Setup

Generalized and Positional `PipelineEnv` class logs everything into the Class "[State](https://github.com/google/brax/blob/main/brax/generalized/base.py)" for saving (link to the generalized documentation)

In [None]:
class Humanoid(PipelineEnv):
  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=False,
      healthy_z_range=(1.0, 2.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      backend='generalized',
      **kwargs,
  ):
    path = epath.resource_path('brax') / 'envs/assets/humanoid.xml'
    sys = mjcf.load(path)

    n_frames = 5

    if backend in ['spring', 'positional']:
      # time step adjusted
      sys = sys.replace(dt=0.0015)
      # more frames per time step
      n_frames = 10

      # how many torque is produced by the acutator given a specific input velocity
      gear = jp.array([
          350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0, 350.0,
          350.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0])  # pyformat: disable
      sys = sys.replace(actuator=sys.actuator.replace(gear=gear))

    if backend == 'mjx':
      sys._model.opt.solver = mujoco.mjtSolver.mjSOL_NEWTON
      sys._model.opt.disableflags = mujoco.mjtDisableBit.mjDSBL_EULERDAMP
      sys._model.opt.iterations = 1
      sys._model.opt.ls_iterations = 4

    kwargs['n_frames'] = kwargs.get('n_frames', n_frames)

    super().__init__(sys=sys, backend=backend, **kwargs)

    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

  def reset(self, rng: jax.Array) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    qpos = self.sys.init_q + jax.random.uniform(
        rng1, (self.sys.q_size(),), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.qd_size(),), minval=low, maxval=hi
    )

    pipeline_state = self.pipeline_init(qpos, qvel)

    obs = self._get_obs(pipeline_state, jp.zeros(self.sys.act_size()))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(pipeline_state, obs, reward, done, metrics)

  def step(self, state: State, action: jax.Array) -> State:
    """Runs one timestep of the environment's dynamics."""

    pipeline_state0 = state.pipeline_state
    pipeline_state = self.pipeline_step(pipeline_state0, action)

    com_before, *_ = self._com(pipeline_state0)
    com_after, *_ = self._com(pipeline_state)
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * 0.5 * velocity[0]

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, 0.0, 1.0)
    is_healthy = jp.where(pipeline_state.x.pos[0, 2] > max_z, 0.0, is_healthy)

    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(pipeline_state, action)

    reward = forward_reward + healthy_reward - ctrl_cost

    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0

    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, pipeline_state: base.State, action: jax.Array
  ) -> jax.Array:
    """Observes humanoid body position, velocities, and angles."""
    position = pipeline_state.q
    velocity = pipeline_state.qd

    if self._exclude_current_positions_from_observation:
      position = position[2:]

    com, inertia, mass_sum, x_i = self._com(pipeline_state)
    cinr = x_i.replace(pos=x_i.pos - com).vmap().do(inertia)
    com_inertia = jp.hstack(
        [cinr.i.reshape((cinr.i.shape[0], -1)), inertia.mass[:, None]]
    )

    xd_i = (
        base.Transform.create(pos=x_i.pos - pipeline_state.x.pos)
        .vmap()
        .do(pipeline_state.xd)
    )
    com_vel = inertia.mass[:, None] * xd_i.vel / mass_sum
    com_ang = xd_i.ang
    com_velocity = jp.hstack([com_vel, com_ang])

    qfrc_actuator = actuator.to_tau(
        self.sys, action, pipeline_state.q, pipeline_state.qd)

    # external_contact_forces are excluded
    return jp.concatenate([
        position,
        velocity,
        com_inertia.ravel(),
        com_velocity.ravel(),
        qfrc_actuator,
    ])

  # Chaging Center of Mass to more customized standard
  def _com(self, pipeline_state: base.State) -> jax.Array:
    inertia = self.sys.link.inertia
    if self.backend in ['spring', 'positional']:
      inertia = inertia.replace(
          i=jax.vmap(jp.diag)(
              jax.vmap(jp.diagonal)(inertia.i)
              ** (1 - self.sys.spring_inertia_scale)
          ),
          mass=inertia.mass ** (1 - self.sys.spring_mass_scale),
      )
    mass_sum = jp.sum(inertia.mass)

    x_i = pipeline_state.x.vmap().do(inertia.transform)
    com = (
        jp.sum(jax.vmap(jp.multiply)(inertia.mass, x_i.pos), axis=0) / mass_sum
    )

    return com, inertia, mass_sum, x_i  # pytype: disable=bad-return-type  # jax-ndarray

envs.register_environment('humanoid_generalized', Humanoid)

## For More Robust Policy
## Domain Randomization
In Brax, domain randomization is a technique used in reinforcement learning to improve the generalization ability of a learned policy across diverse environments. The idea is to train an agent in a variety of simulated environments with different randomized parameters, such as physical properties, dynamics, or visual appearance. By exposing the agent to this diversity during training, it is expected to learn a more robust policy that can adapt well to novel environments.

***It is worth notice that domain randomization only works for mjx models.***

## Network Fatcory
Network Factory allows you to change the already implemented PPO's neurla network layer for achieveing different purposes of training. This may improve performance, but also a big cost on time complexity.

In [None]:
# @title Domain Randomization & Network Factory
from brax.training.agents.ppo import networks as ppo_networks
make_networks_factory = functools.partial(ppo_networks.make_ppo_networks, policy_hidden_layer_sizes=(64, 64, 64, 64))


def domain_randomize(sys, rng):
  """Randomizes the mjx.Model."""
  #Todo: add domain randomization functions

  @jax.vmap
  def rand(rng):
    _, key = jax.random.split(rng, 2)

    # friction
    friction = jax.random.uniform(key, (1,), minval=0.6, maxval=1.4)
    friction = sys.geom_friction.at[:, 0].set(friction)

    # actuator
    _, key = jax.random.split(key, 2)
    gain_range = (-5, 5)
    param = jax.random.uniform(
        key, (1,), minval=gain_range[0], maxval=gain_range[1]
    ) + sys.actuator_gainprm[:, 0]
    gain = sys.actuator_gainprm.at[:, 0].set(param)
    bias = sys.actuator_biasprm.at[:, 1].set(-param)
    return friction, gain, bias

  friction, gain, bias = rand(rng)

  in_axes = jax.tree_map(lambda x: None, sys)
  in_axes = in_axes.tree_replace({
      'geom_friction': 0,
      'actuator_gainprm': 0,
      'actuator_biasprm': 0,
  })

  sys = sys.tree_replace({
      'geom_friction': friction,
      'actuator_gainprm': gain,
      'actuator_biasprm': bias,
  })

  return sys, in_axes

In [None]:
# # If we wanted 10 environments with randomized friction and actuator params,
# # we can call `domain_randomize`, which returns a batched `mjx.Model` along with
# # a dictionary specifying the axes that are batched.

# rng = jax.random.PRNGKey(0)
# rng = jax.random.split(rng, 10)
# batched_sys, _ = domain_randomize(env.sys, rng)

# print('Single env friction shape: ', env.sys.geom_friction.shape)
# print('Batched env friction shape: ', batched_sys.geom_friction.shape)

# print('Friction on geom 0: ', env.sys.geom_friction[0, 0])
# print('Random frictions on geom 0: ', batched_sys.geom_friction[:, 0, 0])

# Instantiation & Visualize a Rollout

Let's instantiate the environment and visualize a short rollout. This is before the training that weights learned to do anything, it doesnt know how to stand, run, or anything. So the model would naturally fall on the ground due to physics engine simulation.

NOTE: Since episodes terminates early if the torso is below the healthy z-range, the only relevant contacts for this task are between the feet and the plane. We turn off other contacts.

Selection of 2 choices:
1. mjx_backend
2. Pipeline_backend (positiona, spring, generalized)

In [None]:
# Instantiate two different environment
env_name = 'humanoid_mjx' # @param ['humanoid_generalized','humanoid_mjx']
backend = 'positional' # @param ['generalized', 'positional', 'spring']

if env_name == 'humanoid_mjx':
  # mjx model takes in no backend argument, doens't matter
  env = envs.get_environment(env_name=env_name)

  # define the jit reset/step functions
  jit_reset = jax.jit(env.reset)
  jit_step = jax.jit(env.step)

  # initialize the state
  state = jit_reset(jax.random.PRNGKey(0))

  #Creating an container for rollout states
  rollout = [state.pipeline_state]

  # grab a trajectory
  for i in tqdm(range(100)):
    ctrl = -0.1 * jp.ones(env.sys.nu)
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

  media.show_video(env.render(rollout, camera='side'), fps=1.0 / env.dt)
  #HTML(html.render(env.sys, rollout))

else:
  env = envs.get_environment(env_name=env_name, backend=backend)

  # define the jit reset/step functions
  jit_reset = jax.jit(env.reset)
  jit_step = jax.jit(env.step)

  # initialize the state
  state = jit_reset(jax.random.PRNGKey(0))

  HTML(html.render(env.sys, [state.pipeline_state]))

## Train Humanoid Policy

Let's now train a policy with PPO to make the Humanoid run forwards. Training takes about 9-10 minutes on a Tesla A100 GPU.

There isn't really much to do here with the training loop as the brax system already compact the whole ppo training loops already. More detail can be found in the implementation of [Proximal Policy Gradient](https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py)

***The function we define here concurrently feeds back into the system of PPO that is defined by barx, creating a feedback loop between them and establishes training.***

In [None]:
config = {
    "env_name": env_name,
    "algo_name": "ppo",
    "task_name": "run",
    "num_timesteps": 30_000_000,
    "num_evals": 5,
    "episode_length": 1000,
    "num_envs": 4096,
    "batch_size": 1024,
    "num_minibatches": 32,
    "num_updates_per_batch": 8,
    "unroll_length": 10,
    "network_factory": make_networks_factory,
    "randomization_fn": domain_randomize #not added
    }

# The functools.partial() function is a function that allow us to pre-pass in parameters
# This is literally the only thing you need to write for this training
train_fn = functools.partial(
    ppo.train, num_timesteps=config['num_timesteps'], num_evals=config['num_evals'], reward_scaling=0.1,
    episode_length=config['episode_length'], normalize_observations=True, action_repeat=1,
    unroll_length=config['unroll_length'], num_minibatches=config['num_minibatches'], num_updates_per_batch=config['num_updates_per_batch'],
    discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=config['num_envs'],
    batch_size=config['batch_size'], network_factory=config['network_factory'], seed=0)

# Saving everything to Wandb
run = wandb.init(project="vnl_backend_switch", config=config)
wandb.run.name = f"{config['env_name']}_{config['task_name']}_{config['algo_name']}_brax"
def wandb_progress(num_steps, metrics):
    metrics["num_steps"] = num_steps
    wandb.log(metrics)

# Making inference & diectly use wandb as progress function
make_inference_fn, params, _= train_fn(environment=env, progress_fn=wandb_progress)

<!-- ## Save and Load Policy -->

We can save and load the policy using the brax model API.

In [None]:
#@title Save Model
model_path = '/tmp/mjx_brax_policy'
model.save_params(model_path, params)

In [None]:
#@title Load Model and Define Inference Function
params = model.load_params(model_path)

## Visualize One Policy Trajectory

In [None]:
env_name = 'humanoid_mjx' # @param ['humanoid_generalized','humanoid_mjx']
backend = 'positional' # @param ['generalized', 'positional', 'spring']

if env_name == 'humanoid_generalized':
  env = envs.create(env_name=env_name, backend=backend)

  jit_env_reset = jax.jit(env.reset)
  jit_env_step = jax.jit(env.step)
  inference_fn = make_inference_fn(params)
  jit_inference_fn = jax.jit(inference_fn)

  rollout = []
  rng = jax.random.PRNGKey(seed=1)
  state = jit_env_reset(rng=rng)
  for _ in range(1000):
    rollout.append(state.pipeline_state)
    act_rng, rng = jax.random.split(rng)
    act, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_env_step(state, act)

  HTML(html.render(env.sys.replace(dt=env.dt), rollout))

else:
  env = envs.create(env_name=env_name)

  jit_env_reset = jax.jit(env.reset)
  jit_env_step = jax.jit(env.step)
  inference_fn = make_inference_fn(params)
  jit_inference_fn = jax.jit(inference_fn)

  # initialize the state
  rng = jax.random.PRNGKey(0)
  state = jit_reset(rng)
  rollout = [state.pipeline_state]

  # grab a trajectory
  n_steps = 1000
  render_every = 1

  for i in range(n_steps):
    act_rng, rng = jax.random.split(rng)
    ctrl, _ = jit_inference_fn(state.obs, act_rng)
    state = jit_step(state, ctrl)
    rollout.append(state.pipeline_state)

    if state.done:
      break

  media.show_video(env.render(rollout[::render_every], camera='side'), fps=1.0 / env.dt / render_every)