In [159]:
%load_ext autoreload
%autoreload 2
import jax
from jax import random
from envs.rodent import RodentMultiClipTracking, RodentTracking
from preprocessing import mjx_preprocess as mjxp

from brax.training.agents.ppo import networks as ppo_networks

import numpy as np
import mediapy as media
import jax.numpy as jp
import mujoco
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

from brax import envs
from preprocessing.mjx_preprocess import process_clip_to_train

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [160]:
reference_clip = process_clip_to_train('clips/all_snips.p')

  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


In [161]:
envs.register_environment("rodent", RodentMultiClipTracking)

# Define the environment arguments
env_args = {
    "reference_clip": reference_clip,
    "end_eff_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R"
    ],
    "appendage_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R",
        "skull"
    ],
    "walker_body_names": [
        "torso",
        "pelvis",
        "upper_leg_L",
        "lower_leg_L",
        "foot_L",
        "upper_leg_R",
        "lower_leg_R",
        "foot_R",
        "skull",
        "jaw",
        "scapula_L",
        "upper_arm_L",
        "lower_arm_L",
        "finger_L",
        "scapula_R",
        "upper_arm_R",
        "lower_arm_R",
        "finger_R"
    ],
    "joint_names": [
        "vertebra_1_extend",
        "hip_L_supinate",
        "hip_L_abduct",
        "hip_L_extend",
        "knee_L",
        "ankle_L",
        "toe_L",
        "hip_R_supinate",
        "hip_R_abduct",
        "hip_R_extend",
        "knee_R",
        "ankle_R",
        "toe_R",
        "vertebra_C11_extend",
        "vertebra_cervical_1_bend",
        "vertebra_axis_twist",
        "atlas",
        "mandible",
        "scapula_L_supinate",
        "scapula_L_abduct",
        "scapula_L_extend",
        "shoulder_L",
        "shoulder_sup_L",
        "elbow_L",
        "wrist_L",
        "scapula_R_supinate",
        "scapula_R_abduct",
        "scapula_R_extend",
        "shoulder_R",
        "shoulder_sup_R",
        "elbow_R",
        "wrist_R",
        "finger_R"
    ],
    "center_of_mass": "torso",
    "mjcf_path": "./assets/rodent.xml",
    "scale_factor": 0.9,
    "solver": "cg",
    "iterations": 6,
    "ls_iterations": 6,
    "healthy_z_range": (0.2, 1.0),
    "reset_noise_scale": 0.1,
    "clip_length": 250,
    "sub_clip_length": 10,
    "ref_traj_length": 5,
    "termination_threshold": 5,
    "body_error_multiplier": 1.0,
    # "min_steps": 10,
    "random_start":True,
}

env = envs.get_environment('rodent', **env_args)

## Checking reset & step functionality

**reset function does support `vmap` for parralel env processing, but not the step function**

In [162]:
key = random.PRNGKey(100)
jax.random.randint(key, (env._num_clips,), 0, 235)

Array([ 28, 118,  71,  54,   3, 182,  27, 200, 106, 140, 109, 113,  12,
       213, 186, 203, 164, 139, 189, 132,  42,  68, 206,  54,  60, 138,
        80, 213,  75, 150,  57, 134,   1, 141, 117,  49, 207, 196, 148,
        97, 203, 205,  99, 170, 113,  99, 191,  44, 107, 163,  81,  30,
        74, 127, 144,  97, 191, 174,  33, 157, 176,  86, 172, 123, 120,
        53,  80,  29, 113, 122,  24, 135,   2,  85, 152,  48, 150,  67,
       210,  24, 148,   3,  50, 149,  47, 215, 218, 197,  94,  10, 106,
       112, 185, 111,   2, 197,  77, 203,  10, 230, 106,  82, 229,  57,
         9, 190,  68, 186, 220, 180, 212, 201, 127, 200,  27,  51, 153,
       179,  45, 191, 213,   0, 176,  55, 170, 112,  23, 112,   4,  68,
        35, 137, 115, 212, 132, 196,  47, 101, 136, 134, 162, 127, 195,
       122, 117, 208,  88,  59, 179,  19,  28,  72, 112,  17,  53, 108,
       116, 220, 115,  96, 142, 231,  15,  68, 124,  23,  80, 219, 107,
       116,  55, 231, 139, 109, 160, 113,  57,  49, 105,  52,  2

In [163]:
key = random.PRNGKey(100)
reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)

In [164]:
# n_envs = 5
_, key_envs= jax.random.split(key)
env_state = reset_fn(key_envs)

In [165]:
env_state.info

{'cur_clip_id': Array(499, dtype=int32),
 'cur_frame': Array(110, dtype=int32),
 'sub_clip_frame': Array(0, dtype=int32, weak_type=True),
 'termination_error': Array(0.11066836, dtype=float32),
 'traj': Array([-1.77801907e-01,  1.43157870e-01,  5.55636287e-02, -1.73061430e-01,
         1.27704203e-01,  5.36383018e-02, -1.66621372e-01,  1.33204550e-01,
         6.83749514e-03, -1.66621372e-01,  1.33204550e-01,  6.83749514e-03,
        -1.66621372e-01,  1.33204550e-01,  6.83749514e-03, -1.77674338e-01,
         1.43771216e-01,  5.51112145e-02, -1.72757417e-01,  1.28591433e-01,
         5.34426644e-02, -1.67005271e-01,  1.33920848e-01,  6.65747328e-03,
        -1.67005271e-01,  1.33920848e-01,  6.65747328e-03, -1.67005271e-01,
         1.33920848e-01,  6.65747328e-03, -1.77560881e-01,  1.43667117e-01,
         5.49331270e-02, -1.72571808e-01,  1.28556579e-01,  5.33372499e-02,
        -1.66969702e-01,  1.34058267e-01,  6.54165074e-03, -1.66969702e-01,
         1.34058267e-01,  6.54165074e-

In [166]:
_, new_key_env = jax.random.split(key_envs)
env_state = reset_fn(new_key_env)

In [167]:
env_state.info

{'cur_clip_id': Array(738, dtype=int32),
 'cur_frame': Array(28, dtype=int32),
 'sub_clip_frame': Array(0, dtype=int32, weak_type=True),
 'termination_error': Array(0.22899085, dtype=float32),
 'traj': Array([-0.16659114,  0.14546664,  0.10879885, -0.18231891,  0.1292104 ,
         0.10871847, -0.15377682,  0.12173592,  0.0793609 , -0.15377682,
         0.12173592,  0.0793609 , -0.15377682,  0.12173592,  0.0793609 ,
        -0.16687953,  0.14501932,  0.10862356, -0.18263307,  0.12906387,
         0.10865918, -0.15436172,  0.12095588,  0.07856713, -0.15436172,
         0.12095588,  0.07856713, -0.15436172,  0.12095588,  0.07856713,
        -0.16707502,  0.1445722 ,  0.10828973, -0.18278287,  0.12885769,
         0.10849562, -0.15487142,  0.12028071,  0.07756346, -0.15487142,
         0.12028071,  0.07756346, -0.15487142,  0.12028071,  0.07756346,
        -0.1669629 ,  0.14372286,  0.10828417, -0.18268488,  0.12831607,
         0.10858818, -0.1553124 ,  0.11863773,  0.07670501, -0.155312

In [168]:
step_fn = jax.jit(env.step)

In [169]:
random.normal(key, shape=(env.sys.nu,)).shape

(30,)

In [170]:
mu = 0
sigma = 0.3
action =  mu + sigma * random.normal(key, shape=(env.sys.nu,))
state = env_state.pipeline_state
# env.pipeline_step(state, action)
state.q.shape

(74,)

In [171]:
new_state = step_fn(env_state, action)

In [172]:
new_state.info

{'cur_clip_id': Array(738, dtype=int32),
 'cur_frame': Array(29, dtype=int32),
 'sub_clip_frame': Array(1, dtype=int32, weak_type=True),
 'termination_error': Array(0.00228991, dtype=float32),
 'traj': Array([-1.66879535e-01,  1.45019323e-01,  1.08623564e-01, -1.82633072e-01,
         1.29063874e-01,  1.08659178e-01, -1.54361725e-01,  1.20955884e-01,
         7.85671324e-02, -1.54361725e-01,  1.20955884e-01,  7.85671324e-02,
        -1.54361725e-01,  1.20955884e-01,  7.85671324e-02, -1.67075023e-01,
         1.44572198e-01,  1.08289734e-01, -1.82782874e-01,  1.28857687e-01,
         1.08495623e-01, -1.54871419e-01,  1.20280713e-01,  7.75634646e-02,
        -1.54871419e-01,  1.20280713e-01,  7.75634646e-02, -1.54871419e-01,
         1.20280713e-01,  7.75634646e-02, -1.66962907e-01,  1.43722862e-01,
         1.08284168e-01, -1.82684883e-01,  1.28316075e-01,  1.08588181e-01,
        -1.55312404e-01,  1.18637726e-01,  7.67050087e-02, -1.55312404e-01,
         1.18637726e-01,  7.67050087e-0

## `vmap` reset function for parralel environments

In [123]:
reset_fn = jax.jit(jax.vmap(env.reset))
step_fn = jax.jit(jax.vmap(env.step))

n_envs = 10
key_envs = jax.random.split(key, n_envs)
env_state = reset_fn(key_envs)

env_state.info

{'cur_clip_id': Array([485, 401, 632,  71, 651, 258, 284, 312, 366, 194], dtype=int32),
 'cur_frame': Array([208, 159, 155,  19, 204, 201,  91,  90, 119, 136], dtype=int32),
 'sub_clip_frame': Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32, weak_type=True),
 'termination_error': Array([0.24269372, 0.22189927, 0.11183977, 0.23087865, 0.2814368 ,
        0.34638757, 0.07063329, 0.23716456, 0.26817143, 0.34759504],      dtype=float32),
 'traj': Array([[-0.15032037,  0.01754861,  0.05537911, ..., -0.08783695,
         -0.09165404, -0.09165404],
        [ 0.07968964, -0.03570389,  0.05808613, ..., -0.20993605,
         -0.09104721, -0.09104721],
        [-0.17849676,  0.06983046,  0.04860817, ...,  0.38336328,
         -0.03742177, -0.03742177],
        ...,
        [ 0.07268548,  0.01696111,  0.09246162, ...,  0.13834521,
          0.01832652,  0.01832652],
        [ 0.20566644, -0.09198424,  0.15571444, ..., -0.22191952,
          0.01180918,  0.01180918],
        [-0.2095694 ,  0.1857

In [126]:
new_key_env = jax.random.split(key_envs[0], n_envs)
env_state = reset_fn(new_key_env)
env_state.info

{'cur_clip_id': Array([467,  73, 426, 756, 356, 688, 409, 791, 404, 261], dtype=int32),
 'cur_frame': Array([216, 201, 234,  24, 230,   3,  88,  39,  20, 102], dtype=int32),
 'sub_clip_frame': Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32, weak_type=True),
 'termination_error': Array([0.2852499 , 0.41327906, 0.13209683, 0.19281441, 0.29671246,
        0.05769813, 0.34846103, 0.2727661 , 0.35410887, 0.12264472],      dtype=float32),
 'traj': Array([[-0.18792216,  0.26528084,  0.154288  , ..., -0.04520485,
          0.01174304,  0.01174304],
        [-0.1295034 ,  0.02134102,  0.07011838, ..., -0.00082335,
          0.01036848,  0.01036848],
        [-0.00189915,  0.4155111 ,  0.04983118, ...,  0.48758662,
         -0.18402584, -0.18402584],
        ...,
        [ 0.12375263, -0.04087412,  0.05324259, ..., -0.5494366 ,
          0.00241748,  0.00241748],
        [-0.12025809,  0.05894759,  0.14897513, ..., -0.40574092,
          0.00125737,  0.00125737],
        [ 0.2347006 , -0.0536

## Reference to SingleClip

In [68]:
envs.register_environment("rodent_single", RodentTracking)

# Define the environment arguments
env_args = {
    "reference_clip": reference_clip,
    "end_eff_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R"
    ],
    "appendage_names": [
        "foot_L",
        "foot_R",
        "hand_L",
        "hand_R",
        "skull"
    ],
    "walker_body_names": [
        "torso",
        "pelvis",
        "upper_leg_L",
        "lower_leg_L",
        "foot_L",
        "upper_leg_R",
        "lower_leg_R",
        "foot_R",
        "skull",
        "jaw",
        "scapula_L",
        "upper_arm_L",
        "lower_arm_L",
        "finger_L",
        "scapula_R",
        "upper_arm_R",
        "lower_arm_R",
        "finger_R"
    ],
    "joint_names": [
        "vertebra_1_extend",
        "hip_L_supinate",
        "hip_L_abduct",
        "hip_L_extend",
        "knee_L",
        "ankle_L",
        "toe_L",
        "hip_R_supinate",
        "hip_R_abduct",
        "hip_R_extend",
        "knee_R",
        "ankle_R",
        "toe_R",
        "vertebra_C11_extend",
        "vertebra_cervical_1_bend",
        "vertebra_axis_twist",
        "atlas",
        "mandible",
        "scapula_L_supinate",
        "scapula_L_abduct",
        "scapula_L_extend",
        "shoulder_L",
        "shoulder_sup_L",
        "elbow_L",
        "wrist_L",
        "scapula_R_supinate",
        "scapula_R_abduct",
        "scapula_R_extend",
        "shoulder_R",
        "shoulder_sup_R",
        "elbow_R",
        "wrist_R",
        "finger_R"
    ],
    "center_of_mass": "torso",
    "mjcf_path": "./assets/rodent.xml",
    "scale_factor": 0.9,
    "solver": "cg",
    "iterations": 6,
    "ls_iterations": 6,
    "healthy_z_range": (0.2, 1.0),
    "reset_noise_scale": 0.1,
    "clip_length": 250,
    "sub_clip_length": 10,
    "ref_traj_length": 5,
    "termination_threshold": 5,
    "body_error_multiplier": 1.0,
}

env_s = envs.get_environment('rodent_single', **env_args)

### `vmap` of the environment not available yet!

In [94]:
key = random.PRNGKey(100)
reset_fn = jax.jit(env_s.reset)
step_fn = jax.jit(env_s.step)

n_envs = 5
_, key_envs= jax.random.split(key)
env_state = reset_fn(key_envs)

In [95]:
env_state.info

{'cur_frame': Array(36, dtype=int32),
 'sub_clip_frame': Array(0, dtype=int32, weak_type=True),
 'termination_error': Array(0.22856492, dtype=float32),
 'traj': Array([ 0.39154962,  0.16734575,  0.11736992,  0.3782915 ,  0.15460956,
         0.11611144,  0.4061749 ,  0.12981644,  0.09224587,  0.4061749 ,
         0.12981644,  0.09224587,  0.4061749 ,  0.12981644,  0.09224587,
         0.39091292,  0.16724986,  0.1150074 ,  0.37745553,  0.15453087,
         0.11410891,  0.40458795,  0.13184093,  0.08867101,  0.40458795,
         0.13184093,  0.08867101,  0.40458795,  0.13184093,  0.08867101,
         0.39045614,  0.16842628,  0.11219793,  0.3764755 ,  0.15536475,
         0.11146784,  0.40267816,  0.1338081 ,  0.08442029,  0.40267816,
         0.1338081 ,  0.08442029,  0.40267816,  0.1338081 ,  0.08442029,
         0.3898809 ,  0.17004865,  0.10895442,  0.37572348,  0.15662089,
         0.10837589,  0.4008316 ,  0.13651828,  0.07850949,  0.4008316 ,
         0.13651828,  0.07850949,  0.

In [96]:
step_fn = jax.jit(env_s.step)
mu = 0
sigma = 0.3
_, key = random.split(key)
action =  mu + sigma * random.normal(key, shape=(env_s.sys.nu,))
action.shape

(30,)

In [97]:
new_state = step_fn(env_state, action)