In [1]:
import mjcf_vnl.rodent as rodent_vnl
from brax.io import mjcf as mjcf_brax
from dm_control import mjcf as mjcf_dm
import mujoco

In [2]:
from dm_control import mujoco as mujoco_dm

## Adapt the MJCF in brax and in dm_control

> TODO: We can adapt the `rodent.py`'s mjcf system to brax, to allow it to interact with the brax system.

- The MJCF in brax is pretty bare-bone, and only support loading the xml output form the mjcf package
- if we want full flexibility of quick access to joints and contact, we still need to import mjcf from dm_control, which introduce a whole new sets of dependencies.
    - better logics for contact termination, terminate when the system is unhealthy
    - simplified observation space other than the huge array of joint angles and positions

**How-to:**
- Insert the `dm_control.rodent()` (loaded by mjcf) to the brax training system. 
    - _Challenges_: we need to bind the physics to `Mjx_model` in order to pass the simulation data to dm_control.

In [3]:
# This is how brax uses MJCF
import os
_XML_PATH = os.path.join("models", "rodent_optimized.xml")
mj_model = mujoco.MjModel.from_xml_path(_XML_PATH)
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
sys = mjcf_brax.load_model(mj_model)

2024-03-20 18:00:42.601486: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [4]:
dm_rodent = rodent_vnl.Rodent(xml_path=_XML_PATH)
# this will directly get out the contacting geoms
print(dm_rodent.ground_contact_geoms[:5])
# However, we need physics object to pass in dm_control to see how we
# can have better understanding of the simulation
# might be hard to integrate with brax/mjx

(MJCF Element: <geom name="foot_L_collision" class="nonself_collision_primitive_paw" size="0.003210712207833968 0.01238417565878816" density="2041.1597822799999" pos="0.0082744640327606835 0.0022566720089347309 -0.0022566720089347309" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L0_collision" class="nonself_collision_primitive_paw" size="0.002383208424149982 0.0072532430300216834" pos="0.0084966561208825436 -0.0008496656120882542 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L1_collision" class="nonself_collision_primitive_paw" size="0.002383208424149982 0.0077713318178803739" pos="0.0080718233148384163 0.0025489968362647632 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L2_collision" class="nonself_collision_primitive_paw" size="0.002383208424149982 0.0072532430300216834" pos="0.0067973248967060336 0.0055228264785736536 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJ

In [5]:
dm_rodent.ground_contact_geoms[0]

MJCF Element: <geom name="foot_L_collision" class="nonself_collision_primitive_paw" size="0.003210712207833968 0.01238417565878816" density="2041.1597822799999" pos="0.0082744640327606835 0.0022566720089347309 -0.0022566720089347309" euler="1.570796326794897 -1.2217304763960311 0"/>

In [6]:
# system has the inertia information of this named space
sys.link.inertia.transform.pos.__len__(), len(sys.link_names)

(65, 65)

In [7]:
sys.link_names

['torso',
 'vertebra_1',
 'vertebra_2',
 'vertebra_3',
 'vertebra_4',
 'vertebra_5',
 'vertebra_6',
 'pelvis',
 'upper_leg_L',
 'lower_leg_L',
 'foot_L',
 'toe_L',
 'upper_leg_R',
 'lower_leg_R',
 'foot_R',
 'toe_R',
 'vertebra_C1',
 'vertebra_C2',
 'vertebra_C3',
 'vertebra_C4',
 'vertebra_C5',
 'vertebra_C6',
 'vertebra_C7',
 'vertebra_C8',
 'vertebra_C9',
 'vertebra_C10',
 'vertebra_C11',
 'vertebra_C12',
 'vertebra_C13',
 'vertebra_C14',
 'vertebra_C15',
 'vertebra_C16',
 'vertebra_C17',
 'vertebra_C18',
 'vertebra_C19',
 'vertebra_C20',
 'vertebra_C21',
 'vertebra_C22',
 'vertebra_C23',
 'vertebra_C24',
 'vertebra_C25',
 'vertebra_C26',
 'vertebra_C27',
 'vertebra_C28',
 'vertebra_C29',
 'vertebra_C30',
 'vertebra_cervical_5',
 'vertebra_cervical_4',
 'vertebra_cervical_3',
 'vertebra_cervical_2',
 'vertebra_cervical_1',
 'vertebra_axis',
 'vertebra_atlant',
 'skull',
 'jaw',
 'scapula_L',
 'upper_arm_L',
 'lower_arm_L',
 'hand_L',
 'finger_L',
 'scapula_R',
 'upper_arm_R',
 'lowe

## Insert the MJCF
Enabled rodent model to brax. 

**TODO:** Might also be helpful to just use the dm_control logics for reward and contact terminations.

> Review [these functions code](https://github.com/google-deepmind/dm_control/blob/b002aa350c3a48e44ecef84bd5dab6d08e1b2f74/dm_control/locomotion/tasks/corridors.py#L127) in `dm_control` where they calculate the reward and physical simulation steps.

> Maybe a named mapping {joint: index} to each joint and geom for better access. Look how [brax system might have done it](https://github.com/google/brax/blob/2329ae76759e37b0b1f1861cf34e5a67d0f7efa8/brax/base.py#L470). 

In [8]:
import jax
from jax import numpy as jp

from brax import envs
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf as mjcf_brax

import mujoco
from mujoco import mjx

import yaml
from typing import Dict, Text


**Notes:** Brax system `brax.System` is a subclass of the `mjx.Model`

**Notes:** Brax State `brax.State` is a subclass of `mjx.Data`

In [9]:
class Rodent(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(0.70, 0.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,
  ):
    # Load the rodent model via dm_control
    dm_rodent = rodent_vnl.Rodent()
    physics = mjcf_dm.Physics.from_mjcf_model(dm_rodent.mjcf_model)
    # mj_model = mujoco.MjModel.from_xml_path(params["XML_PATH"])
    mj_model = physics.model.ptr
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 2
    mj_model.opt.ls_iterations = 4

    sys = mjcf_brax.load_model(mj_model)

    physics_steps_per_control_step = 3
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step
    )
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    # initialize random position
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    # initialize the data space
    data = self.pipeline_init(qpos, qvel)

    # get observation
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0]

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action)
    reward = forward_reward + healthy_reward - ctrl_cost
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes rodent body position, velocities, and angles."""

    # external_contact_forces are excluded
    return jp.concatenate([
      jp.zeros(1)
    ])


envs.register_environment('rodent', Rodent)

In [10]:
rodent = Rodent()
jit_step, jit_reset = jax.jit(rodent.step), jax.jit(rodent.reset)
rng = jax.random.PRNGKey(0)

In [11]:
# state = jit_reset(rng)

# for _ in range(1000):
#     next_state = jit_step(state, jp.zeros(30))
#     state = next_state

KeyboardInterrupt: 

Previous cell took 18.8s to run the simulation after `jit`, which corresponds to approxi 53.2 sps (on AMD Ryzen 7 3700X 8-Core Processor / Nvidia 2070s)

11.9s to run, which corresponds to 84 sps (on i7-11700K, and A5000)


<div class="alert alert-success">
  <strong>BUG:</strong> Why does the state does not have contact field? Is it not calculating contact in mjx step? <br/>
  <strong>REASON:</strong> The state generated by `pipeline_step` should have the contact field. i.e. `pipeline_state` should have contact but not the regular state.
</div>

The `State` object returned in the step function nested another `State` object called `pipeline_state`, which contain richer information about the physical simulation. In terms of why that is the case? I don't know.


In [12]:
# instead of next_state.contact, do:
# next_state.pipeline_state.contact

<div class="alert alert-success">
  <strong>Implemented: </strong>Complete a helper function to imitate the observation space and contact termination in dm_control using the `sys.link` and `sys.link_names` mentioned above.
</div>

In [13]:
# generate the rodent model
from brax.base import System

rng = jax.random.PRNGKey(0)
rodent_brax = Rodent()

# define the jit reset/step functions
jit_reset = jax.jit(rodent_brax.reset)
jit_step = jax.jit(rodent_brax.step)

state = jit_reset(rng)

<div class="alert alert-success">
  <strong>BUG RESOLVED:</strong> The contact space is unreasonably large, currently at (4840) pairs of contacts of the rodent model. In contrast, the ant model only have (4) pairs of contacts, which is expected because the ant only have four legs.

  Charles' optimized model for the rodent has significantly less contact compare to the previous XML file.

  This will be the precursor for contact termination logics.
</div>


In [15]:
state.pipeline_state.contact.pos.shape

(42, 3)

In [29]:
state.pipeline_state.contact

Contact(dist=Array([0.02128432, 0.02099292, 0.02121474, 0.02108133, 0.0213101 ,
       0.02116716, 0.02140143, 0.02126803, 0.0193865 , 0.02018749,
       0.01933481, 0.01968523, 0.01927798, 0.01965343, 0.01926983,
       0.01962025, 0.03382435, 0.02736348, 0.01820131, 0.01872309,
       0.01868518, 0.03833075, 0.02313393, 0.03811677, 0.01435135,
       0.01431087, 0.01438749, 0.01437128, 0.01442973, 0.01443351,
       0.0182746 , 0.03778837, 0.02265196, 0.03757623, 0.01402474,
       0.01398875, 0.01401761, 0.01398232, 0.01399838, 0.01397169,
       0.01357788, 0.01314746], dtype=float32), pos=Array([[-0.05125693,  0.01235715,  0.00564216],
       [-0.02829816,  0.02164591,  0.00549646],
       [-0.0253654 ,  0.01955279,  0.00560737],
       [-0.01191798,  0.02499221,  0.00554066],
       [-0.02638887,  0.02273907,  0.00565505],
       [-0.01198093,  0.02856702,  0.00558358],
       [-0.02728597,  0.02585937,  0.00570071],
       [-0.01383855,  0.03129879,  0.00563401],
       [-0.0260

In [16]:
link_name = rodent_brax.sys.link_names
link_name.append("floor")
print(state.pipeline_state.contact.pos.shape)
geom1s, geom2s = state.pipeline_state.contact.link_idx
for g1, g2 in zip(geom1s, geom2s):
    print(f"Contact between: {link_name[g1]} and {link_name[g2]}")

(42, 3)
Contact between: floor and foot_L
Contact between: floor and foot_L
Contact between: floor and toe_L
Contact between: floor and toe_L
Contact between: floor and toe_L
Contact between: floor and toe_L
Contact between: floor and toe_L
Contact between: floor and toe_L
Contact between: floor and foot_R
Contact between: floor and foot_R
Contact between: floor and toe_R
Contact between: floor and toe_R
Contact between: floor and toe_R
Contact between: floor and toe_R
Contact between: floor and toe_R
Contact between: floor and toe_R
Contact between: floor and vertebra_C9
Contact between: floor and vertebra_C9
Contact between: floor and vertebra_C20
Contact between: floor and vertebra_C20
Contact between: floor and lower_arm_L
Contact between: floor and lower_arm_L
Contact between: floor and lower_arm_L
Contact between: floor and lower_arm_L
Contact between: floor and finger_L
Contact between: floor and finger_L
Contact between: floor and finger_L
Contact between: floor and finger_L
Co

<div class="alert alert-warning">
  <strong>TODO: </strong> Implement better observation space. In `dm_control`'s rodent model, observables only has 30 dimension.

  <strong>???: </strong> I cannot find the observation space for the virtual rodent in the `dm_control` repo... Did I missed anything? It only provides with the observation space for the element in the `dm_control.suite`. The following is the observation space for the Humanoid:

  ```python
    def get_observation(self, physics):
    """Returns a set of egocentric features."""
    obs = collections.OrderedDict()
    obs['joint_angles'] = physics.joint_angles()
    obs['head_height'] = physics.head_height()
    obs['extremities'] = physics.extremities()
    obs['torso_vertical'] = physics.torso_vertical_orientation()
    obs['com_velocity'] = physics.center_of_mass_velocity()
    obs['velocity'] = physics.velocity()
    return obs
  ```
  I can also just try to imitate this logic for our rodent friend.
</div>

In [None]:
len(dm_rodent.observables.as_dict())

30