In [1]:
%load_ext autoreload
%autoreload 2
import jax
from jax import jit, vmap

import mujoco 

from mujoco import mjx

from dm_control import mjcf
from dm_control.locomotion.walkers import rescale

import pickle

from preprocessing.mjx_preprocess import process_clip

# setup environment and stac data for preprocessing
scale_factor = 0.9
stac_path = "./clips/all_snips.p"

with open(stac_path, "rb") as file:
        d = pickle.load(file)        
        data_qpos = d["qpos"]
        
# Load rodent mjcf and rescale, then get the mj_model from that.
# TODO: make this all work in mjx? james cotton did rescaling with mjx model:
# https://github.com/peabody124/BodyModels/blob/f6ef1be5c5d4b7e51028adfc51125e510c13bcc2/body_models/biomechanics_mjx/forward_kinematics.py#L92
root = mjcf.from_path("./assets/rodent.xml")
rescale.rescale_subtree(
    root,
    scale_factor,
    scale_factor,
)
mj_model = mjcf.Physics.from_mjcf_model(root).model.ptr
mj_data = mujoco.MjData(mj_model)

    # Place into GPU
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

2024-07-01 21:03:00.221876: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
  return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))


In [2]:
clip_length = 250
# Split clip like this if you want to run just once
start_step = 0
first_clip_qpos = data_qpos[start_step : start_step + clip_length]

# jit the process_clip function
jit_process_clip = jax.jit(process_clip)

In [3]:
# Reshape qposes to have the batch dimension and vmap the jitted function
all_clips_qpos = data_qpos.reshape((-1, clip_length, mjx_model.nq))
vmap_jit_process_clip = vmap(jit_process_clip, in_axes=(0, None, None))
all_clips_qpos.shape

(842, 250, 74)

In [4]:
all_clips = vmap_jit_process_clip(all_clips_qpos, mjx_model, mjx_data) 

In [8]:
all_clips.position.shape

(842, 250, 3)

### saving and loading (wip)

In [5]:
import h5py

def save_reference_clip_to_h5(filename, reference_clip):
    """
    Save the contents of a ReferenceClip object to an .h5 file.

    Args:
        filename (str): The name of the .h5 file to save to.
        reference_clip (ReferenceClip): The ReferenceClip object to save.
    """
    with h5py.File(filename, 'w') as hf:
        for attr, value in reference_clip.__dict__.items():
            if value is not None:
                # Create a group for each batch
                for batch_idx in range(value.shape[0]):
                    #TODO: instead of batch_x, save as the name given by d["snips_order"]
                    # and save the order as its own thing at the top level
                    group_name = f"{attr}/batch_{batch_idx}"
                    hf.create_dataset(group_name, data=value[batch_idx])

In [6]:
filename = "clips/test_all_clips.h5"
save_reference_clip_to_h5(filename, all_clips)

In [7]:
from preprocessing.mjx_preprocess import ReferenceClip
from jax import numpy as jp
def load_reference_clip_from_h5(filename):
    """
    Load the contents of an .h5 file into a ReferenceClip object.

    Args:
        filename (str): The name of the .h5 file to load from.

    Returns:
        ReferenceClip: The reconstructed ReferenceClip object.
    """
    with h5py.File(filename, 'r') as hf:
        clip = ReferenceClip()
        for attr in clip.__dict__.keys():
            batch_data = []
            batch_idx = 0
            while f"{attr}/batch_{batch_idx}" in hf:
                batch_data.append(hf[f"{attr}/batch_{batch_idx}"][:])
                batch_idx += 1
            if batch_data:
                setattr(clip, attr, jp.stack(batch_data))
        return clip

In [13]:
from dataclasses import fields

for field in fields(all_clips):
    print(field.name, getattr(all_clips, field.name).shape)

position (842, 250, 3)
quaternion (842, 250, 4)
joints (842, 250, 67)
body_positions (842, 250, 66, 3)
velocity (842, 250, 3)
joints_velocity (842, 250, 67)
angular_velocity (842, 250, 3)
body_quaternions (842, 250, 66, 4)


In [30]:
def f(x):
    if len(x.shape) != 1:
        return jax.lax.dynamic_slice_in_dim(
                x,
                0,
                1,
                )
    return jp.array([])

ref_traj = jax.tree_util.tree_map(f, all_clips)
ref_traj.position.shape

(1, 250, 3)

In [9]:
loaded_all_clips = load_reference_clip_from_h5(filename)

FrozenInstanceError: cannot assign to field 'position'