# Preprocessing the Rodent Data

In [1]:
import numpy as np
import PIL.ImageDraw

import tensorflow as tf
import tensorflow_probability as tfp
from acme import wrappers

from flybody.fly_envs import (
    flight_imitation,
    vision_guided_flight,
    walk_on_ball,
)
from flybody.basic_rodent_2020 import (
    walk_imitation
)

from flybody.tasks.task_utils import (
    get_random_policy,
    real2canonical,
)
from flybody.agents.utils_tf import TestPolicyWrapper
from flybody.utils import (
    display_video,
    rollout_and_render,
)
from flybody.tasks.trajectory_rodent import (
    read_h5_file,
    read_id,
    extract_feature
)

In [2]:
REF_WALK_FLY = "/root/talmolab-smb/kaiwen/flybody/demos/data/walking-dataset_female-only_snippets-16252_trk-files-0-9.hdf5"
REF_WALK_RODENT = "/root/talmolab-smb/kaiwen/flybody/clips/all_snippets.h5"
OUT = "/root/talmolab-smb/kaiwen/flybody/clips/processed_snippets.h5"

In [None]:
read_h5_file(REF_WALK_FLY)

In [None]:
read_id(REF_WALK_FLY)

In [None]:
read_h5_file(REF_WALK_RODENT)

In [3]:
extract_feature(REF_WALK_RODENT, OUT)

In [4]:
read_h5_file(OUT)

Group: id2name
id2name:
Dataset: id2name/joints
    shape: (67,)
    dtype: |S26
Dataset: id2name/qpos
    shape: (67,)
    dtype: |S26
Dataset: id2name/sites
    shape: (18,)
    dtype: |S11
Dataset: timestep_seconds
    shape: ()
    dtype: float64
Group: trajectories
trajectories:
Group: trajectories/0
trajectories/0:
Dataset: trajectories/0/qpos
    shape: (250, 67)
    dtype: float32
Dataset: trajectories/0/qvel
    shape: (250, 67)
    dtype: float32
Dataset: trajectories/0/root2site
    shape: (250, 6, 3)
    dtype: float32
Dataset: trajectories/0/root_qpos
    shape: (250, 3)
    dtype: float32
Dataset: trajectories/0/root_qvel
    shape: (250, 3)
    dtype: float32
Group: trajectories/1
trajectories/1:
Dataset: trajectories/1/qpos
    shape: (250, 67)
    dtype: float32
Dataset: trajectories/1/qvel
    shape: (250, 67)
    dtype: float32
Dataset: trajectories/1/root2site
    shape: (250, 6, 3)
    dtype: float32
Dataset: trajectories/1/root_qpos
    shape: (250, 3)
    dtype: 

# Joints, Body loctaion

In [1]:
import mujoco
from dm_control import mjcf

_RAT_MOCAP_JOINTS = [
    'vertebra_1_extend', 'vertebra_2_bend', 'vertebra_3_twist',
    'vertebra_4_extend', 'vertebra_5_bend', 'vertebra_6_twist',
    'hip_L_supinate', 'hip_L_abduct', 'hip_L_extend', 'knee_L', 'ankle_L',
    'toe_L', 'hip_R_supinate', 'hip_R_abduct', 'hip_R_extend', 'knee_R',
    'ankle_R', 'toe_R', 'vertebra_C1_extend', 'vertebra_C1_bend',
    'vertebra_C2_extend', 'vertebra_C2_bend', 'vertebra_C3_extend',
    'vertebra_C3_bend', 'vertebra_C4_extend', 'vertebra_C4_bend',
    'vertebra_C5_extend', 'vertebra_C5_bend', 'vertebra_C6_extend',
    'vertebra_C6_bend', 'vertebra_C7_extend', 'vertebra_C9_bend',
    'vertebra_C11_extend', 'vertebra_C13_bend', 'vertebra_C15_extend',
    'vertebra_C17_bend', 'vertebra_C19_extend', 'vertebra_C21_bend',
    'vertebra_C23_extend', 'vertebra_C25_bend', 'vertebra_C27_extend',
    'vertebra_C29_bend', 'vertebra_cervical_5_extend',
    'vertebra_cervical_4_bend', 'vertebra_cervical_3_twist',
    'vertebra_cervical_2_extend', 'vertebra_cervical_1_bend',
    'vertebra_axis_twist', 'vertebra_atlant_extend', 'atlas', 'mandible',
    'scapula_L_supinate', 'scapula_L_abduct', 'scapula_L_extend', 'shoulder_L',
    'shoulder_sup_L', 'elbow_L', 'wrist_L', 'finger_L', 'scapula_R_supinate',
    'scapula_R_abduct', 'scapula_R_extend', 'shoulder_R', 'shoulder_sup_R',
    'elbow_R', 'wrist_R', 'finger_R'
]

_RAT_MOCAP_BODY = [
    "torso","pelvis","upper_leg_L",
    "lower_leg_L","foot_L","upper_leg_R",
    "lower_leg_R","foot_R","skull","jaw",
    "scapula_L","upper_arm_L","lower_arm_L",
    "finger_L","scapula_R","upper_arm_R","lower_arm_R","finger_R"]

In [5]:
mj_model = mujoco.MjModel.from_xml_path("/root/talmolab-smb/kaiwen/flybody/flybody/fruitfly/assets_rodent/rodent.xml")
np.array([mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("body"), body) for body in _RAT_MOCAP_BODY])

array([ 1,  8,  9, 10, 11, 13, 14, 15, 54, 55, 56, 57, 58, 60, 61, 62, 63,
       65])

In [7]:
np.array([mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("joint"), joint) for joint in _RAT_MOCAP_JOINTS])

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66])

In [8]:
lst =[]
for n in range(100):
      joint = mujoco.mj_id2name(mj_model, 
                          mujoco.mju_str2Type("joint"), 
                          n)
      lst.append(joint)
lst

['vertebra_1_extend',
 'vertebra_2_bend',
 'vertebra_3_twist',
 'vertebra_4_extend',
 'vertebra_5_bend',
 'vertebra_6_twist',
 'hip_L_supinate',
 'hip_L_abduct',
 'hip_L_extend',
 'knee_L',
 'ankle_L',
 'toe_L',
 'hip_R_supinate',
 'hip_R_abduct',
 'hip_R_extend',
 'knee_R',
 'ankle_R',
 'toe_R',
 'vertebra_C1_extend',
 'vertebra_C1_bend',
 'vertebra_C2_extend',
 'vertebra_C2_bend',
 'vertebra_C3_extend',
 'vertebra_C3_bend',
 'vertebra_C4_extend',
 'vertebra_C4_bend',
 'vertebra_C5_extend',
 'vertebra_C5_bend',
 'vertebra_C6_extend',
 'vertebra_C6_bend',
 'vertebra_C7_extend',
 'vertebra_C9_bend',
 'vertebra_C11_extend',
 'vertebra_C13_bend',
 'vertebra_C15_extend',
 'vertebra_C17_bend',
 'vertebra_C19_extend',
 'vertebra_C21_bend',
 'vertebra_C23_extend',
 'vertebra_C25_bend',
 'vertebra_C27_extend',
 'vertebra_C29_bend',
 'vertebra_cervical_5_extend',
 'vertebra_cervical_4_bend',
 'vertebra_cervical_3_twist',
 'vertebra_cervical_2_extend',
 'vertebra_cervical_1_bend',
 'vertebra_axi

# Instantiate the Environment

In [3]:
ref_traj_path = "/root/talmolab-smb/kaiwen/flybody/clips/processed_snippets.h5"
env = walk_imitation(ref_traj_path)

In [4]:
env.observation_spec()

OrderedDict([('walker/actuator_activation',
              Array(shape=(30,), dtype=dtype('float64'), name='walker/actuator_activation')),
             ('walker/appendages_pos',
              Array(shape=(15,), dtype=dtype('float64'), name='walker/appendages_pos')),
             ('walker/body_height',
              Array(shape=(), dtype=dtype('float64'), name='walker/body_height')),
             ('walker/egocentric_camera',
              BoundedArray(shape=(64, 64, 3), dtype=dtype('uint8'), name='walker/egocentric_camera', minimum=0, maximum=255)),
             ('walker/end_effectors_pos',
              Array(shape=(12,), dtype=dtype('float64'), name='walker/end_effectors_pos')),
             ('walker/joints_pos',
              Array(shape=(30,), dtype=dtype('float64'), name='walker/joints_pos')),
             ('walker/joints_vel',
              Array(shape=(30,), dtype=dtype('float64'), name='walker/joints_vel')),
             ('walker/tendons_pos',
              Array(shape=(0,), dtyp

In [6]:
timestep = env.reset()

KeyError: "Unable to synchronously open object (object 'joint_quat' doesn't exist)"

In [None]:
n_actions = 38 #env.action_spec().shape[0]

def random_action_policy(observation):
    del observation  # Not used by dummy policy.
    random_action = np.random.uniform(-.5, .5, n_actions)
    return random_action

frames = []
rewards = []

timestep = env.reset()
for _ in range(200):
    
    action = random_action_policy(timestep.observation)
    timestep = env.step(action)
    rewards.append(timestep.reward)

    pixels = env.physics.render(camera_id=1)
    frames.append(pixels)

display_video(frames)