In [1]:
%load_ext autoreload
%autoreload 2
import jax
from jax import random
from envs.rodent import RodentMultiClipTracking, RodentTracking
from preprocessing import mjx_preprocess as mjxp

from brax.training.agents.ppo import networks as ppo_networks

import numpy as np
import mediapy as media
import jax.numpy as jp
import mujoco
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

from brax import envs
from preprocessing.mjx_preprocess import process_clip_to_train

2024-07-15 23:05:12.136490: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
reference_clip = process_clip_to_train('clips/all_snips.p')

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


In [3]:
envs.register_environment("rodent", RodentMultiClipTracking)

# Define the environment arguments
env_args = {
    "reference_clip": reference_clip,
    "end_eff_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R"
    ],
    "appendage_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R",
        "skull"
    ],
    "walker_body_names": [
        "torso",
        "pelvis",
        "upper_leg_L",
        "lower_leg_L",
        "foot_L",
        "upper_leg_R",
        "lower_leg_R",
        "foot_R",
        "skull",
        "jaw",
        "scapula_L",
        "upper_arm_L",
        "lower_arm_L",
        "finger_L",
        "scapula_R",
        "upper_arm_R",
        "lower_arm_R",
        "finger_R"
    ],
    "joint_names": [
        "vertebra_1_extend",
        "hip_L_supinate",
        "hip_L_abduct",
        "hip_L_extend",
        "knee_L",
        "ankle_L",
        "toe_L",
        "hip_R_supinate",
        "hip_R_abduct",
        "hip_R_extend",
        "knee_R",
        "ankle_R",
        "toe_R",
        "vertebra_C11_extend",
        "vertebra_cervical_1_bend",
        "vertebra_axis_twist",
        "atlas",
        "mandible",
        "scapula_L_supinate",
        "scapula_L_abduct",
        "scapula_L_extend",
        "shoulder_L",
        "shoulder_sup_L",
        "elbow_L",
        "wrist_L",
        "scapula_R_supinate",
        "scapula_R_abduct",
        "scapula_R_extend",
        "shoulder_R",
        "shoulder_sup_R",
        "elbow_R",
        "wrist_R",
        "finger_R"
    ],
    "center_of_mass": "torso",
    "mjcf_path": "./assets/rodent.xml",
    "scale_factor": 0.9,
    "solver": "cg",
    "iterations": 6,
    "ls_iterations": 6,
    "healthy_z_range": (0.2, 1.0),
    "reset_noise_scale": 0.1,
    "clip_length": 250,
    "sub_clip_length": 10,
    "ref_traj_length": 5,
    "termination_threshold": 5,
    "body_error_multiplier": 1.0,
    "min_steps": 10,
    "random_start":True,
}

env = envs.get_environment('rodent', **env_args)

## Checking reset & step functionality

**reset function does support `vmap` for parralel env processing, but not the step function**

In [4]:
key = random.PRNGKey(100)
jax.random.randint(key, (env._num_clips,), 0, 235)

Array([ 28, 118,  71,  54,   3, 182,  27, 200, 106, 140, 109, 113,  12,
       213, 186, 203, 164, 139, 189, 132,  42,  68, 206,  54,  60, 138,
        80, 213,  75, 150,  57, 134,   1, 141, 117,  49, 207, 196, 148,
        97, 203, 205,  99, 170, 113,  99, 191,  44, 107, 163,  81,  30,
        74, 127, 144,  97, 191, 174,  33, 157, 176,  86, 172, 123, 120,
        53,  80,  29, 113, 122,  24, 135,   2,  85, 152,  48, 150,  67,
       210,  24, 148,   3,  50, 149,  47, 215, 218, 197,  94,  10, 106,
       112, 185, 111,   2, 197,  77, 203,  10, 230, 106,  82, 229,  57,
         9, 190,  68, 186, 220, 180, 212, 201, 127, 200,  27,  51, 153,
       179,  45, 191, 213,   0, 176,  55, 170, 112,  23, 112,   4,  68,
        35, 137, 115, 212, 132, 196,  47, 101, 136, 134, 162, 127, 195,
       122, 117, 208,  88,  59, 179,  19,  28,  72, 112,  17,  53, 108,
       116, 220, 115,  96, 142, 231,  15,  68, 124,  23,  80, 219, 107,
       116,  55, 231, 139, 109, 160, 113,  57,  49, 105,  52,  2

In [5]:
key = random.PRNGKey(100)
reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)

In [6]:
# n_envs = 5
_, key_envs= jax.random.split(key)
env_state = reset_fn(key_envs)

Inititialize the environment with reset called first time

In [7]:
env_state.info

{'cur_clip_id': Array(499, dtype=int32),
 'cur_frame': Array(241, dtype=int32),
 'sub_clip_frame': Array(0, dtype=int32, weak_type=True),
 'termination_error': (Array(-6.950203, dtype=float32),
  Array(0., dtype=float32, weak_type=True)),
 'traj': Array([-0.20617531,  0.09366457,  0.06282681, ..., -0.01531352,
        -0.03543425, -0.03543425], dtype=float32)}

Inititialize the environment with reset called second time

In [8]:
_, new_key_env = jax.random.split(key_envs)
env_state = reset_fn(new_key_env)

In [9]:
env_state.info

{'cur_clip_id': Array(738, dtype=int32),
 'cur_frame': Array(108, dtype=int32),
 'sub_clip_frame': Array(0, dtype=int32, weak_type=True),
 'termination_error': (Array(-16.027216, dtype=float32),
  Array(0., dtype=float32, weak_type=True)),
 'traj': Array([-0.16082662,  0.14971048,  0.10406622, ...,  0.02211449,
        -0.14542218, -0.14542218], dtype=float32)}

stepping the environment with the step function

In [10]:
step_fn = jax.jit(env.step)

In [11]:
random.normal(key, shape=(env.sys.nu,)).shape

(30,)

In [12]:
mu = 0
sigma = 0.3
action =  mu + sigma * random.normal(key, shape=(env.sys.nu,))
state = env_state.pipeline_state
# env.pipeline_step(state, action)
state.q.shape

(74,)

In [13]:
new_state = step_fn(env_state, action)

In [14]:
new_state.info

{'cur_clip_id': Array(738, dtype=int32),
 'cur_frame': Array(109, dtype=int32),
 'sub_clip_frame': Array(2, dtype=int32, weak_type=True),
 'termination_error': Array(-0.16027215, dtype=float32),
 'traj': Array([-0.16100302,  0.14942591,  0.10313655, ..., -0.5201364 ,
        -0.27276134, -0.27276134], dtype=float32)}

stepping again

In [15]:
new_state = step_fn(new_state, action)
new_state.info

{'cur_clip_id': Array(738, dtype=int32),
 'cur_frame': Array(110, dtype=int32),
 'sub_clip_frame': Array(4, dtype=int32, weak_type=True),
 'termination_error': Array(-0.22881521, dtype=float32),
 'traj': Array([-0.1609548 ,  0.14878199,  0.10276634, ..., -0.80740184,
        -0.4126497 , -0.4126497 ], dtype=float32)}

## `vmap` reset function for parralel environments

In [16]:
reset_fn = jax.jit(jax.vmap(env.reset))
step_fn = jax.jit(jax.vmap(env.step))

n_envs = 10
key_envs = jax.random.split(key, n_envs)
env_state = reset_fn(key_envs)

env_state.info

{'cur_clip_id': Array([485, 401, 632,  71, 651, 258, 284, 312, 366, 194], dtype=int32),
 'cur_frame': Array([235, 187, 136, 149, 123,  10, 210,  94, 146, 112], dtype=int32),
 'sub_clip_frame': Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32, weak_type=True),
 'termination_error': (Array([ -8.206141 ,  -1.6277385, -23.958109 , -10.601182 ,  -9.867305 ,
         -14.988257 , -14.562813 , -11.599106 ,  -5.4344554, -18.814314 ],      dtype=float32),
  Array([0., 0., 0., 0., 0., 1., 1., 0., 1., 0.], dtype=float32, weak_type=True)),
 'traj': Array([[-0.15427066,  0.01494329,  0.05672465, ..., -0.06212887,
         -0.09165404, -0.09165404],
        [ 0.02545187, -0.09089029,  0.05754928, ..., -0.01286346,
         -0.09104721, -0.09104721],
        [-0.18914463,  0.07429746,  0.05606084, ...,  0.03004456,
         -0.03742177, -0.03742177],
        ...,
        [ 0.06721927,  0.00763874,  0.08064112, ...,  0.02001724,
          0.01832652,  0.01832652],
        [ 0.21447162, -0.10602634,  

In [17]:
new_key_env = jax.random.split(key_envs[0], n_envs)
env_state = reset_fn(new_key_env)
env_state.info

{'cur_clip_id': Array([467,  73, 426, 756, 356, 688, 409, 791, 404, 261], dtype=int32),
 'cur_frame': Array([169,  81,  84,  16,  22, 220,  69, 197, 142, 211], dtype=int32),
 'sub_clip_frame': Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32, weak_type=True),
 'termination_error': (Array([-13.7108965, -16.871    ,  -8.502774 , -22.36122  , -11.122095 ,
          -5.816902 ,  -4.333744 , -14.831317 , -12.647524 , -13.258934 ],      dtype=float32),
  Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32, weak_type=True)),
 'traj': Array([[-0.20004347,  0.2385041 ,  0.06029586, ..., -0.08812524,
          0.01174304,  0.01174304],
        [-0.15528704,  0.00930448,  0.11902401, ..., -0.06174397,
          0.01036848,  0.01036848],
        [ 0.08066069,  0.04199532,  0.05919814, ..., -0.15245038,
         -0.18402584, -0.18402584],
        ...,
        [-0.09740867,  0.3776508 ,  0.05755525, ..., -0.12594199,
          0.00241748,  0.00241748],
        [ 0.0057214 ,  0.43774807,  

In [18]:
mu = 0
sigma = 0.3
action =  mu + sigma * random.normal(key, shape=(n_envs, env.sys.nu))
action.shape

(10, 30)

In [19]:
new_state = step_fn(env_state, action)
new_state.info

{'cur_clip_id': Array([467,  73, 426, 756, 356, 688, 409, 791, 404, 261], dtype=int32),
 'cur_frame': Array([170,  82,  85,  17,  23, 221,  70, 198, 143, 212], dtype=int32),
 'sub_clip_frame': Array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int32, weak_type=True),
 'termination_error': Array([-0.13710897, -0.16871   , -0.08502774, -0.22361219, -0.11122095,
        -0.05816902, -0.04333744, -0.14831316, -0.12647523, -0.13258934],      dtype=float32),
 'traj': Array([[-0.20036991,  0.23884232,  0.06018402, ..., -0.5029912 ,
         -0.15563805, -0.15563805],
        [-0.15533067,  0.00956762,  0.11930922, ..., -0.53285253,
         -0.16768971, -0.16768971],
        [ 0.08127106,  0.05134097,  0.06068143, ..., -0.5477654 ,
         -0.3447373 , -0.3447373 ],
        ...,
        [-0.09816636,  0.37509644,  0.05811431, ...,  2.6401162 ,
         -3.9763784 , -3.9763784 ],
        [ 0.01007128,  0.4416503 ,  0.0586198 , ..., -0.45390052,
         -0.17426883, -0.17426883],
        [ 0.3767978

step successfully

## Reference to SingleClip

In [20]:
envs.register_environment("rodent_single", RodentTracking)

# Define the environment arguments
env_args = {
    "reference_clip": reference_clip,
    "end_eff_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R"
    ],
    "appendage_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R",
        "skull"
    ],
    "walker_body_names": [
        "torso",
        "pelvis",
        "upper_leg_L",
        "lower_leg_L",
        "foot_L",
        "upper_leg_R",
        "lower_leg_R",
        "foot_R",
        "skull",
        "jaw",
        "scapula_L",
        "upper_arm_L",
        "lower_arm_L",
        "finger_L",
        "scapula_R",
        "upper_arm_R",
        "lower_arm_R",
        "finger_R"
    ],
    "joint_names": [
        "vertebra_1_extend",
        "hip_L_supinate",
        "hip_L_abduct",
        "hip_L_extend",
        "knee_L",
        "ankle_L",
        "toe_L",
        "hip_R_supinate",
        "hip_R_abduct",
        "hip_R_extend",
        "knee_R",
        "ankle_R",
        "toe_R",
        "vertebra_C11_extend",
        "vertebra_cervical_1_bend",
        "vertebra_axis_twist",
        "atlas",
        "mandible",
        "scapula_L_supinate",
        "scapula_L_abduct",
        "scapula_L_extend",
        "shoulder_L",
        "shoulder_sup_L",
        "elbow_L",
        "wrist_L",
        "scapula_R_supinate",
        "scapula_R_abduct",
        "scapula_R_extend",
        "shoulder_R",
        "shoulder_sup_R",
        "elbow_R",
        "wrist_R",
        "finger_R"
    ],
    "center_of_mass": "torso",
    "mjcf_path": "./assets/rodent.xml",
    "scale_factor": 0.9,
    "solver": "cg",
    "iterations": 6,
    "ls_iterations": 6,
    "healthy_z_range": (0.2, 1.0),
    "reset_noise_scale": 0.1,
    "clip_length": 250,
    "sub_clip_length": 10,
    "ref_traj_length": 5,
    "termination_threshold": 5,
    "body_error_multiplier": 1.0,
}

env_s = envs.get_environment('rodent_single', **env_args)

### `vmap` of the environment not available yet!

In [21]:
key = random.PRNGKey(100)
reset_fn = jax.jit(env_s.reset)
step_fn = jax.jit(env_s.step)

n_envs = 5
_, key_envs= jax.random.split(key)
env_state = reset_fn(key_envs)

In [22]:
env_state.info

{'cur_frame': Array(36, dtype=int32),
 'sub_clip_frame': Array(0, dtype=int32, weak_type=True),
 'termination_error': Array(-3.7766209, dtype=float32),
 'traj': Array([ 0.39154962,  0.16734575,  0.11736992, ..., -0.01084213,
        -0.14542218, -0.14542218], dtype=float32)}

In [23]:
step_fn = jax.jit(env_s.step)
mu = 0
sigma = 0.3
_, key = random.split(key)
action =  mu + sigma * random.normal(key, shape=(env_s.sys.nu,))
action.shape

(30,)

In [24]:
new_state = step_fn(env_state, action)