In [2]:
%load_ext autoreload
%autoreload 2
import jax
from jax import jit, vmap

import mujoco 

from mujoco import mjx

from dm_control import mjcf
from dm_control.locomotion.walkers import rescale

import pickle

from preprocessing.mjx_preprocess import process_clip, save_reference_clip_to_h5, ReferenceClip, load_reference_clip_from_h5
from jax import numpy as jp

# setup environment and stac data for preprocessing
scale_factor = 0.9
stac_path = "./clips/all_snips.p"

with open(stac_path, "rb") as file:
        d = pickle.load(file)        
        data_qpos = d["qpos"]
        
# Load rodent mjcf and rescale, then get the mj_model from that.
# TODO: make this all work in mjx? james cotton did rescaling with mjx model:
# https://github.com/peabody124/BodyModels/blob/f6ef1be5c5d4b7e51028adfc51125e510c13bcc2/body_models/biomechanics_mjx/forward_kinematics.py#L92
root = mjcf.from_path("./assets/rodent.xml")
rescale.rescale_subtree(
    root,
    scale_factor,
    scale_factor,
)
mj_model = mjcf.Physics.from_mjcf_model(root).model.ptr
mj_data = mujoco.MjData(mj_model)

    # Place into GPU
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


In [3]:
clip_length = 250
# Split clip like this if you want to run just once
start_step = 0
first_clip_qpos = data_qpos[start_step : start_step + clip_length]

# jit the process_clip function
jit_process_clip = jax.jit(process_clip)

In [4]:
# Reshape qposes to have the batch dimension and vmap the jitted function
all_clips_qpos = data_qpos.reshape((-1, clip_length, mjx_model.nq))
vmap_jit_process_clip = vmap(jit_process_clip, in_axes=(0, None, None))
all_clips_qpos.shape

(842, 250, 74)

In [5]:
all_clips = vmap_jit_process_clip(all_clips_qpos, mjx_model, mjx_data) 

In [6]:
all_clips.position.shape

(842, 250, 3)

### saving and loading (wip)

In [21]:
ref_path = 'clips/all_snips.p'
with open(ref_path, "rb") as f:
            all_traj = pickle.load(f)

names = [traj.split("/")[-1].split(".")[0] for traj in all_traj["snips_order"]]

In [22]:
names

['FastWalk_171',
 'LGroom_158',
 'FastWalk_169',
 'FastWalk_44',
 'RGroom_45',
 'FastWalk_58',
 'Rear_155',
 'RGroom_16',
 'Rear_82',
 'Walk_145',
 'Rear_175',
 'RGroom_164',
 'Walk_115',
 'Walk_98',
 'Rear_104',
 'Rear_166',
 'Walk_174',
 'Rear_158',
 'LGroom_134',
 'Rear_5',
 'Rear_135',
 'FastWalk_87',
 'RGroom_145',
 'FaceGroom_113',
 'Rear_26',
 'LGroom_121',
 'FaceGroom_31',
 'RGroom_17',
 'FaceGroom_126',
 'Walk_118',
 'RGroom_57',
 'RGroom_121',
 'Rear_103',
 'FaceGroom_13',
 'FaceGroom_8',
 'FaceGroom_74',
 'Walk_19',
 'RGroom_36',
 'FastWalk_178',
 'FastWalk_167',
 'Walk_191',
 'Rear_75',
 'LGroom_177',
 'Walk_113',
 'FaceGroom_91',
 'Walk_2',
 'FastWalk_129',
 'RGroom_147',
 'Walk_128',
 'FastWalk_114',
 'FaceGroom_168',
 'LGroom_193',
 'FastWalk_134',
 'FastWalk_5',
 'Rear_122',
 'FastWalk_35',
 'FastWalk_29',
 'Walk_116',
 'RGroom_73',
 'Walk_10',
 'FastWalk_6',
 'FastWalk_11',
 'FastWalk_88',
 'Rear_143',
 'Rear_128',
 'Rear_189',
 'Rear_149',
 'LGroom_87',
 'RGroom_31',


In [23]:
filename = "clips/test_all_clips.h5"
save_reference_clip_to_h5(filename, names, all_clips)

In [24]:
from dataclasses import fields

for field in fields(all_clips):
    print(field.name, getattr(all_clips, field.name).shape)

position (842, 250, 3)
quaternion (842, 250, 4)
joints (842, 250, 67)
body_positions (842, 250, 66, 3)
velocity (842, 250, 3)
joints_velocity (842, 250, 67)
angular_velocity (842, 250, 3)
body_quaternions (842, 250, 66, 4)


In [25]:
def f(x):
    if len(x.shape) != 1:
        return jax.lax.dynamic_slice_in_dim(
                x,
                0,
                1,
                )
    return jp.array([])

ref_traj = jax.tree_util.tree_map(f, all_clips)
ref_traj.position.shape

(1, 250, 3)

In [27]:
loaded_all_clips = load_reference_clip_from_h5(filename, names)