In [1]:
import rodent
from brax.io import mjcf as mjcf_brax
import mujoco

## Adapt the MJCF in brax and in dm_control

> TODO: We can adapt the `rodent.py`'s mjcf system to brax, to allow it to interact with the brax system.

- The MJCF in brax is pretty bare-bone, and only support loading the xml output form the mjcf package
- if we want full flexibility of quick access to joints and contact, we still need to import mjcf from dm_control, which introduce a whole new sets of dependencies.
    - better logics for contact termination, terminate when the system is unhealthy
    - simplified observation space other than the huge array of joint angles and positions

**How-to:**
- Insert the `dm_control.rodent()` (loaded by mjcf) to the brax training system. 
    - _Challenges_: we need to bind the physics to `Mjx_model` in order to pass the simulation data to dm_control.

In [2]:
# This is how brax uses MJCF
mj_model = mujoco.MjModel.from_xml_path(rodent._XML_PATH)
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
sys = mjcf_brax.load_model(mj_model)

2024-03-12 10:42:06.416770: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [18]:
dm_rodent = rodent.Rodent()
# this will directly get out the contacting geoms
print(dm_rodent.ground_contact_geoms[:5])
# However, we need physics object to pass in dm_control to see how we
# can have better understanding of the simulation
# might be hard to integrate with brax/mjx

(MJCF Element: <geom name="foot_L_collision" class="collision_primitive_paw" size="0.003210712207833968 0.01238417565878816" density="2041.1597822799999" pos="0.0082744640327606835 0.0022566720089347309 -0.0022566720089347309" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L0_collision" class="collision_primitive_paw" size="0.002383208424149982 0.0072532430300216834" pos="0.0084966561208825436 -0.0008496656120882542 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L1_collision" class="collision_primitive_paw" size="0.002383208424149982 0.0077713318178803739" pos="0.0080718233148384163 0.0025489968362647632 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="toe_L2_collision" class="collision_primitive_paw" size="0.002383208424149982 0.0072532430300216834" pos="0.0067973248967060336 0.0055228264785736536 0" euler="1.570796326794897 -1.2217304763960311 0"/>, MJCF Element: <geom name="foot_R_c

In [19]:
# Connect this `apply_function` to the `step` in brax function
dm_rodent.apply_action("","","")

AttributeError: 'str' object has no attribute 'bind'