In [3]:
!pip install mujoco
!pip install mujoco_mjx
!pip install brax



In [4]:
import os
import subprocess
import distutils.util

# Check for GPU (remove or modify if you're running on a CPU-only machine)
if subprocess.run('nvidia-smi', shell=True).returncode:
    raise RuntimeError(
        'Cannot communicate with GPU. '
        'Make sure you have an NVIDIA GPU with the proper drivers installed.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
    with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
        f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Set environment variable to use GPU rendering:
os.environ['MUJOCO_GL'] = 'egl'
print('Environment variable MUJOCO_GL set to:', os.environ['MUJOCO_GL'])

try:
    print('Checking that the installation succeeded:')
    import mujoco
    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise RuntimeError(
        'Something went wrong during installation. Check the terminal output for more information.\n'
        'If using a hosted runtime, make sure GPU acceleration is enabled.'
    ) from e

print('Installation successful.')

# Tell XLA to use Triton GEMM for improved performance (if applicable)
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
print('XLA_FLAGS set to:', os.environ['XLA_FLAGS'])


Tue Mar  4 17:30:13 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A4000               On  |   00000000:65:00.0  On |                  Off |
| 41%   50C    P5             23W /  140W |   13706MiB /  16376MiB |     24%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A4000               On  |   00

In [5]:
# If you haven't installed mediapy yet:
# %pip install mediapy

import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Instead of the apt-install commands, just ensure ffmpeg is installed in your environment.

import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

print("All packages imported successfully!")


All packages imported successfully!


In [6]:
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict


import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model


In [12]:
#@title Humanoid Env
#from pathlib import path
import jax.numpy as jnp    

#HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'
HUMANOID_ROOT_PATH = epath.Path("/home/wrschiff/Desktop/mujoco/Robust_Locomotion/")

class Humanoid(PipelineEnv):

   def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(1.0, 2.0),
      ########### WEIGHT FOR BALANCING TRAY ############
      #balance_tray_reward_weight=0, # setting this to 0 for now
      #terminate_when_boxfall=True,
      ##################################################
      #left_arm_pose_penalty_weight=0, #also setting this to 0 for initial policy
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,
   ):
   #
      mj_model = mujoco.MjModel.from_xml_path(os.getcwd() + '/basic_humanoid_no_tray.xml')
      mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
      mj_model.opt.iterations = 6
      mj_model.opt.ls_iterations = 6

      sys = mjcf.load_model(mj_model)

      physics_steps_per_control_step = 5
      kwargs['n_frames'] = kwargs.get(
         'n_frames', physics_steps_per_control_step)
      kwargs['backend'] = 'mjx'

      super().__init__(sys, **kwargs)

      self._forward_reward_weight = forward_reward_weight
      self._ctrl_cost_weight = ctrl_cost_weight
      self._healthy_reward = healthy_reward
      self._terminate_when_unhealthy = terminate_when_unhealthy
      self._healthy_z_range = healthy_z_range
      self._reset_noise_scale = reset_noise_scale
      ########### WEIGHT FOR BALANCING TRAY ############
      #self._balance_tray_reward_weight = balance_tray_reward_weight
      #self._terminate_when_boxfall = terminate_when_boxfall
      ##################################################
      self._exclude_current_positions_from_observation = (
         exclude_current_positions_from_observation
      )

      ########### INDICES FOR BOX AND TRAY IN XPOS ARRAY ############
      #self.tray_x_id = mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_XBODY, "tray")
      #self.box_x_id = mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_XBODY, "box")
      ###############################################################

      #self._left_arm_pose_penalty_weight = left_arm_pose_penalty_weight

      #self.target_left_arm_pose = jax.numpy.array([0.0, 0.0, 0.0])

      # Identify left arm joint indices (adjust joint names as needed)
      self.left_arm_joint_ids = jnp.array([
         mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_JOINT, "shoulder1_left"),
         mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_JOINT, "shoulder2_left"),
         mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_JOINT, "elbow_left")
      ])


   def reset(self, rng: jp.ndarray) -> State:
      """Resets the environment to an initial state."""
      rng, rng1, rng2 = jax.random.split(rng, 3)

      low, hi = -self._reset_noise_scale, self._reset_noise_scale
      qpos = self.sys.qpos0 + jax.random.uniform(
         rng1, (self.sys.nq,), minval=low, maxval=hi
      )
      qvel = jax.random.uniform(
         rng2, (self.sys.nv,), minval=low, maxval=hi
      )

      data = self.pipeline_init(qpos, qvel)

      
      obs = self._get_obs(data, jp.zeros(self.sys.nu))
      reward, done, zero = jp.zeros(3)
      metrics = {
         'forward_reward': zero,
         'reward_linvel': zero,
         'reward_quadctrl': zero,
         'reward_alive': zero,
         'x_position': zero,
         'y_position': zero,
         'distance_from_origin': zero,
         'x_velocity': zero,
         'y_velocity': zero,
      }
      return State(data, obs, reward, done, metrics)
   def step(self, state: State, action: jp.ndarray) -> State:
      """Runs one timestep of the environment's dynamics."""
      data0 = state.pipeline_state
      data = self.pipeline_step(data0, action)

      com_before = data0.subtree_com[1]
      com_after = data.subtree_com[1]
      velocity = (com_after - com_before) / self.dt
      forward_reward = self._forward_reward_weight * velocity[0]

      min_z, max_z = self._healthy_z_range
      is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
      is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
      if self._terminate_when_unhealthy:
         healthy_reward = self._healthy_reward
      else:
         healthy_reward = self._healthy_reward * is_healthy

      ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

      ############## CALCULATE BOX-TRAY REWARD ##################
      #euclid_dist_tb = jp.linalg.norm(data.x.pos[self.tray_x_id] - data.x.pos[self.box_x_id])
      #balance_cost = euclid_dist_tb * self._balance_tray_reward_weight
      ###########################################################

      obs = self._get_obs(data, action)
      ############## ADD TO OVERALL REWARD ##################
      reward = forward_reward + healthy_reward - ctrl_cost - balance_cost
      #######################################################

      reward = forward_reward + healthy_reward - ctrl_cost - balance_cost

      # Compute left arm steadiness penalty
      #left_arm_joint_angles = data.q[self.left_arm_joint_ids]      
      
      # Compute pose deviation error
      #pose_error = left_arm_joint_angles - self.target_left_arm_pose
      #pose_penalty = self._left_arm_pose_penalty_weight * jp.sum(jp.square(pose_error))
        
       # Subtract the penalty from the total reward
      #reward = reward - pose_penalty

      #print(f'CTRL COST BEFORE SCALAR (as benchmark): {ctrl_cost}')
      #print(f'EUCLID DISTANCE: {euclid_dist_tb} \t\tSCALED REWARD: {balance_cost}\t\tTOTAL REWARD: {reward}')

      ########## ADDING TERMINATION CONSTRAINT IF BOX FALLS OFF TRAY ############
      #is_balanced = data.x.pos[self.tray_x_id][2] < data.x.pos[self.box_x_id][2]
      #done = 0.0
      #done = jp.where(self._terminate_when_unhealthy, 1.0 - is_healthy, 0.0)
      #done = jp.where(
      #   (self._terminate_when_boxfall) & (done == 0.0),  # both must be True
      #   1.0 - is_balanced,
      #   done
      #)

      ###########################################################################
      # PREVIOUS METHOD: done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0

      #print(f'TRAY HEIGHT: {data.x.pos[self.tray_x_id][2]}\tBOX HEIGHT: {data.x.pos[self.box_x_id][2]}\tDONE:{done}')
      
      state.metrics.update(
         forward_reward=forward_reward,
         reward_linvel=forward_reward,
         reward_quadctrl=-ctrl_cost,
         reward_alive=healthy_reward,
         x_position=com_after[0],
         y_position=com_after[1],
         distance_from_origin=jp.linalg.norm(com_after),
         x_velocity=velocity[0],
         y_velocity=velocity[1],
      )

      return state.replace(
         pipeline_state=data, obs=obs, reward=reward, done=done
      )

   def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
   ) -> jp.ndarray:
      """Observes humanoid body position, velocities, and angles."""
      position = data.qpos
      if self._exclude_current_positions_from_observation:
         position = position[2:]

      # external_contact_forces are excluded
      return jp.concatenate([
         position,
         data.qvel,
         data.cinert[1:].ravel(),
         data.cvel[1:].ravel(),
         data.qfrc_actuator,
      ])
envs.register_environment('humanoid', Humanoid)

In [17]:
# instantiate the environment
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

In [41]:
class NetworkWrapper:
  """A simple wrapper to emulate the interface expected by PPO.

  It wraps a Flax model (a callable) along with a dummy input so that
  the model's init and apply methods can be called without additional arguments.
  """
  def __init__(self, model_fn, dummy_input):
    self.model_fn = model_fn
    self.dummy_input = dummy_input

  def init(self, key):
    return self.model_fn.init(key, self.dummy_input)

  def apply(self, params, key, inputs):
    return self.model_fn.apply(params, key, inputs)


In [44]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.linen.recurrent import OptimizedLSTMCell
from typing import Sequence, Tuple
from types import SimpleNamespace
from brax.training import distribution
from brax.training import types

class LSTMPolicy(nn.Module):
    input_dim: int       # Dimensionality of a single observation.
    hidden_dim: int      # Hidden size for the LSTM layers.
    output_dim: int      # Number of action logits.
    num_layers: int = 2  # Number of LSTM layers.

    @nn.compact
    def __call__(self, inputs, initial_state=None):
        """
        Args:
          inputs: jnp.ndarray of shape [batch, time, input_dim]
          initial_state: Optional list of LSTM states for each layer.
        Returns:
          logits: jnp.ndarray of shape [batch, output_dim]
          final_state: list of final LSTM states for each layer.
        """
        batch_size, time_steps, _ = inputs.shape

        # Initialize LSTM states if not provided.
        if initial_state is None:
            initial_state = [
                OptimizedLSTMCell.initialize_carry(
                    jax.random.PRNGKey(0),  # use a fixed key for initialization
                    (batch_size,),
                    self.hidden_dim
                )
                for _ in range(self.num_layers)
            ]
        state = initial_state

        # Unroll the LSTM over the time dimension.
        x = inputs  # shape: [batch, time, input_dim]
        for t in range(time_steps):
            xt = x[:, t, :]  # current timestep: [batch, input_dim] (or hidden_dim for subsequent layers)
            new_state = []
            h = xt
            for i in range(self.num_layers):
                # Each layer uses its own OptimizedLSTMCell.
                cell = OptimizedLSTMCell(
                    in_features=self.input_dim if i == 0 else self.hidden_dim,
                    hidden_features=self.hidden_dim,
                )
                s, h = cell(state[i], h)
                new_state.append(s)
            state = new_state
        final_output = h  # final hidden output from the last LSTM layer
        logits = nn.Dense(self.output_dim)(final_output)
        return logits, state

# Example of how to wrap this into a network factory for PPO:

def make_recurrent_ppo_networks(
    obs_shape: types.ObservationSize,
    action_size: int,
    unroll_length: int = 20,
    hidden_dim: int = 128,
    num_layers: int = 2,
    preprocess_observations_fn = lambda x: x,
) -> SimpleNamespace:
    """
    Returns a PPONetworks-like object with a recurrent LSTM-based policy network.
    The value network remains a feedforward MLP.
    """
    # Flatten the observation shape (e.g. (336,))
    obs_dim = int(jnp.prod(jnp.array(obs_shape)))
    
    # Define the recurrent policy function.
    def policy_fn(x):
        # x is expected to be of shape [batch, unroll_length, obs_dim]
        x = preprocess_observations_fn(x)
        # Instantiate the LSTMPolicy module.
        lstm_policy = LSTMPolicy(
            input_dim=obs_dim,
            hidden_dim=hidden_dim,
            output_dim=action_size,
            num_layers=num_layers,
        )
        # We ignore the final LSTM state here.
        logits, _ = lstm_policy(x)
        return logits

    # Transform the policy function so it can be initialized and applied.
    # (You may need to adjust the wrapper to match Brax’s expectations.)
    dummy_policy_input = jnp.zeros((1, unroll_length) + obs_shape)
    policy_network = NetworkWrapper(policy_fn, dummy_policy_input)

    # For the value network, we use a standard MLP.
    class ValueMLP(nn.Module):
        hidden_sizes: Sequence[int]
        
        @nn.compact
        def __call__(self, x):
            for size in self.hidden_sizes:
                x = nn.Dense(size)(x)
                x = nn.swish(x)
            x = nn.Dense(1)(x)
            return x

    def value_fn(x):
        # x is of shape [batch, obs_dim]
        mlp = ValueMLP(hidden_sizes=[256, 256])
        return mlp(x)
    
    dummy_value_input = jnp.zeros((1,) + obs_shape)
    value_network = ppo_networks.NetworkWrapper(value_fn, dummy_value_input)
    
    # The parametric action distribution can remain the same.
    parametric_action_distribution = distribution.NormalTanhDistribution(event_size=action_size)
    
    return SimpleNamespace(
        policy_network=policy_network,
        value_network=value_network,
        parametric_action_distribution=parametric_action_distribution,
    )


In [39]:
def domain_randomize(sys, rng):
  """Randomizes the mjx.Model."""
  @jax.vmap
  def rand(rng):
    _, key = jax.random.split(rng, 2)
    # friction
    friction = jax.random.uniform(key, (1,), minval=0.9, maxval=1.1)
    friction = sys.geom_friction.at[:, 0].set(friction)
    # actuator
    _, key = jax.random.split(key, 2)
    gain_range = (-2, 2)
    param = jax.random.uniform(
        key, (1,), minval=gain_range[0], maxval=gain_range[1]
    ) + sys.actuator_gainprm[:, 0]
    gain = sys.actuator_gainprm.at[:, 0].set(param)
    bias = sys.actuator_biasprm.at[:, 1].set(-param)
    return friction, gain, bias

  friction, gain, bias = rand(rng)

  in_axes = jax.tree_util.tree_map(lambda x: None, sys)
  in_axes = in_axes.tree_replace({
      'geom_friction': 0,
      'actuator_gainprm': 0,
      'actuator_biasprm': 0,
  })

  sys = sys.tree_replace({
      'geom_friction': friction,
      'actuator_gainprm': gain,
      'actuator_biasprm': bias,
  })

  return sys, in_axes

In [45]:
ckpt_path = epath.Path('/tmp/humanoid_base/ckpts')
ckpt_path.mkdir(parents=True, exist_ok=True)

def policy_params_fn(current_step, make_policy, params):
  # save checkpoints
  orbax_checkpointer = ocp.PyTreeCheckpointer()
  save_args = orbax_utils.save_args_from_target(params)
  path = ckpt_path / f'{current_step}'
  orbax_checkpointer.save(path, params, force=True, save_args=save_args)

# Use your custom recurrent network factory.
make_networks_factory = functools.partial(
    make_recurrent_ppo_networks,  # This factory builds your LSTM-based policy.
    hidden_dim=128,                # LSTM hidden state size.
    num_layers=2                   # Number of LSTM layers.
)

train_fn = functools.partial(
      ppo.train,
      num_timesteps=100_000_000,
      num_evals=10,
      reward_scaling=1,
      episode_length=1000,
      normalize_observations=True,
      action_repeat=1,
      unroll_length=20,          # Unroll length for the recurrent policy.
      num_minibatches=32,
      num_updates_per_batch=4,
      discounting=0.97,
      learning_rate=3.0e-4,
      entropy_cost=1e-2,
      num_envs=8192,
      batch_size=256,
      network_factory=make_networks_factory,  # Pass the recurrent factory.
      randomization_fn=domain_randomize,
      policy_params_fn=policy_params_fn,
      seed=0)

x_data = []
y_data = []
ydataerr = []
times = [datetime.now()]
max_y, min_y = 1000, 0

def progress(num_steps, metrics):
  times.append(datetime.now())
  x_data.append(num_steps)
  y_data.append(metrics['eval/episode_reward'])
  ydataerr.append(metrics['eval/episode_reward_std'])

  plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.25])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.title(f'y={y_data[-1]:.3f}')
  plt.errorbar(x_data, y_data, yerr=ydataerr)
  plt.show()

# Reset environments (using Brax's environment API)
env = envs.get_environment(env_name)
eval_env = envs.get_environment(env_name)

make_inference_fn, params, _ = train_fn(
    environment=env,
    progress_fn=progress,
    eval_env=eval_env
)

print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

# Save model parameters.
model_path = '/home/wrschiff/Desktop/mujoco/Robust_Locomotion/model_paths/mjx_brax_initial_policy'
model.save_params(model_path, params)


AttributeError: module 'brax.training.agents.ppo.networks' has no attribute 'NetworkWrapper'