# Cross Validation With COMIC Environment From dm_control
Data is where intelligence comes form, not just fancy math, we need to make sure the integrity of the observation space, or the data. When we have one that is already available to compares with (COMIC setup) we should use it to garantee that the setup is correct. **This notebook is for checking the integirty of the observation space.**

In [1]:
%load_ext autoreload
%autoreload 2
import jax
from jax import random
from envs.humanoid import HumanoidTracking
import numpy as np
import mediapy as media
import jax.numpy as jp
import mujoco
from dm_control import mjcf

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

# CMU Humanoid env

In [2]:
_CMU_MOCAP_JOINTS = (
    'lfemurrz', 'lfemurry', 'lfemurrx', 'ltibiarx', 'lfootrz', 'lfootrx',
    'ltoesrx', 'rfemurrz', 'rfemurry', 'rfemurrx', 'rtibiarx', 'rfootrz',
    'rfootrx', 'rtoesrx', 'lowerbackrz', 'lowerbackry', 'lowerbackrx',
    'upperbackrz', 'upperbackry', 'upperbackrx', 'thoraxrz', 'thoraxry',
    'thoraxrx', 'lowerneckrz', 'lowerneckry', 'lowerneckrx', 'upperneckrz',
    'upperneckry', 'upperneckrx', 'headrz', 'headry', 'headrx', 'lclaviclerz',
    'lclaviclery', 'lhumerusrz', 'lhumerusry', 'lhumerusrx', 'lradiusrx',
    'lwristry', 'lhandrz', 'lhandrx', 'lfingersrx', 'lthumbrz', 'lthumbrx',
    'rclaviclerz', 'rclaviclery', 'rhumerusrz', 'rhumerusry', 'rhumerusrx',
    'rradiusrx', 'rwristry', 'rhandrz', 'rhandrx', 'rfingersrx', 'rthumbrz',
    'rthumbrx')
len(_CMU_MOCAP_JOINTS)

56

Checking the body positions and make sure they are aligned

In [3]:
root = mjcf.from_path("./assets/humanoid_CMU_V2019.xml")

physics = mjcf.Physics.from_mjcf_model(root).model.ptr

physics.body_pos.shape

physics.name_bodyadr

mj_model = mujoco.MjModel.from_xml_path("./assets/humanoid_CMU_V2019.xml")

params = np.array([
    # 'world',
    # 'root',
    'lhipjoint', 'lfemur', 'ltibia', 'lfoot', 'ltoes', 'rhipjoint', 'rfemur', 'rtibia', 'rfoot',
    'rtoes', 'lowerback', 'upperback', 'thorax', 'lowerneck', 'upperneck', 'head', 'lclavicle', 'lhumerus',
    'lradius', 'lwrist', 'lhand', 'lfingers', 'lthumb', 'rclavicle', 'rhumerus', 'rradius', 'rwrist',
    'rhand', 'rfingers', 'rthumb'])

jp.array([mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("body"), body) for body in params])

Array([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
       19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], dtype=int32)

In [4]:
params = {
    "solver": "cg",
    "iterations": 6,
    "ls_iterations": 6,
    "clip_path": "clips/test_traj.p",
    "body_names":
    ['lhipjoint', 'lfemur', 'ltibia', 'lfoot', 'ltoes', 'rhipjoint', 'rfemur', 'rtibia', 'rfoot',
    'rtoes', 'lowerback', 'upperback', 'thorax', 'lowerneck', 'upperneck', 'head', 'lclavicle', 'lhumerus',
    'lradius', 'lwrist', 'lhand', 'lfingers', 'lthumb', 'rclavicle', 'rhumerus', 'rradius', 'rwrist',
    'rhand', 'rfingers', 'rthumb'],
    "joint_names":
    ['root', 'lowerbackrz', 'lowerbackry', 'lowerbackrx', 'upperbackrz', 'upperbackry', 'upperbackrx', 'thoraxrz',
     'thoraxry', 'thoraxrx', 'lowerneckrz', 'lowerneckry', 'lowerneckrx', 'upperneckrz', 'upperneckry', 'upperneckrx', 'headrz', 'headry',
     'headrx', 'lclaviclerz', 'lclaviclery', 'lhumerusrz', 'lhumerusry', 'lhumerusrx', 'lradiusrx', 'lwristry', 'lhandrz', 'lhandrx', 'lfingersrx',
     'lthumbrz', 'lthumbrx', 'rclaviclerz', 'rclaviclery', 'rhumerusrz', 'rhumerusry', 'rhumerusrx', 'rradiusrx', 'rwristry', 'rhandrz',
     'rhandrx', 'rfingersrx', 'rthumbrz', 'rthumbrx', 'lfemurrz', 'lfemurry', 'lfemurrx', 'ltibiarx', 'lfootrz', 'lfootrx',
     'ltoesrx', 'rfemurrz', 'rfemurry', 'rfemurrx', 'rtibiarx', 'rfootrz', 'rfootrx', 'rtoesrx']
}

In [5]:
env_brax = HumanoidTracking(params)

In [6]:
env_brax.sys.nv # nq=63, nv=62

62

In [7]:
lst =[]
for n in range(100):
      joint = mujoco.mj_id2name(env_brax.sys.mj_model, 
                          mujoco.mju_str2Type("joint"), 
                          n)
      lst.append(joint)
# root joint is a free joint in any direction (adds extra dimensions), rotation locks, quaternion
# something for full range of motion

lst[:5]

['root', 'lfemurrz', 'lfemurry', 'lfemurrx', 'ltibiarx']

In [8]:
state = env_brax.reset(jax.random.PRNGKey(0))

[-6.7749232e-02  8.6231794e-06 -1.7833591e-02 -1.0716664e-01
  3.4846365e-05 -2.3776948e-02  1.6838254e-01 -2.8296500e-02
  2.9513806e-02  8.9463890e-03]


## Stepping brax environment

In [9]:
# key = random.PRNGKey(0)
# jit_step = jax.jit(env_brax.step)
# jit_reset = jax.jit(env_brax.reset)

In [10]:
# next = jit_reset(key)
# print("starting rollout")
# mu = 0
# sigma = .1
# rollout = []
# for _ in range(500):
#     _, key = jax.random.split(key)
#     next = jit_step(next, mu + sigma * random.normal(key, shape=(env.sys.nu,)))
#     rollout.append(next)

In [11]:
# import mediapy as media
# import os

# os.environ["MUJOCO_GL"] = "glfw"
# rollout_data = [s.pipeline_state for s in rollout]

# video = env.render(rollout_data, camera='side')

# media.show_video(video, fps=1.0 / env.dt)

# COMIC Validation Observation

In [12]:
from dm_control import composer
from dm_control.locomotion import arenas
from dm_control.locomotion import walkers
from walker import Rat
from cmu_humanoid import CMUHumanoidPositionControlled
from dm_control.locomotion.mocap import props
from dm_control.locomotion.tasks.reference_pose import tracking
from dm_control.locomotion.tasks.reference_pose import types
from dm_control.utils import io as resources
import os
import h5py
import pickle

maybe reset doesn't replace values in dm control? instantiate it in init? The clip is actually instantiated in init with flag `always_init_at_clip_start=True`

In [13]:
# Define the path to your motion capture data file
file_name = 'clips/test_traj.h5' #'clips/all_snips_250.h5'
current_directory = os.getcwd()
TEST_FILE_PATH = os.path.join(current_directory, file_name)

with h5py.File(TEST_FILE_PATH, 'r') as f:
    dataset_keys = tuple(f.keys())
    dataset = types.ClipCollection(ids=dataset_keys,)

# Set up the mocap tracking task
task = tracking.MultiClipMocapTracking(
    walker=CMUHumanoidPositionControlled,
    arena=arenas.Floor(),
    ref_path=resources.GetResourceFilename(TEST_FILE_PATH),
    ref_steps=(1, 2, 3, 4, 5),
    min_steps=1,
    dataset=dataset,
    reward_type='comic',
    always_init_at_clip_start=True
)

# Initialize the environment
env = composer.Environment(task=task)
reset = env.reset()

# action_spec = env.action_spec()
# env.step(np.zeros(action_spec.shape))

In [14]:
reset.observation['walker/reference_rel_bodies_pos_global'].shape

(1, 450)

In [15]:
reset.observation['walker/reference_rel_root_pos_local'].shape

(1, 15)

In [16]:
reset.observation['walker/reference_rel_joints'].shape # 280 / 5 = 56 makes sense

(1, 280)

If the joints in brax environment is indexed with shape mismatch, no warning is given, `np` does not compain and dimension becomes (56,57) and give a flattened dimension of 3000ish, which is not the correct one.

In [17]:
env.task._walker_features.keys()

dict_keys(['position', 'quaternion', 'joints', 'center_of_mass', 'end_effectors', 'appendages', 'body_positions', 'body_quaternions', 'velocity', 'angular_velocity', 'joints_velocity'])

In [18]:
[item for item in reset.observation]

['reference_props_pos_global',
 'reference_props_quat_global',
 'walker/actuator_activation',
 'walker/appendages_pos',
 'walker/body_height',
 'walker/end_effectors_pos',
 'walker/joints_pos',
 'walker/joints_vel',
 'walker/sensors_accelerometer',
 'walker/sensors_force',
 'walker/sensors_gyro',
 'walker/sensors_torque',
 'walker/sensors_touch',
 'walker/sensors_velocimeter',
 'walker/world_zaxis',
 'walker/reference_rel_joints',
 'walker/reference_rel_bodies_pos_global',
 'walker/reference_rel_bodies_quats',
 'walker/reference_rel_bodies_pos_local',
 'walker/reference_ego_bodies_quats',
 'walker/reference_rel_root_quat',
 'walker/reference_rel_root_pos_local',
 'walker/reference_appendages_pos',
 'walker/clip_id',
 'walker/velocimeter_control',
 'walker/gyro_control',
 'walker/joints_vel_control',
 'walker/time_in_clip']

In [19]:
check_lst = ['walker/reference_rel_bodies_pos_local',
             'walker/reference_rel_bodies_pos_global',
             'walker/reference_rel_root_pos_local',
             'walker/reference_rel_joints'
            ]

COMIC_lst = reset.observation[check_lst[0]]
for check in check_lst[1:]:
    COMIC_lst = np.concatenate((COMIC_lst, reset.observation[check]), axis=1)

COMIC_lst

array([[ 8.62317918e-06, -2.78480039e-02, -6.42844045e-02, ...,
         3.10079384e-01, -1.54169243e-03,  5.53967392e-05]])

In [20]:
state.obs.shape

(1195,)

In [21]:
COMIC_lst.shape

(1, 1195)

## Check qpos matching: position is first 3 initilizatio same

In [22]:
# env.task._clip_reference_features

In [23]:
env.physics.data.qpos

array([0.        , 0.        , 0.94      , 0.46075298, 0.53638297,
       0.53638297, 0.46075298, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        ])

In [24]:
state.pipeline_state.qpos

Array([0.        , 0.        , 0.94      , 0.46075296, 0.536383  ,
       0.536383  , 0.46075296, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        ], dtype=float32)

## Each Transformation Checking (body index check)
it is probably a ordering issue? al;l the numbers looks the same, but the ordering may be what is causing the issue

In [25]:
time_steps = env.task._time_step + env.task._ref_steps
obs = env.task._walker.transform_vec_to_egocentric_frame(env.physics, (env.task._clip_reference_features['body_positions'][time_steps] -
                  env.task._walker_features['body_positions'])[:, env.task._body_idxs])
obs_flattened = np.concatenate([o.flatten() for o in obs])

env.task._walker_features['body_positions'].flatten()[:10]

array([ 0.        ,  0.        ,  0.94      ,  0.03503343,  0.101937  ,
        0.84178506, -0.02235977,  0.24043609,  0.46561425, -0.07985898])

Both from brax environments:
1. Sending to brax then grab body positions `xpos` flattened:

    ```python
    [ 0.02206819 -0.38698062  0.05154771  0.03503342  0.10193701  0.8417851
    -0.02235974  0.24043609  0.46561432 -0.07985894]
    ```

2. Grabbing body positions then sending into brax `xpos` flattened:

    ```python
    [ 0.          0.          0.94        0.03503342  0.10193701  0.8417851
    -0.02235974  0.24043609  0.46561432 -0.07985894]
    ```

Brax environment instantiation does aweird reordering that distort the original indexes

qpos values are the same but xpos are different, body position indexing wrong

In [26]:
env_brax._body_locations

Array([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
       19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], dtype=int32)

positions grabbing now equivalence with the one on the top directly calling from notebook

## Each Transformation Checking (check ref_traj as well)

In [27]:
env.task._clip_reference_features['body_positions'].flatten()[:10]

array([ 0.        ,  0.        ,  0.94      ,  0.03503343,  0.101937  ,
        0.84178506, -0.02235977,  0.24043609,  0.46561425, -0.07985898])

In [28]:
env_brax._ref_traj.body_positions.flatten()[:10]

Array([ 0.        ,  0.        ,  0.94      ,  0.03503343,  0.101937  ,
        0.8417851 , -0.02235977,  0.24043609,  0.46561426, -0.07985898],      dtype=float32)

Similar as well, but...

There is a extra step from `_ref_traj.body_positions` to `_walker_features['body_positions']`, the data still match up

In [29]:
env.task._walker_features['body_positions'].flatten()[:10]

array([ 0.        ,  0.        ,  0.94      ,  0.03503343,  0.101937  ,
        0.84178506, -0.02235977,  0.24043609,  0.46561425, -0.07985898])

## Each Transformation Checking (transform function check)

In [30]:
obs = env.task._walker.transform_vec_to_egocentric_frame(env.physics,
                                                         (env.task._clip_reference_features['body_positions'][time_steps] - 
                                                          env.task._walker_features['body_positions'])[:, env.task._body_idxs])

obs_flattened = np.concatenate([o.flatten() for o in obs])
obs_flattened[:10]

array([ 8.62317918e-06, -2.78480039e-02, -6.42844045e-02,  3.48603853e-05,
       -3.96685538e-02, -1.02354489e-01, -2.82965011e-02,  5.45728515e-02,
        1.62004822e-01, -4.03528938e-03])

brax environment after transformation, seems not to be the same
```python
[-6.7749232e-02  8.6231794e-06 -1.7833591e-02 -1.0716664e-01
  3.4846365e-05 -2.3776948e-02  1.6838254e-01 -2.8296500e-02
  2.9513806e-02  8.9463890e-03]
```

In [31]:
def global_vector_to_local_frame(data, vec_in_world_frame):
    xmat = jp.reshape(data.xmat[0], (3, 3))
    if vec_in_world_frame.shape[-1] == 2:
      return jp.dot(vec_in_world_frame, xmat[:2, :2])
    elif vec_in_world_frame.shape[-1] == 3:
      return jp.dot(vec_in_world_frame, xmat)
    else:
      raise ValueError('`vec_in_world_frame` should have shape with final '
                       'dimension 2 or 3: got {}'.format(
                           vec_in_world_frame.shape))
    
def f(x):
      if len(x.shape) != 1:
        return jax.lax.dynamic_slice_in_dim(
          x, 
          state.info['cur_frame'] + 1, 
          env_brax._ref_traj_length, 
        )
      return jp.array([])

In [32]:
ref_traj = jax.tree_util.tree_map(f, env_brax._ref_traj)

xpos_flatten = state.pipeline_state.xpos[env_brax._body_locations].flatten()

obs = global_vector_to_local_frame(state.pipeline_state,
                             (ref_traj.body_positions - xpos_flatten)
                             .reshape([env_brax._ref_traj_length, 30, 3]))

# obs_flattened = jp.concatenate([o.flatten() for o in obs])

## Each Transformation Checking (differences check)

In [33]:
(env.task._clip_reference_features['body_positions'][time_steps] -
                  env.task._walker_features['body_positions'])[:, env.task._body_idxs].flatten()[:10]

array([-6.77492290e-02,  8.62317918e-06, -1.78336186e-02, -1.07166655e-01,
        3.48603853e-05, -2.37769555e-02,  1.68382568e-01, -2.82965011e-02,
        2.95138832e-02,  8.94642469e-03])

In [34]:
(ref_traj.body_positions - xpos_flatten).reshape([env_brax._ref_traj_length, 30, 3]).flatten()[:10]

Array([-6.7749232e-02,  8.6231794e-06, -1.7833591e-02, -1.0716664e-01,
        3.4846365e-05, -2.3776948e-02,  1.6838254e-01, -2.8296500e-02,
        2.9513806e-02,  8.9463890e-03], dtype=float32)

The differences is the same, may be the problem with transformation then