# Cross Validation With COMIC Environment From dm_control
Data is where intelligence comes form, not just fancy math, we need to make sure the integrity of the observation space, or the data. When we have one that is already available to compares with (COMIC setup) we should use it to garantee that the setup is correct. **This notebook is for checking the integirty of the observation space.**

In [1]:
%load_ext autoreload
%autoreload 2
import jax
from jax import random
from envs.humanoid_cmu_validate import HumanoidTracking
import numpy as np
import mediapy as media
import jax.numpy as jp
import mujoco
from dm_control import mjcf

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

# CMU Humanoid env

In [2]:
_CMU_MOCAP_JOINTS = (
    'lfemurrz', 'lfemurry', 'lfemurrx', 'ltibiarx', 'lfootrz', 'lfootrx',
    'ltoesrx', 'rfemurrz', 'rfemurry', 'rfemurrx', 'rtibiarx', 'rfootrz',
    'rfootrx', 'rtoesrx', 'lowerbackrz', 'lowerbackry', 'lowerbackrx',
    'upperbackrz', 'upperbackry', 'upperbackrx', 'thoraxrz', 'thoraxry',
    'thoraxrx', 'lowerneckrz', 'lowerneckry', 'lowerneckrx', 'upperneckrz',
    'upperneckry', 'upperneckrx', 'headrz', 'headry', 'headrx', 'lclaviclerz',
    'lclaviclery', 'lhumerusrz', 'lhumerusry', 'lhumerusrx', 'lradiusrx',
    'lwristry', 'lhandrz', 'lhandrx', 'lfingersrx', 'lthumbrz', 'lthumbrx',
    'rclaviclerz', 'rclaviclery', 'rhumerusrz', 'rhumerusry', 'rhumerusrx',
    'rradiusrx', 'rwristry', 'rhandrz', 'rhandrx', 'rfingersrx', 'rthumbrz',
    'rthumbrx')
len(_CMU_MOCAP_JOINTS)

56

Checking the body positions and make sure they are aligned

In [3]:
root = mjcf.from_path("./assets/humanoid_CMU_V2019.xml")

physics = mjcf.Physics.from_mjcf_model(root).model.ptr

physics.body_pos.shape

physics.name_bodyadr

mj_model = mujoco.MjModel.from_xml_path("./assets/humanoid_CMU_V2019.xml")

params = np.array([
    # 'world',
    # 'root',
    'lhipjoint', 'lfemur', 'ltibia', 'lfoot', 'ltoes', 'rhipjoint', 'rfemur', 'rtibia', 'rfoot',
    'rtoes', 'lowerback', 'upperback', 'thorax', 'lowerneck', 'upperneck', 'head', 'lclavicle', 'lhumerus',
    'lradius', 'lwrist', 'lhand', 'lfingers', 'lthumb', 'rclavicle', 'rhumerus', 'rradius', 'rwrist',
    'rhand', 'rfingers', 'rthumb'])

jp.array([mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("body"), body) for body in params])

Array([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
       19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], dtype=int32)

In [4]:
params = np.array(
    ['lowerbackrz', 'lowerbackry', 'lowerbackrx', 'upperbackrz', 'upperbackry', 'upperbackrx', 'thoraxrz',
     'thoraxry', 'thoraxrx', 'lowerneckrz', 'lowerneckry', 'lowerneckrx', 'upperneckrz', 'upperneckry', 'upperneckrx', 'headrz', 'headry',
     'headrx', 'lclaviclerz', 'lclaviclery', 'lhumerusrz', 'lhumerusry', 'lhumerusrx', 'lradiusrx', 'lwristry', 'lhandrz', 'lhandrx', 'lfingersrx',
     'lthumbrz', 'lthumbrx', 'rclaviclerz', 'rclaviclery', 'rhumerusrz', 'rhumerusry', 'rhumerusrx', 'rradiusrx', 'rwristry', 'rhandrz',
     'rhandrx', 'rfingersrx', 'rthumbrz', 'rthumbrx', 'lfemurrz', 'lfemurry', 'lfemurrx', 'ltibiarx', 'lfootrz', 'lfootrx',
     'ltoesrx', 'rfemurrz', 'rfemurry', 'rfemurrx', 'rtibiarx', 'rfootrz', 'rfootrx', 'rtoesrx'])

jp.array([mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("joint"), joint) for joint in params])

Array([15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
       32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
       49, 50, 51, 52, 53, 54, 55, 56,  1,  2,  3,  4,  5,  6,  7,  8,  9,
       10, 11, 12, 13, 14], dtype=int32)

In [5]:
params = np.array(
    ['lfemurrz', 'lfemurry', 'lfemurrx', 'ltibiarx', 'lfootrz', 'lfootrx',
    'ltoesrx', 'rfemurrz', 'rfemurry', 'rfemurrx', 'rtibiarx', 'rfootrz',
    'rfootrx', 'rtoesrx', 'lowerbackrz', 'lowerbackry', 'lowerbackrx',
    'upperbackrz', 'upperbackry', 'upperbackrx', 'thoraxrz', 'thoraxry',
    'thoraxrx', 'lowerneckrz', 'lowerneckry', 'lowerneckrx', 'upperneckrz',
    'upperneckry', 'upperneckrx', 'headrz', 'headry', 'headrx', 'lclaviclerz',
    'lclaviclery', 'lhumerusrz', 'lhumerusry', 'lhumerusrx', 'lradiusrx',
    'lwristry', 'lhandrz', 'lhandrx', 'lfingersrx', 'lthumbrz', 'lthumbrx',
    'rclaviclerz', 'rclaviclery', 'rhumerusrz', 'rhumerusry', 'rhumerusrx',
    'rradiusrx', 'rwristry', 'rhandrz', 'rhandrx', 'rfingersrx', 'rthumbrz',
    'rthumbrx'])

jp.array([mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("joint"), joint) for joint in params])

Array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
       35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
       52, 53, 54, 55, 56], dtype=int32)

no freejoint `root` in joints for reference_joints list

In [6]:
params = {
    "solver": "cg",
    "iterations": 6,
    "ls_iterations": 6,
    "clip_path": "clips/test_traj.p",
    "body_names":
    ['lhipjoint', 'lfemur', 'ltibia', 'lfoot', 'ltoes', 'rhipjoint', 'rfemur', 'rtibia', 'rfoot',
    'rtoes', 'lowerback', 'upperback', 'thorax', 'lowerneck', 'upperneck', 'head', 'lclavicle', 'lhumerus',
    'lradius', 'lwrist', 'lhand', 'lfingers', 'lthumb', 'rclavicle', 'rhumerus', 'rradius', 'rwrist',
    'rhand', 'rfingers', 'rthumb'],
    "joint_names":
    ['headry', 'headrz', 'upperneckrx', 'lclaviclerz', 'headrx',
       'lfemurry', 'lfemurrz', 'root', 'lhandrx', 'lfootrz', 'ltibiarx',
       'lhandrz', 'lwristry', 'lhumerusry', 'lhumerusrz', 'lclaviclery',
       'lowerbackry', 'lowerbackrz', 'rtoesrx', 'lowerneckry',
       'lowerneckrz', 'thoraxrx', 'lhumerusrx', 'lthumbrz', 'lfingersrx',
       'lfemurrx', 'lfootrx', 'lradiusrx', 'rclaviclerz', 'lthumbrx',
       'rfemurry', 'rfemurrz', 'ltoesrx', 'rhandrx', 'rfootrz',
       'rtibiarx', 'rhandrz', 'rwristry', 'rhumerusry', 'rhumerusrz',
       'rclaviclery', 'rhumerusrx', 'rthumbrz', 'rfingersrx', 'rfemurrx',
       'rfootrx', 'rradiusrx', 'thoraxry', 'thoraxrz', 'upperbackrx',
       'upperbackry', 'upperbackrz', 'lowerbackrx', 'upperneckry',
       'upperneckrz', 'lowerneckrx']
}

In [7]:
env_brax = HumanoidTracking(params)

In [8]:
env_brax.sys.nv # nq=63, nv=62

62

In [9]:
lst =[]
for n in range(100):
      joint = mujoco.mj_id2name(env_brax.sys.mj_model, 
                          mujoco.mju_str2Type("joint"), 
                          n)
      lst.append(joint)
# root joint is a free joint in any direction (adds extra dimensions), rotation locks, quaternion
# something for full range of motion

lst[:5]

['root', 'lfemurrz', 'lfemurry', 'lfemurrx', 'ltibiarx']

In [10]:
state = env_brax.reset(jax.random.PRNGKey(0))

## Stepping brax environment

In [11]:
# key = random.PRNGKey(0)
# jit_step = jax.jit(env_brax.step)
# jit_reset = jax.jit(env_brax.reset)

In [12]:
# next = jit_reset(key)
# print("starting rollout")
# mu = 0
# sigma = .1
# rollout = []
# for _ in range(500):
#     _, key = jax.random.split(key)
#     next = jit_step(next, mu + sigma * random.normal(key, shape=(env.sys.nu,)))
#     rollout.append(next)

In [13]:
# import mediapy as media
# import os

# os.environ["MUJOCO_GL"] = "glfw"
# rollout_data = [s.pipeline_state for s in rollout]

# video = env.render(rollout_data, camera='side')

# media.show_video(video, fps=1.0 / env.dt)

# COMIC Validation Observation

In [14]:
from dm_control import composer
from dm_control.locomotion import arenas
from dm_control.locomotion import walkers
from walker import Rat
from walker_validate_cmu import CMUHumanoidPositionControlled
from dm_control.locomotion.mocap import props
from dm_control.locomotion.tasks.reference_pose import tracking
from dm_control.locomotion.tasks.reference_pose import types
from dm_control.utils import io as resources
import os
import h5py
import pickle

maybe reset doesn't replace values in dm control? instantiate it in init? The clip is actually instantiated in init with flag `always_init_at_clip_start=True`

In [15]:
# Define the path to your motion capture data file
file_name = 'clips/test_traj.h5' #'clips/all_snips_250.h5'
current_directory = os.getcwd()
TEST_FILE_PATH = os.path.join(current_directory, file_name)

with h5py.File(TEST_FILE_PATH, 'r') as f:
    dataset_keys = tuple(f.keys())
    dataset = types.ClipCollection(ids=dataset_keys,)

# Set up the mocap tracking task
task = tracking.MultiClipMocapTracking(
    walker=CMUHumanoidPositionControlled,
    arena=arenas.Floor(),
    ref_path=resources.GetResourceFilename(TEST_FILE_PATH),
    ref_steps=(1, 2, 3, 4, 5),
    min_steps=1,
    dataset=dataset,
    reward_type='comic',
    always_init_at_clip_start=True
)

# Initialize the environment
env = composer.Environment(task=task)
reset = env.reset()

# action_spec = env.action_spec()
# env.step(np.zeros(action_spec.shape))

In [16]:
reset.observation['walker/reference_rel_bodies_pos_global'].shape

(1, 450)

In [17]:
reset.observation['walker/reference_rel_root_pos_local'].shape

(1, 15)

In [18]:
reset.observation['walker/reference_rel_joints'].shape # 280 / 5 = 56 makes sense

(1, 280)

If the joints in brax environment is indexed with shape mismatch, no warning is given, `np` does not compain and dimension becomes (56,57) and give a flattened dimension of 3000ish, which is not the correct one. **Should use `[:,env_brax._joint_orders]`, it is a 2d indexing and also the `root` joint should not be included, it is there for mjx.**

In [19]:
env.task._walker_features.keys()

dict_keys(['position', 'quaternion', 'joints', 'center_of_mass', 'end_effectors', 'appendages', 'body_positions', 'body_quaternions', 'velocity', 'angular_velocity', 'joints_velocity'])

In [20]:
[item for item in reset.observation]

['reference_props_pos_global',
 'reference_props_quat_global',
 'walker/actuator_activation',
 'walker/appendages_pos',
 'walker/body_height',
 'walker/end_effectors_pos',
 'walker/joints_pos',
 'walker/joints_vel',
 'walker/sensors_accelerometer',
 'walker/sensors_force',
 'walker/sensors_gyro',
 'walker/sensors_torque',
 'walker/sensors_touch',
 'walker/sensors_velocimeter',
 'walker/world_zaxis',
 'walker/reference_rel_joints',
 'walker/reference_rel_bodies_pos_global',
 'walker/reference_rel_bodies_quats',
 'walker/reference_rel_bodies_pos_local',
 'walker/reference_ego_bodies_quats',
 'walker/reference_rel_root_quat',
 'walker/reference_rel_root_pos_local',
 'walker/reference_appendages_pos',
 'walker/clip_id',
 'walker/velocimeter_control',
 'walker/gyro_control',
 'walker/joints_vel_control',
 'walker/time_in_clip']

In [21]:
check_lst = ['walker/reference_rel_bodies_pos_local',
             'walker/reference_rel_bodies_pos_global',
             'walker/reference_rel_root_pos_local',
             'walker/reference_rel_joints'
            ]

COMIC_lst = reset.observation[check_lst[0]]
for check in check_lst[1:]:
    COMIC_lst = np.concatenate((COMIC_lst, reset.observation[check]), axis=1)

COMIC_lst

array([[ 8.62317918e-06, -2.78480039e-02, -6.42844045e-02, ...,
         3.10079384e-01, -1.54169243e-03,  5.53967392e-05]])

In [22]:
state.obs.shape

(1395,)

In [23]:
COMIC_lst.shape

(1, 1195)

# Check reference_rel_bodies_pos_local
Sample check, rest of the functions follow this

## Check qpos matching: position is first 3 initilizatio same

In [24]:
# env.task._clip_reference_features

In [25]:
env.physics.data.qpos

array([0.        , 0.        , 0.94      , 0.46075298, 0.53638297,
       0.53638297, 0.46075298, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        ])

In [26]:
state.pipeline_state.qpos

Array([0.        , 0.        , 0.94      , 0.46075296, 0.536383  ,
       0.536383  , 0.46075296, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        ], dtype=float32)

## Each Transformation Checking (body index check)
it is probably a ordering issue? al;l the numbers looks the same, but the ordering may be what is causing the issue

In [27]:
time_steps = env.task._time_step + env.task._ref_steps
obs = env.task._walker.transform_vec_to_egocentric_frame(env.physics, (env.task._clip_reference_features['body_positions'][time_steps] -
                  env.task._walker_features['body_positions'])[:, env.task._body_idxs])
obs_flattened = np.concatenate([o.flatten() for o in obs])

env.task._walker_features['body_positions'].flatten()[:10]

array([ 0.        ,  0.        ,  0.94      ,  0.03503343,  0.101937  ,
        0.84178506, -0.02235977,  0.24043609,  0.46561425, -0.07985898])

Both from brax environments:
1. Sending to brax then grab body positions `xpos` flattened:

    ```python
    [ 0.02206819 -0.38698062  0.05154771  0.03503342  0.10193701  0.8417851
    -0.02235974  0.24043609  0.46561432 -0.07985894]
    ```

2. Grabbing body positions then sending into brax `xpos` flattened:

    ```python
    [ 0.          0.          0.94        0.03503342  0.10193701  0.8417851
    -0.02235974  0.24043609  0.46561432 -0.07985894]
    ```

Brax environment instantiation does aweird reordering that distort the original indexes

qpos values are the same but xpos are different, body position indexing wrong

In [28]:
env_brax._body_locations

Array([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
       19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], dtype=int32)

positions grabbing now equivalence with the one on the top directly calling from notebook

## Each Transformation Checking (check ref_traj as well)

In [29]:
env.task._clip_reference_features['body_positions'].flatten()[:10]

array([ 0.        ,  0.        ,  0.94      ,  0.03503343,  0.101937  ,
        0.84178506, -0.02235977,  0.24043609,  0.46561425, -0.07985898])

In [30]:
env_brax._ref_traj.body_positions.flatten()[:10]

Array([ 0.        ,  0.        ,  0.94      ,  0.03503343,  0.101937  ,
        0.8417851 , -0.02235977,  0.24043609,  0.46561426, -0.07985898],      dtype=float32)

Similar as well, but...

There is a extra step from `_ref_traj.body_positions` to `_walker_features['body_positions']`, the data still match up

In [31]:
env.task._walker_features['body_positions'].flatten()[:10]

array([ 0.        ,  0.        ,  0.94      ,  0.03503343,  0.101937  ,
        0.84178506, -0.02235977,  0.24043609,  0.46561425, -0.07985898])

## Each Transformation Checking (differences check)

In [32]:
(env.task._clip_reference_features['body_positions'][time_steps] -
                  env.task._walker_features['body_positions'])[:, env.task._body_idxs].flatten()[:10]

array([-6.77492290e-02,  8.62317918e-06, -1.78336186e-02, -1.07166655e-01,
        3.48603853e-05, -2.37769555e-02,  1.68382568e-01, -2.82965011e-02,
        2.95138832e-02,  8.94642469e-03])

In [33]:
def f(x):
      if len(x.shape) != 1:
        return jax.lax.dynamic_slice_in_dim(
          x, 
          state.info['cur_frame'] + 1, 
          env_brax._ref_traj_length, 
        )
      return jp.array([])

In [34]:
ref_traj = jax.tree_util.tree_map(f, env_brax._ref_traj)

xpos_flatten = state.pipeline_state.xpos[env_brax._body_locations].flatten()

(ref_traj.body_positions - xpos_flatten).reshape([env_brax._ref_traj_length, 30, 3]).flatten()[:10]

Array([-6.7749232e-02,  8.6231794e-06, -1.7833591e-02, -1.0716664e-01,
        3.4846365e-05, -2.3776948e-02,  1.6838254e-01, -2.8296500e-02,
        2.9513806e-02,  8.9463890e-03], dtype=float32)

The differences is the same, may be the problem with transformation then

## Each Transformation Checking (transform function check)

In [35]:
obs = env.task._walker.transform_vec_to_egocentric_frame(env.physics,
                                                         (env.task._clip_reference_features['body_positions'][time_steps] - 
                                                          env.task._walker_features['body_positions'])[:, env.task._body_idxs])

obs_flattened = np.concatenate([o.flatten() for o in obs])
obs_flattened[:10]

array([ 8.62317918e-06, -2.78480039e-02, -6.42844045e-02,  3.48603853e-05,
       -3.96685538e-02, -1.02354489e-01, -2.82965011e-02,  5.45728515e-02,
        1.62004822e-01, -4.03528938e-03])

brax environment after transformation, seems not to be the same, they seems to not be changed at all, `xmat` currently identity matrix
```python
[-6.7749232e-02  8.6231794e-06 -1.7833591e-02 -1.0716664e-01
  3.4846365e-05 -2.3776948e-02  1.6838254e-01 -2.8296500e-02
  2.9513806e-02  8.9463890e-03]
```

In [36]:
def global_vector_to_local_frame(data, vec_in_world_frame):
    xmat = jp.reshape(data.xmat[1], (3, 3))
    print(xmat)
    if vec_in_world_frame.shape[-1] == 2:
      return jp.dot(vec_in_world_frame, xmat[:2, :2])
    elif vec_in_world_frame.shape[-1] == 3:
      return jp.dot(vec_in_world_frame, xmat)
    else:
      raise ValueError('`vec_in_world_frame` should have shape with final '
                       'dimension 2 or 3: got {}'.format(
                           vec_in_world_frame.shape))

In [37]:
ref_traj = jax.tree_util.tree_map(f, env_brax._ref_traj)

xpos_flatten = state.pipeline_state.xpos[env_brax._body_locations].flatten()

obs = global_vector_to_local_frame(state.pipeline_state,
                             (ref_traj.body_positions - xpos_flatten)
                             .reshape([env_brax._ref_traj_length, 30, 3]))

obs_flattened = jp.concatenate([o.flatten() for o in obs])
obs_flattened[:10]

[[ 0.          0.15082681  0.9885602 ]
 [ 1.          0.          0.        ]
 [ 0.          0.9885602  -0.15082681]]


Array([ 8.62317938e-06, -2.78479792e-02, -6.42844066e-02,  3.48463655e-05,
       -3.96685489e-02, -1.02354474e-01, -2.82965004e-02,  5.45727760e-02,
        1.62004814e-01, -4.03526425e-03], dtype=float32)

In [38]:
state.pipeline_state.xmat[1]

Array([[ 0.        ,  0.15082681,  0.9885602 ],
       [ 1.        ,  0.        ,  0.        ],
       [ 0.        ,  0.9885602 , -0.15082681]], dtype=float32)

COMIC paper uses the `xmat[1]`, the second `xmat` transformation matrix, might need this to be a changing values

## Comparing Obs

In [39]:
brax_body_pos_local = state.obs[:450]
brax_body_pos_local[:10]

Array([ 8.62317938e-06, -2.78479792e-02, -6.42844066e-02,  3.48463655e-05,
       -3.96685489e-02, -1.02354474e-01, -2.82965004e-02,  5.45727760e-02,
        1.62004814e-01, -4.03526425e-03], dtype=float32)

In [40]:
COMIC_body_pos_local = obs_flattened[:450]
COMIC_body_pos_local[:10]

Array([ 8.62317938e-06, -2.78479792e-02, -6.42844066e-02,  3.48463655e-05,
       -3.96685489e-02, -1.02354474e-01, -2.82965004e-02,  5.45727760e-02,
        1.62004814e-01, -4.03526425e-03], dtype=float32)

In [41]:
COMIC_body_pos_local = reset.observation['walker/reference_rel_bodies_pos_local'].flatten()
COMIC_body_pos_local[:10]

array([ 8.62317918e-06, -2.78480039e-02, -6.42844045e-02,  3.48603853e-05,
       -3.96685538e-02, -1.02354489e-01, -2.82965011e-02,  5.45728515e-02,
        1.62004822e-01, -4.03528938e-03])

In [42]:
np.isclose(brax_body_pos_local,
           COMIC_body_pos_local,
           atol=1e-04).sum()

450

Local being the same meaning global is the same

# Check reference_rel_bodies_pos_global

In [43]:
reset.observation['walker/reference_rel_bodies_pos_global'].flatten()[:10]

array([-6.77492290e-02,  8.62317918e-06, -1.78336186e-02, -1.07166655e-01,
        3.48603853e-05, -2.37769555e-02,  1.68382568e-01, -2.82965011e-02,
        2.95138832e-02,  8.94642469e-03])

In [44]:
ref_traj = jax.tree_util.tree_map(f, env_brax._ref_traj)
xpos_flatten = state.pipeline_state.xpos[env_brax._body_locations].flatten()
obs = (ref_traj.body_positions - xpos_flatten)
obs_flattened = obs.flatten()
obs_flattened[:10]

Array([-6.7749232e-02,  8.6231794e-06, -1.7833591e-02, -1.0716664e-01,
        3.4846365e-05, -2.3776948e-02,  1.6838254e-01, -2.8296500e-02,
        2.9513806e-02,  8.9463890e-03], dtype=float32)

In [45]:
np.isclose(obs_flattened,
           reset.observation['walker/reference_rel_bodies_pos_global'].flatten(),
           atol=1e-04).sum()

450

# Comparing Joints Positions

In [46]:
env_brax._joint_orders

Array([31, 30, 29, 33, 32,  2,  1,  0, 41,  5,  4, 40, 39, 36, 35, 34, 16,
       15, 14, 25, 24, 23, 37, 43, 42,  3,  6, 38, 45, 44,  9,  8,  7, 53,
       12, 11, 52, 51, 48, 47, 46, 49, 55, 54, 10, 13, 50, 22, 21, 20, 19,
       18, 17, 28, 27, 26], dtype=int32)

Same with the one shown in notebook earlier

In [47]:
(ref_traj.joints - state.pipeline_state.qpos[7:]).shape

(5, 56)

In [48]:
(ref_traj.joints - state.pipeline_state.qpos[7:])[:,env_brax._joint_orders].shape

(5, 56)

In [49]:
(ref_traj.joints - state.pipeline_state.qpos[7:][env_brax._joint_orders]).shape

(5, 56)

Now the 2d slicing would work and the dimensionality would match up, no (56,57) dimension

In [50]:
brax_body_joints = state.obs[state.obs.shape[0]-280:] # 280 joints
brax_body_joints[:10]

Array([-9.90474820e-02,  7.79400945e-01,  3.47748071e-01,  3.71101469e-01,
        2.09856662e-03,  5.76202932e-04,  4.40269202e-01,  1.17130265e-01,
       -1.34602770e-01,  1.42922294e+00], dtype=float32)

In [51]:
reset.observation['walker/reference_rel_joints'].flatten()[:10]

array([ 3.51097703e-02,  2.23367805e-04,  6.32195582e-04,  1.06107106e-01,
        4.23162204e-03, -1.06353902e+00,  9.13721649e-02,  7.30717459e-02,
        4.87872551e-01,  1.46089771e-01])

In [68]:
time_steps = env.task._time_step + env.task._ref_steps
diff = (env.task._clip_reference_features['joints'][time_steps] -
            env.task._walker_joints)
diff_flattened = diff[:, env.task._walker.mocap_to_observable_joint_order].flatten()
diff_flattened[:10]

array([ 3.51097703e-02,  2.23367805e-04,  6.32195582e-04,  1.06107106e-01,
        4.23162204e-03, -1.06353902e+00,  9.13721649e-02,  7.30717459e-02,
        4.87872551e-01,  1.46089771e-01])

Seems like the ordering of the joints is not the same between COMIC and brax, this is causing the issue

In [69]:
qpos_ref = ref_traj.joints
diff = (qpos_ref - state.pipeline_state.qpos[7:])[:, env.task._walker.mocap_to_observable_joint_order]#[:,env_brax._joint_orders]
brax_env_diff_flattened = diff.flatten()
brax_env_diff_flattened[:10]

Array([ 3.5109770e-02,  2.2336781e-04,  6.3219556e-04,  1.0610711e-01,
        4.2316220e-03, -1.0635390e+00,  9.1372162e-02,  7.3071748e-02,
        4.8787254e-01,  1.4608978e-01], dtype=float32)

In [70]:
qpos_ref = ref_traj.joints
diff = (qpos_ref - state.pipeline_state.qpos[7:][env_brax._joint_orders])
brax_env_diff_flattened_2 = diff.flatten()
brax_env_diff_flattened_2[:10]

Array([ 0.07307175,  0.09137216, -1.063539  ,  1.1439872 , -0.24278677,
        0.14608978, -0.4831119 , -0.07377312, -0.09156196, -1.0633751 ],      dtype=float32)

Checking the order from COMIC and reverse finding the name

In [71]:
np.array(env.task._walker.mocap_to_observable_joint_order)

array([31, 30, 29, 33, 32,  2,  1,  0, 41,  5,  4, 40, 39, 36, 35, 34, 16,
       15, 14, 25, 24, 23, 37, 43, 42,  3,  6, 38, 45, 44,  9,  8,  7, 53,
       12, 11, 52, 51, 48, 47, 46, 49, 55, 54, 10, 13, 50, 22, 21, 20, 19,
       18, 17, 28, 27, 26])

Work backward to get the names

In [72]:
np.array([mujoco.mj_id2name(mj_model, mujoco.mju_str2Type("joint"), joint)
          for joint in env.task._walker.mocap_to_observable_joint_order])

array(['headry', 'headrz', 'upperneckrx', 'lclaviclerz', 'headrx',
       'lfemurry', 'lfemurrz', 'root', 'lhandrx', 'lfootrz', 'ltibiarx',
       'lhandrz', 'lwristry', 'lhumerusry', 'lhumerusrz', 'lclaviclery',
       'lowerbackry', 'lowerbackrz', 'rtoesrx', 'lowerneckry',
       'lowerneckrz', 'thoraxrx', 'lhumerusrx', 'lthumbrz', 'lfingersrx',
       'lfemurrx', 'lfootrx', 'lradiusrx', 'rclaviclerz', 'lthumbrx',
       'rfemurry', 'rfemurrz', 'ltoesrx', 'rhandrx', 'rfootrz',
       'rtibiarx', 'rhandrz', 'rwristry', 'rhumerusry', 'rhumerusrz',
       'rclaviclery', 'rhumerusrx', 'rthumbrz', 'rfingersrx', 'rfemurrx',
       'rfootrx', 'rradiusrx', 'thoraxry', 'thoraxrz', 'upperbackrx',
       'upperbackry', 'upperbackrz', 'lowerbackrx', 'upperneckry',
       'upperneckrz', 'lowerneckrx'], dtype='<U11')

In [73]:
np.isclose(brax_env_diff_flattened, reset.observation['walker/reference_rel_joints'].flatten()).sum()

280

# Comparing root position
correct position should be first three in `qpos`

In [58]:
reset.observation['walker/reference_rel_root_pos_local'].flatten()[:10]

array([ 8.62317918e-06, -2.78480039e-02, -6.42844045e-02, -1.32226590e-05,
       -1.14394152e-01, -1.21373602e-01, -2.92599295e-05, -1.87009094e-01,
       -1.26428729e-01, -2.53498893e-05])

In [59]:
com = state.pipeline_state.qpos[:3]#state.pipeline_state.subtree_com[0]
thing = (ref_traj.position - com)
obs = env_brax.global_vector_to_local_frame(state.pipeline_state, thing)
brax_obs_flattened = jp.concatenate([o.flatten() for o in obs])
brax_obs_flattened[:10]

Array([ 8.62317938e-06, -2.78479792e-02, -6.42844066e-02, -1.32226587e-05,
       -1.14394173e-01, -1.21373594e-01, -2.92599289e-05, -1.87009111e-01,
       -1.26428723e-01, -2.53498893e-05], dtype=float32)

In [60]:
com

Array([0.  , 0.  , 0.94], dtype=float32)

In [61]:
time_steps = env.task._time_step + env.task._ref_steps
obs = env.task._walker.transform_vec_to_egocentric_frame(
        env.physics, (env.task._clip_reference_features['position'][time_steps] -
                  env.task._walker_features['position']))
obs_flattened = np.concatenate([o.flatten() for o in obs])

obs_flattened[:10]

array([ 8.62317918e-06, -2.78480039e-02, -6.42844045e-02, -1.32226590e-05,
       -1.14394152e-01, -1.21373602e-01, -2.92599295e-05, -1.87009094e-01,
       -1.26428729e-01, -2.53498893e-05])

Check step by step

In [62]:
env.task._walker_features['position']

array([0.  , 0.  , 0.94])

In [63]:
np.isclose(brax_obs_flattened, reset.observation['walker/reference_rel_root_pos_local'].flatten()).sum()

15

# Comparing All Observations

In [64]:
np.isclose(state.obs[:450+450+15+280],COMIC_lst, atol=1e-04).sum()

1195

In [65]:
state.obs[:450+450+15+280].shape

(1195,)

All match!!!

# Joints Order for Rodent Model
1. Hard to instantiate because of agent/proto mismatch with our MOCAP data
2. dm_control should use the list of joints just as what is written in the documentation

# Comparing ref_appendges_pos

In [66]:
reset.observation['walker/reference_appendages_pos']

array([[-0.45634599,  0.41476   , -0.11430375,  0.45844934,  0.41457251,
        -0.11405968, -0.37539589, -0.62795519,  0.3692763 ,  0.37496014,
        -0.62812286,  0.36945789, -0.00102365,  0.50932764, -0.10341592,
        -0.40951064,  0.49287611,  0.26847143,  0.41188137,  0.49241393,
         0.26842684, -0.33230921, -0.47547592,  0.17988624,  0.33214898,
        -0.47571259,  0.1796875 , -0.00077913,  0.47470758,  0.1990824 ,
        -0.3761327 ,  0.32668301,  0.424146  ,  0.37814751,  0.3266576 ,
         0.42429442, -0.31386993, -0.44234694,  0.17498533,  0.31418303,
        -0.44217935,  0.17485666, -0.00105083,  0.37097046,  0.31755521,
        -0.41046867,  0.23070891,  0.3622488 ,  0.41254475,  0.23100998,
         0.36249511, -0.3279663 , -0.52494438,  0.29433282,  0.32857207,
        -0.52458837,  0.29418791, -0.0012832 ,  0.32515618,  0.33978611,
        -0.44742211,  0.28107662,  0.24599315,  0.44943593,  0.28141067,
         0.24624748, -0.35384757, -0.55873849,  0.4

In [67]:
ref_traj.appendages.flatten()

Array([-0.456346  ,  0.41476   , -0.11430375,  0.45844933,  0.4145725 ,
       -0.11405968, -0.3753959 , -0.6279552 ,  0.3692763 ,  0.37496015,
       -0.62812287,  0.3694579 , -0.00102365,  0.50932765, -0.10341591,
       -0.40951064,  0.4928761 ,  0.26847142,  0.41188136,  0.49241394,
        0.26842684, -0.33230922, -0.47547594,  0.17988624,  0.332149  ,
       -0.4757126 ,  0.1796875 , -0.00077913,  0.47470757,  0.19908239,
       -0.3761327 ,  0.326683  ,  0.424146  ,  0.3781475 ,  0.3266576 ,
        0.4242944 , -0.31386992, -0.44234693,  0.17498533,  0.31418303,
       -0.44217935,  0.17485666, -0.00105083,  0.37097046,  0.31755522,
       -0.41046867,  0.23070891,  0.3622488 ,  0.41254476,  0.23100998,
        0.3624951 , -0.3279663 , -0.52494437,  0.29433283,  0.32857206,
       -0.52458835,  0.2941879 , -0.0012832 ,  0.32515618,  0.3397861 ,
       -0.44742212,  0.2810766 ,  0.24599315,  0.44943592,  0.28141066,
        0.24624747, -0.35384756, -0.55873847,  0.47058687,  0.35