In [2]:
%load_ext autoreload
%autoreload 2
import jax
from jax import random
from envs.rodent import RodentMultiClipTracking, RodentTracking
from preprocessing import mjx_preprocess as mjxp

from brax.training.agents.ppo import networks as ppo_networks

import numpy as np
import mediapy as media
import jax.numpy as jp
import mujoco
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

from brax import envs
from preprocessing.mjx_preprocess import process_clip_to_train

2024-07-09 01:04:55.183127: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
reference_clip = process_clip_to_train('clips/all_snips.p')

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


In [4]:
envs.register_environment("rodent", RodentMultiClipTracking)

# Define the environment arguments
env_args = {
    "reference_clip": reference_clip,
    "end_eff_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R"
    ],
    "appendage_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R",
        "skull"
    ],
    "walker_body_names": [
        "torso",
        "pelvis",
        "upper_leg_L",
        "lower_leg_L",
        "foot_L",
        "upper_leg_R",
        "lower_leg_R",
        "foot_R",
        "skull",
        "jaw",
        "scapula_L",
        "upper_arm_L",
        "lower_arm_L",
        "finger_L",
        "scapula_R",
        "upper_arm_R",
        "lower_arm_R",
        "finger_R"
    ],
    "joint_names": [
        "vertebra_1_extend",
        "hip_L_supinate",
        "hip_L_abduct",
        "hip_L_extend",
        "knee_L",
        "ankle_L",
        "toe_L",
        "hip_R_supinate",
        "hip_R_abduct",
        "hip_R_extend",
        "knee_R",
        "ankle_R",
        "toe_R",
        "vertebra_C11_extend",
        "vertebra_cervical_1_bend",
        "vertebra_axis_twist",
        "atlas",
        "mandible",
        "scapula_L_supinate",
        "scapula_L_abduct",
        "scapula_L_extend",
        "shoulder_L",
        "shoulder_sup_L",
        "elbow_L",
        "wrist_L",
        "scapula_R_supinate",
        "scapula_R_abduct",
        "scapula_R_extend",
        "shoulder_R",
        "shoulder_sup_R",
        "elbow_R",
        "wrist_R",
        "finger_R"
    ],
    "center_of_mass": "torso",
    "mjcf_path": "./assets/rodent.xml",
    "scale_factor": 0.9,
    "solver": "cg",
    "iterations": 6,
    "ls_iterations": 6,
    "healthy_z_range": (0.2, 1.0),
    "reset_noise_scale": 0.1,
    "clip_length": 250,
    "sub_clip_length": 10,
    "ref_traj_length": 5,
    "termination_threshold": 5,
    "body_error_multiplier": 1.0,
    "min_steps": 10,
    "random_start":True,
}

env = envs.get_environment('rodent', **env_args)

## Checking reset & step functionality

**reset function does support `vmap` for parralel env processing, but not the step function**

In [5]:
key = random.PRNGKey(100)
jax.random.randint(key, (env._num_clips,), 0, 235)

Array([ 28, 118,  71,  54,   3, 182,  27, 200, 106, 140, 109, 113,  12,
       213, 186, 203, 164, 139, 189, 132,  42,  68, 206,  54,  60, 138,
        80, 213,  75, 150,  57, 134,   1, 141, 117,  49, 207, 196, 148,
        97, 203, 205,  99, 170, 113,  99, 191,  44, 107, 163,  81,  30,
        74, 127, 144,  97, 191, 174,  33, 157, 176,  86, 172, 123, 120,
        53,  80,  29, 113, 122,  24, 135,   2,  85, 152,  48, 150,  67,
       210,  24, 148,   3,  50, 149,  47, 215, 218, 197,  94,  10, 106,
       112, 185, 111,   2, 197,  77, 203,  10, 230, 106,  82, 229,  57,
         9, 190,  68, 186, 220, 180, 212, 201, 127, 200,  27,  51, 153,
       179,  45, 191, 213,   0, 176,  55, 170, 112,  23, 112,   4,  68,
        35, 137, 115, 212, 132, 196,  47, 101, 136, 134, 162, 127, 195,
       122, 117, 208,  88,  59, 179,  19,  28,  72, 112,  17,  53, 108,
       116, 220, 115,  96, 142, 231,  15,  68, 124,  23,  80, 219, 107,
       116,  55, 231, 139, 109, 160, 113,  57,  49, 105,  52,  2

In [6]:
key = random.PRNGKey(100)
reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)

In [7]:
# n_envs = 5
_, key_envs= jax.random.split(key)
env_state = reset_fn(key_envs)

In [8]:
env_state.info

{'cur_clip_id': Array(499, dtype=int32),
 'cur_frame': Array(110, dtype=int32),
 'sub_clip_frame': Array(0, dtype=int32, weak_type=True),
 'termination_error': Array(-0.9241526, dtype=float32),
 'traj': Array([-1.77801907e-01,  1.43157870e-01,  5.55636287e-02, -1.73061430e-01,
         1.27704203e-01,  5.36383018e-02, -1.66621372e-01,  1.33204550e-01,
         6.83749514e-03, -1.66621372e-01,  1.33204550e-01,  6.83749514e-03,
        -1.66621372e-01,  1.33204550e-01,  6.83749514e-03, -1.77674338e-01,
         1.43771216e-01,  5.51112145e-02, -1.72757417e-01,  1.28591433e-01,
         5.34426644e-02, -1.67005271e-01,  1.33920848e-01,  6.65747328e-03,
        -1.67005271e-01,  1.33920848e-01,  6.65747328e-03, -1.67005271e-01,
         1.33920848e-01,  6.65747328e-03, -1.77560881e-01,  1.43667117e-01,
         5.49331270e-02, -1.72571808e-01,  1.28556579e-01,  5.33372499e-02,
        -1.66969702e-01,  1.34058267e-01,  6.54165074e-03, -1.66969702e-01,
         1.34058267e-01,  6.54165074e-

In [9]:
_, new_key_env = jax.random.split(key_envs)
env_state = reset_fn(new_key_env)

In [10]:
env_state.info

{'cur_clip_id': Array(738, dtype=int32),
 'cur_frame': Array(28, dtype=int32),
 'sub_clip_frame': Array(0, dtype=int32, weak_type=True),
 'termination_error': Array(-1.0630691, dtype=float32),
 'traj': Array([-0.16659114,  0.14546664,  0.10879885, -0.18231891,  0.1292104 ,
         0.10871847, -0.15377682,  0.12173592,  0.0793609 , -0.15377682,
         0.12173592,  0.0793609 , -0.15377682,  0.12173592,  0.0793609 ,
        -0.16687953,  0.14501932,  0.10862356, -0.18263307,  0.12906387,
         0.10865918, -0.15436172,  0.12095588,  0.07856713, -0.15436172,
         0.12095588,  0.07856713, -0.15436172,  0.12095588,  0.07856713,
        -0.16707502,  0.1445722 ,  0.10828973, -0.18278287,  0.12885769,
         0.10849562, -0.15487142,  0.12028071,  0.07756346, -0.15487142,
         0.12028071,  0.07756346, -0.15487142,  0.12028071,  0.07756346,
        -0.1669629 ,  0.14372286,  0.10828417, -0.18268488,  0.12831607,
         0.10858818, -0.1553124 ,  0.11863773,  0.07670501, -0.155312

In [11]:
step_fn = jax.jit(env.step)

In [12]:
random.normal(key, shape=(env.sys.nu,)).shape

(30,)

In [13]:
mu = 0
sigma = 0.3
action =  mu + sigma * random.normal(key, shape=(env.sys.nu,))
state = env_state.pipeline_state
# env.pipeline_step(state, action)
state.q.shape

(74,)

In [14]:
new_state = step_fn(env_state, action)

In [15]:
new_state.info

{'cur_clip_id': Array(738, dtype=int32),
 'cur_frame': Array(29, dtype=int32),
 'sub_clip_frame': Array(1, dtype=int32, weak_type=True),
 'termination_error': Array(-0.01063069, dtype=float32),
 'traj': Array([-1.66879535e-01,  1.45019323e-01,  1.08623564e-01, -1.82633072e-01,
         1.29063874e-01,  1.08659178e-01, -1.54361725e-01,  1.20955884e-01,
         7.85671324e-02, -1.54361725e-01,  1.20955884e-01,  7.85671324e-02,
        -1.54361725e-01,  1.20955884e-01,  7.85671324e-02, -1.67075023e-01,
         1.44572198e-01,  1.08289734e-01, -1.82782874e-01,  1.28857687e-01,
         1.08495623e-01, -1.54871419e-01,  1.20280713e-01,  7.75634646e-02,
        -1.54871419e-01,  1.20280713e-01,  7.75634646e-02, -1.54871419e-01,
         1.20280713e-01,  7.75634646e-02, -1.66962907e-01,  1.43722862e-01,
         1.08284168e-01, -1.82684883e-01,  1.28316075e-01,  1.08588181e-01,
        -1.55312404e-01,  1.18637726e-01,  7.67050087e-02, -1.55312404e-01,
         1.18637726e-01,  7.67050087e-

## `vmap` reset function for parralel environments

In [18]:
reset_fn = jax.jit(jax.vmap(env.reset))
step_fn = jax.jit(jax.vmap(env.step))

n_envs = 10
key_envs = jax.random.split(key, n_envs)
env_state = reset_fn(key_envs)

env_state.info

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int32[842] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was reset at /root/talmolab-smb/kaiwen/VNL-Brax-Imitation/envs/rodent.py:601 traced for jit.
------------------------------
The leaked intermediate value was created on line /root/talmolab-smb/kaiwen/VNL-Brax-Imitation/envs/rodent.py:572:41 (RodentMultiClipTracking._get_possible_starts). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<frozen runpy>:88:4 (_run_code)
/tmp/ipykernel_530325/2482866467.py:3:12 (<module>)
/root/talmolab-smb/kaiwen/VNL-Brax-Imitation/envs/rodent.py:607:26 (RodentMultiClipTracking.reset)
/root/talmolab-smb/kaiwen/VNL-Brax-Imitation/envs/rodent.py:592:21 (RodentMultiClipTracking._get_clip_id)
/root/talmolab-smb/kaiwen/VNL-Brax-Imitation/envs/rodent.py:572:41 (RodentMultiClipTracking._get_possible_starts)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

In [None]:
new_key_env = jax.random.split(key_envs[0], n_envs)
env_state = reset_fn(new_key_env)
env_state.info

{'cur_clip_id': Array([467,  73, 426, 756, 356, 688, 409, 791, 404, 261], dtype=int32),
 'cur_frame': Array([216, 201, 234,  24, 230,   3,  88,  39,  20, 102], dtype=int32),
 'sub_clip_frame': Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32, weak_type=True),
 'termination_error': Array([0.2852499 , 0.41327906, 0.13209683, 0.19281441, 0.29671246,
        0.05769813, 0.34846103, 0.2727661 , 0.35410887, 0.12264472],      dtype=float32),
 'traj': Array([[-0.18792216,  0.26528084,  0.154288  , ..., -0.04520485,
          0.01174304,  0.01174304],
        [-0.1295034 ,  0.02134102,  0.07011838, ..., -0.00082335,
          0.01036848,  0.01036848],
        [-0.00189915,  0.4155111 ,  0.04983118, ...,  0.48758662,
         -0.18402584, -0.18402584],
        ...,
        [ 0.12375263, -0.04087412,  0.05324259, ..., -0.5494366 ,
          0.00241748,  0.00241748],
        [-0.12025809,  0.05894759,  0.14897513, ..., -0.40574092,
          0.00125737,  0.00125737],
        [ 0.2347006 , -0.0536

In [None]:
mu = 0
sigma = 0.3
action =  mu + sigma * random.normal(key, shape=(env.sys.nu, n_envs))
action.shape

In [None]:
new_state = step_fn(env_state, action)
new_state.info

## Reference to SingleClip

In [None]:
envs.register_environment("rodent_single", RodentTracking)

# Define the environment arguments
env_args = {
    "reference_clip": reference_clip,
    "end_eff_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R"
    ],
    "appendage_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R",
        "skull"
    ],
    "walker_body_names": [
        "torso",
        "pelvis",
        "upper_leg_L",
        "lower_leg_L",
        "foot_L",
        "upper_leg_R",
        "lower_leg_R",
        "foot_R",
        "skull",
        "jaw",
        "scapula_L",
        "upper_arm_L",
        "lower_arm_L",
        "finger_L",
        "scapula_R",
        "upper_arm_R",
        "lower_arm_R",
        "finger_R"
    ],
    "joint_names": [
        "vertebra_1_extend",
        "hip_L_supinate",
        "hip_L_abduct",
        "hip_L_extend",
        "knee_L",
        "ankle_L",
        "toe_L",
        "hip_R_supinate",
        "hip_R_abduct",
        "hip_R_extend",
        "knee_R",
        "ankle_R",
        "toe_R",
        "vertebra_C11_extend",
        "vertebra_cervical_1_bend",
        "vertebra_axis_twist",
        "atlas",
        "mandible",
        "scapula_L_supinate",
        "scapula_L_abduct",
        "scapula_L_extend",
        "shoulder_L",
        "shoulder_sup_L",
        "elbow_L",
        "wrist_L",
        "scapula_R_supinate",
        "scapula_R_abduct",
        "scapula_R_extend",
        "shoulder_R",
        "shoulder_sup_R",
        "elbow_R",
        "wrist_R",
        "finger_R"
    ],
    "center_of_mass": "torso",
    "mjcf_path": "./assets/rodent.xml",
    "scale_factor": 0.9,
    "solver": "cg",
    "iterations": 6,
    "ls_iterations": 6,
    "healthy_z_range": (0.2, 1.0),
    "reset_noise_scale": 0.1,
    "clip_length": 250,
    "sub_clip_length": 10,
    "ref_traj_length": 5,
    "termination_threshold": 5,
    "body_error_multiplier": 1.0,
}

env_s = envs.get_environment('rodent_single', **env_args)

### `vmap` of the environment not available yet!

In [None]:
key = random.PRNGKey(100)
reset_fn = jax.jit(env_s.reset)
step_fn = jax.jit(env_s.step)

n_envs = 5
_, key_envs= jax.random.split(key)
env_state = reset_fn(key_envs)

In [None]:
env_state.info

{'cur_frame': Array(36, dtype=int32),
 'sub_clip_frame': Array(0, dtype=int32, weak_type=True),
 'termination_error': Array(0.22856492, dtype=float32),
 'traj': Array([ 0.39154962,  0.16734575,  0.11736992,  0.3782915 ,  0.15460956,
         0.11611144,  0.4061749 ,  0.12981644,  0.09224587,  0.4061749 ,
         0.12981644,  0.09224587,  0.4061749 ,  0.12981644,  0.09224587,
         0.39091292,  0.16724986,  0.1150074 ,  0.37745553,  0.15453087,
         0.11410891,  0.40458795,  0.13184093,  0.08867101,  0.40458795,
         0.13184093,  0.08867101,  0.40458795,  0.13184093,  0.08867101,
         0.39045614,  0.16842628,  0.11219793,  0.3764755 ,  0.15536475,
         0.11146784,  0.40267816,  0.1338081 ,  0.08442029,  0.40267816,
         0.1338081 ,  0.08442029,  0.40267816,  0.1338081 ,  0.08442029,
         0.3898809 ,  0.17004865,  0.10895442,  0.37572348,  0.15662089,
         0.10837589,  0.4008316 ,  0.13651828,  0.07850949,  0.4008316 ,
         0.13651828,  0.07850949,  0.

In [None]:
step_fn = jax.jit(env_s.step)
mu = 0
sigma = 0.3
_, key = random.split(key)
action =  mu + sigma * random.normal(key, shape=(env_s.sys.nu,))
action.shape

(30,)

In [None]:
new_state = step_fn(env_state, action)