In [2]:
import rodent
from brax.io import mjcf as mjcf_brax
from dm_control import mjcf as mjcf_dm
import mujoco

In [4]:
from dm_control import mujoco as mujoco_dm

In [None]:
mujoco_dm.Physics

## Adapt the MJCF in brax and in dm_control

> TODO: We can adapt the `rodent.py`'s mjcf system to brax, to allow it to interact with the brax system.

- The MJCF in brax is pretty bare-bone, and only support loading the xml output form the mjcf package
- if we want full flexibility of quick access to joints and contact, we still need to import mjcf from dm_control, which introduce a whole new sets of dependencies.
    - better logics for contact termination, terminate when the system is unhealthy
    - simplified observation space other than the huge array of joint angles and positions

**How-to:**
- Insert the `dm_control.rodent()` (loaded by mjcf) to the brax training system. 
    - _Challenges_: we need to bind the physics to `Mjx_model` in order to pass the simulation data to dm_control.

In [5]:
# This is how brax uses MJCF
mj_model = mujoco.MjModel.from_xml_path(rodent._XML_PATH)
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
sys = mjcf_brax.load_model(mj_model)

In [6]:
dm_rodent = rodent.Rodent()
# this will directly get out the contacting geoms
print(dm_rodent.ground_contact_geoms[:5])
# However, we need physics object to pass in dm_control to see how we
# can have better understanding of the simulation
# might be hard to integrate with brax/mjx

(MJCF Element: <geom name="foot_L_collision" class="collision_primitive_paw" size="0.003210712207833968 0.01238417565878816" density="2041.1597822799999" pos="0.0082744640327606835 0.0022566720089347309 -0.0022566720089347309" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L0_collision" class="collision_primitive_paw" size="0.002383208424149982 0.0072532430300216834" pos="0.0084966561208825436 -0.0008496656120882542 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L1_collision" class="collision_primitive_paw" size="0.002383208424149982 0.0077713318178803739" pos="0.0080718233148384163 0.0025489968362647632 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L2_collision" class="collision_primitive_paw" size="0.002383208424149982 0.0072532430300216834" pos="0.0067973248967060336 0.0055228264785736536 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="foot_R_c

In [18]:
dm_rodent.apply_action()

<rodent.Rodent at 0x7c95c8236120>

In [11]:
dm_rodent.observables

<dm_control.composer.observation.observable.mjcf.MJCFFeature at 0x7c95c8234ec0>

## Insert the MJCF
Enabled rodent model to brax. 

**TODO:** Might also be helpful to just use the dm_control logics for reward and contact terminations.

> Review [these functions code](https://github.com/google-deepmind/dm_control/blob/b002aa350c3a48e44ecef84bd5dab6d08e1b2f74/dm_control/locomotion/tasks/corridors.py#L127) in `dm_control` where they calculate the reward and physical simulation steps.

> Maybe a named mapping {joint: index} to each joint and geom for better access. Look how [brax system might have done it](https://github.com/google/brax/blob/2329ae76759e37b0b1f1861cf34e5a67d0f7efa8/brax/base.py#L470). 

In [3]:
import jax
from jax import numpy as jp

from brax import envs
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf_brax

import mujoco
from mujoco import mjx

import yaml
from typing import Dict, Text


In [None]:
class Rodent(PipelineEnv):

  def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(0.70, 0.0),
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,
  ):
    # Load the rodent model via dm_control
    dm_rodent = rodent.Rodent()
    physics = mjcf_dm.Physics.from_mjcf_model(dm_rodent.mjcf_model)
    # mj_model = mujoco.MjModel.from_xml_path(params["XML_PATH"])
    mj_model = physics.model.ptr
    mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
    mj_model.opt.iterations = 2
    mj_model.opt.ls_iterations = 4

    sys = mjcf_brax.load_model(mj_model)

    physics_steps_per_control_step = 3
    kwargs['n_frames'] = kwargs.get(
        'n_frames', physics_steps_per_control_step
    )
    kwargs['backend'] = 'mjx'

    super().__init__(sys, **kwargs)

    self._forward_reward_weight = forward_reward_weight
    self._ctrl_cost_weight = ctrl_cost_weight
    self._healthy_reward = healthy_reward
    self._terminate_when_unhealthy = terminate_when_unhealthy
    self._healthy_z_range = healthy_z_range
    self._reset_noise_scale = reset_noise_scale
    self._exclude_current_positions_from_observation = (
        exclude_current_positions_from_observation
    )

  def reset(self, rng: jp.ndarray) -> State:
    """Resets the environment to an initial state."""
    rng, rng1, rng2 = jax.random.split(rng, 3)

    low, hi = -self._reset_noise_scale, self._reset_noise_scale
    # initialize random position
    qpos = self.sys.qpos0 + jax.random.uniform(
        rng1, (self.sys.nq,), minval=low, maxval=hi
    )
    qvel = jax.random.uniform(
        rng2, (self.sys.nv,), minval=low, maxval=hi
    )

    # initialize the data space
    data = self.pipeline_init(qpos, qvel)

    # get observation
    obs = self._get_obs(data, jp.zeros(self.sys.nu))
    reward, done, zero = jp.zeros(3)
    metrics = {
        'forward_reward': zero,
        'reward_linvel': zero,
        'reward_quadctrl': zero,
        'reward_alive': zero,
        'x_position': zero,
        'y_position': zero,
        'distance_from_origin': zero,
        'x_velocity': zero,
        'y_velocity': zero,
    }
    return State(data, obs, reward, done, metrics)

  def step(self, state: State, action: jp.ndarray) -> State:
    """Runs one timestep of the environment's dynamics."""
    data0 = state.pipeline_state
    data = self.pipeline_step(data0, action)

    com_before = data0.subtree_com[1]
    com_after = data.subtree_com[1]
    velocity = (com_after - com_before) / self.dt
    forward_reward = self._forward_reward_weight * velocity[0]

    min_z, max_z = self._healthy_z_range
    is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
    if self._terminate_when_unhealthy:
      healthy_reward = self._healthy_reward
    else:
      healthy_reward = self._healthy_reward * is_healthy

    ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

    obs = self._get_obs(data, action)
    reward = forward_reward + healthy_reward - ctrl_cost
    done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    state.metrics.update(
        forward_reward=forward_reward,
        reward_linvel=forward_reward,
        reward_quadctrl=-ctrl_cost,
        reward_alive=healthy_reward,
        x_position=com_after[0],
        y_position=com_after[1],
        distance_from_origin=jp.linalg.norm(com_after),
        x_velocity=velocity[0],
        y_velocity=velocity[1],
    )

    return state.replace(
        pipeline_state=data, obs=obs, reward=reward, done=done
    )

  def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
  ) -> jp.ndarray:
    """Observes rodent body position, velocities, and angles."""

    # external_contact_forces are excluded
    return jp.concatenate([
    ])


envs.register_environment('rodent', Rodent)