In [6]:
import rodent
from brax.io import mjcf as mjcf_brax
from dm_control import mjcf as mjcf_dm
import mujoco

/home/scotty/.local/lib/python3.10/site-packages/glfw/__init__.py:916: GLFWError: (65544) b'X11: The DISPLAY environment variable is missing'


In [7]:
from dm_control import mujoco as mujoco_dm

## Adapt the MJCF in brax and in dm_control

> TODO: We can adapt the `rodent.py`'s mjcf system to brax, to allow it to interact with the brax system.

- The MJCF in brax is pretty bare-bone, and only support loading the xml output form the mjcf package
- if we want full flexibility of quick access to joints and contact, we still need to import mjcf from dm_control, which introduce a whole new sets of dependencies.
    - better logics for contact termination, terminate when the system is unhealthy
    - simplified observation space other than the huge array of joint angles and positions

**How-to:**
- Insert the `dm_control.rodent()` (loaded by mjcf) to the brax training system. 
    - _Challenges_: we need to bind the physics to `Mjx_model` in order to pass the simulation data to dm_control.

In [8]:
# This is how brax uses MJCF
mj_model = mujoco.MjModel.from_xml_path(rodent._XML_PATH)
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
sys = mjcf_brax.load_model(mj_model)

In [9]:
dm_rodent = rodent.Rodent()
# this will directly get out the contacting geoms
print(dm_rodent.ground_contact_geoms[:5])
# However, we need physics object to pass in dm_control to see how we
# can have better understanding of the simulation
# might be hard to integrate with brax/mjx

(MJCF Element: <geom name="foot_L_collision" class="collision_primitive_paw" size="0.003210712207833968 0.01238417565878816" density="2041.1597822799999" pos="0.0082744640327606835 0.0022566720089347309 -0.0022566720089347309" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L0_collision" class="collision_primitive_paw" size="0.002383208424149982 0.0072532430300216834" pos="0.0084966561208825436 -0.0008496656120882542 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L1_collision" class="collision_primitive_paw" size="0.002383208424149982 0.0077713318178803739" pos="0.0080718233148384163 0.0025489968362647632 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L2_collision" class="collision_primitive_paw" size="0.002383208424149982 0.0072532430300216834" pos="0.0067973248967060336 0.0055228264785736536 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="foot_R_c

In [56]:
dm_rodent.ground_contact_geoms[0]

MJCF Element: <geom name="foot_L_collision" class="collision_primitive_paw" size="0.003210712207833968 0.01238417565878816" density="2041.1597822799999" pos="0.0082744640327606835 0.0022566720089347309 -0.0022566720089347309" euler="1.570796326794897 -1.2217304763960311 0"/>

In [10]:
# system has the inertia information of this named space
sys.link.inertia.transform.pos.__len__(), len(sys.link_names)

(65, 65)

In [11]:
sys.link_names

['torso',
 'vertebra_1',
 'vertebra_2',
 'vertebra_3',
 'vertebra_4',
 'vertebra_5',
 'vertebra_6',
 'pelvis',
 'upper_leg_L',
 'lower_leg_L',
 'foot_L',
 'toe_L',
 'upper_leg_R',
 'lower_leg_R',
 'foot_R',
 'toe_R',
 'vertebra_C1',
 'vertebra_C2',
 'vertebra_C3',
 'vertebra_C4',
 'vertebra_C5',
 'vertebra_C6',
 'vertebra_C7',
 'vertebra_C8',
 'vertebra_C9',
 'vertebra_C10',
 'vertebra_C11',
 'vertebra_C12',
 'vertebra_C13',
 'vertebra_C14',
 'vertebra_C15',
 'vertebra_C16',
 'vertebra_C17',
 'vertebra_C18',
 'vertebra_C19',
 'vertebra_C20',
 'vertebra_C21',
 'vertebra_C22',
 'vertebra_C23',
 'vertebra_C24',
 'vertebra_C25',
 'vertebra_C26',
 'vertebra_C27',
 'vertebra_C28',
 'vertebra_C29',
 'vertebra_C30',
 'vertebra_cervical_5',
 'vertebra_cervical_4',
 'vertebra_cervical_3',
 'vertebra_cervical_2',
 'vertebra_cervical_1',
 'vertebra_axis',
 'vertebra_atlant',
 'skull',
 'jaw',
 'scapula_L',
 'upper_arm_L',
 'lower_arm_L',
 'hand_L',
 'finger_L',
 'scapula_R',
 'upper_arm_R',
 'lowe

## Insert the MJCF
Enabled rodent model to brax. 

**TODO:** Might also be helpful to just use the dm_control logics for reward and contact terminations.

> Review [these functions code](https://github.com/google-deepmind/dm_control/blob/b002aa350c3a48e44ecef84bd5dab6d08e1b2f74/dm_control/locomotion/tasks/corridors.py#L127) in `dm_control` where they calculate the reward and physical simulation steps.

> Maybe a named mapping {joint: index} to each joint and geom for better access. Look how [brax system might have done it](https://github.com/google/brax/blob/2329ae76759e37b0b1f1861cf34e5a67d0f7efa8/brax/base.py#L470). 

In [12]:
import jax
from jax import numpy as jp

from brax import envs
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf as mjcf_brax

import mujoco
from mujoco import mjx

import yaml
from typing import Dict, Text


**Notes:** Brax system `brax.System` is a subclass of the `mjx.Model`

**Notes:** Brax State `brax.State` is a subclass of `mjx.Data`

In [107]:
class Rodent(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(0.70, 0.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,
  ):
    # Load the rodent model via dm_control
    dm_rodent = rodent.Rodent()
    physics = mjcf_dm.Physics.from_mjcf_model(dm_rodent.mjcf_model)
    # mj_model = mujoco.MjModel.from_xml_path(params["XML_PATH"])
    mj_model = physics.model.ptr
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 2
    mj_model.opt.ls_iterations = 4

    sys = mjcf_brax.load_model(mj_model)

    physics_steps_per_control_step = 3
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step
    )
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    # initialize random position
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    # initialize the data space
    data = self.pipeline_init(qpos, qvel)

    # get observation
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0]

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action)
    reward = forward_reward + healthy_reward - ctrl_cost
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes rodent body position, velocities, and angles."""

    # external_contact_forces are excluded
    return jp.concatenate([
      jp.zeros(1)
    ])


envs.register_environment('rodent', Rodent)

In [46]:
for _ in range(1000):
    next_state = jit_step(state, jp.zeros(30))
    state = next_state

Previous cell took 18.8 to run the simulation after `jit`, which corresponds to approxi 53.2 sps (on AMD Ryzen 7 3700X 8-Core Processor / Nvidia 2070s)

TODO: Benchmark this on A5000


<div class="alert alert-success">
  <strong>BUG:</strong> Why does the state does not have contact field? Is it not calculating contact in mjx step? <br/>
  <strong>REASON:</strong> The state generated by `pipeline_step` should have the contact field. i.e. `pipeline_state` should have contact but not the regular state.
</div>

The `State` object returned in the step function nested another `State` object called `pipeline_state`, which contain richer information about the physical simulation. In terms of why that is the case? I don't know.


In [106]:
# instead of next_state.contact, do:
next_state.pipeline_state.contact

Contact(dist=Array([0.00676617, 0.0183332 , 0.03712968, ..., 1.        , 1.        ,
       1.        ], dtype=float32), pos=Array([[ 0.06609569,  0.10668103,  0.00338309],
       [ 0.05427406,  0.08824398,  0.0091666 ],
       [ 0.01166843,  0.16785012,  0.01856484],
       ...,
       [-0.04209168,  0.10930085,  0.09420121],
       [-0.04226784,  0.11437095,  0.09025861],
       [-0.04327573,  0.11223526,  0.09468883]], dtype=float32), frame=Array([[[ 0.        ,  0.        ,  1.        ],
        [ 0.53976357,  0.8418166 ,  0.        ],
        [-0.8418166 ,  0.53976357,  0.        ]],

       [[ 0.        ,  0.        ,  1.        ],
        [ 0.53976357,  0.8418166 ,  0.        ],
        [-0.8418166 ,  0.53976357,  0.        ]],

       [[ 0.        ,  0.        ,  1.        ],
        [-0.97036415, -0.24164754,  0.        ],
        [ 0.24164754, -0.97036415,  0.        ]],

       ...,

       [[ 0.8604781 ,  0.50058156, -0.09484512],
        [ 0.08198173,  0.04769272,  0.99549

In [98]:
from brax import envs
ant_env = envs.get_environment("ant", backend='mjx')
# jit_ant_reset = jax.jit(ant_env.reset)
# jit_ant_step = jax.jit(ant_env.step)
rng = jax.random.PRNGKey(0)
state = ant_env.reset(rng)


In [75]:
for _ in range(1000):
    next_ant_state = jit_ant_step(state, jp.zeros(8))
    state = next_ant_state
    try:
        print(next_ant_state.contact)
    except AttributeError: 
        pass

<div class="alert alert-success">
  <strong>Implemented: </strong>Complete a helper function to imitate the observation space and contact termination in dm_control using the `sys.link` and `sys.link_names` mentioned above.
</div>

In [109]:
# generate the rodent model
from brax.base import System

rng = jax.random.PRNGKey(0)
rodent_brax = Rodent()

# define the jit reset/step functions
jit_reset = jax.jit(rodent_brax.reset)
jit_step = jax.jit(rodent_brax.step)

state = jit_reset(rng)

<div class="alert alert-danger">
  <strong>BUG:</strong> The contact space is unreasonably large, currently at (4840) pairs of contacts of the rodent model. In contrast, the ant model only have (4) pairs of contacts, which is expected because the ant only have four legs.

  This will be the precursor for contact termination logics.
</div>


In [125]:
state.pipeline_state.contact.pos.shape

(4840, 3)

In [129]:
ant_env = envs.get_environment("ant", backend='mjx')
# jit_ant_reset = jax.jit(ant_env.reset)
# jit_ant_step = jax.jit(ant_env.step)
rng = jax.random.PRNGKey(0)
state = ant_env.reset(rng)

link_name = ant_env.sys.link_names
link_name.append("floor")
print(state.pipeline_state.contact.pos.shape)
# geom1s, geom2s = state.pipeline_state.contact.link_idx
# for g1, g2 in zip(geom1s, geom2s):
#     print(f"Contact between: {link_name[g1]} and {link_name[g2]}")

(4, 3)


<div class="alert alert-warning">
  <strong>TODO: </strong> Implement better observation space. In `dm_control`'s rodent model, observables only has 30 dimension.

  <strong>???: </strong> I cannot find the observation space for the virtual rodent in the `dm_control` repo... Did I missed anything? It only provides with the observation space for the element in the `dm_control.suite`. The following is the observation space for the Humanoid:

  ```python
    def get_observation(self, physics):
    """Returns a set of egocentric features."""
    obs = collections.OrderedDict()
    obs['joint_angles'] = physics.joint_angles()
    obs['head_height'] = physics.head_height()
    obs['extremities'] = physics.extremities()
    obs['torso_vertical'] = physics.torso_vertical_orientation()
    obs['com_velocity'] = physics.center_of_mass_velocity()
    obs['velocity'] = physics.velocity()
    return obs
  ```
  I can also just try to imitate this logic for our rodent friend.
</div>

In [150]:
len(dm_rodent.observables.as_dict())

30

In [None]:
class Rodent(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(0.70, 0.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,
  ):
    # Load the rodent model via dm_control
    dm_rodent = rodent.Rodent()
    physics = mjcf_dm.Physics.from_mjcf_model(dm_rodent.mjcf_model)
    # mj_model = mujoco.MjModel.from_xml_path(params["XML_PATH"])
    mj_model = physics.model.ptr
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 2
    mj_model.opt.ls_iterations = 4

    sys = mjcf_brax.load_model(mj_model)

    physics_steps_per_control_step = 3
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step
    )
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    self.link_name_to_idx = self.make_mapping(sys.link_names)
    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

  def make_mapping(self, lst: list) -> Dict:
    return {key: ele for key, ele in enumerate(lst)}

  def get_com(self, pipeline_state: State, link_name: str):
    if link_name not in self.link_name_to_idx.keys():
      raise KeyError("Link Name does not exist in the ")
    link_idx = self.link_name_to_idx[link_name]
    return pipeline_state.subtree_com[link_idx]

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    # initialize random position
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    # initialize the data space
    data = self.pipeline_init(qpos, qvel)

    # get observation
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0]

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action)
    reward = forward_reward + healthy_reward - ctrl_cost
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes rodent body position, velocities, and angles."""

    # external_contact_forces are excluded
    return jp.concatenate([
      jp.zeros(1)
    ])


envs.register_environment('rodent', Rodent)