# Preprocessing the Rodent Data

In [1]:
import numpy as np
import PIL.ImageDraw

import tensorflow as tf
import tensorflow_probability as tfp
from acme import wrappers

from flybody.fly_envs import (
    flight_imitation,
    vision_guided_flight,
    walk_on_ball,
)
from flybody.basic_rodent_2020 import (
    walk_imitation
)

from flybody.tasks.task_utils import (
    get_random_policy,
    real2canonical,
)
from flybody.agents.utils_tf import TestPolicyWrapper
from flybody.utils import (
    display_video,
    rollout_and_render,
)
from flybody.tasks.trajectory_rodent import (
    read_h5_file,
    read_id,
    extract_feature
)

from flybody.fruitfly import rodent

In [2]:
REF_WALK_FLY = "/root/talmolab-smb/kaiwen/flybody/demos/data/walking-dataset_female-only_snippets-16252_trk-files-0-9.hdf5"
REF_WALK_RODENT = "/root/talmolab-smb/kaiwen/flybody/clips/all_snippets.h5"
OUT = "/root/talmolab-smb/kaiwen/flybody/clips/processed_snippets.h5"

In [None]:
read_h5_file(REF_WALK_FLY)

In [None]:
read_id(REF_WALK_FLY)

In [3]:
read_h5_file(REF_WALK_RODENT)

Group: clip_0
clip_0:
    action: FastWalk
    dt: 0.02
    num_steps: 250
Group: clip_0/props
clip_0/props:
Group: clip_0/walkers
clip_0/walkers:
Group: clip_0/walkers/walker_0
clip_0/walkers/walker_0:
Dataset: clip_0/walkers/walker_0/angular_velocity
    shape: (3, 250)
    dtype: float64
Dataset: clip_0/walkers/walker_0/appendages
    shape: (15, 250)
    dtype: float64
Dataset: clip_0/walkers/walker_0/body_positions
    shape: (54, 250)
    dtype: float64
Dataset: clip_0/walkers/walker_0/body_quaternions
    shape: (72, 250)
    dtype: float64
Dataset: clip_0/walkers/walker_0/center_of_mass
    shape: (3, 250)
    dtype: float64
Dataset: clip_0/walkers/walker_0/end_effectors
    shape: (12, 250)
    dtype: float64
Dataset: clip_0/walkers/walker_0/joints
    shape: (67, 250)
    dtype: float64
Dataset: clip_0/walkers/walker_0/joints_velocity
    shape: (67, 250)
    dtype: float64
Dataset: clip_0/walkers/walker_0/markers
    shape: (0,)
    dtype: float64
Dataset: clip_0/walkers/wal

In [None]:
extract_feature(REF_WALK_RODENT, OUT)

In [None]:
read_h5_file(OUT)

# Joints, Body loctaion

In [None]:
import mujoco
from dm_control import mjcf

_RAT_MOCAP_JOINTS = [
    'vertebra_1_extend', 'vertebra_2_bend', 'vertebra_3_twist',
    'vertebra_4_extend', 'vertebra_5_bend', 'vertebra_6_twist',
    'hip_L_supinate', 'hip_L_abduct', 'hip_L_extend', 'knee_L', 'ankle_L',
    'toe_L', 'hip_R_supinate', 'hip_R_abduct', 'hip_R_extend', 'knee_R',
    'ankle_R', 'toe_R', 'vertebra_C1_extend', 'vertebra_C1_bend',
    'vertebra_C2_extend', 'vertebra_C2_bend', 'vertebra_C3_extend',
    'vertebra_C3_bend', 'vertebra_C4_extend', 'vertebra_C4_bend',
    'vertebra_C5_extend', 'vertebra_C5_bend', 'vertebra_C6_extend',
    'vertebra_C6_bend', 'vertebra_C7_extend', 'vertebra_C9_bend',
    'vertebra_C11_extend', 'vertebra_C13_bend', 'vertebra_C15_extend',
    'vertebra_C17_bend', 'vertebra_C19_extend', 'vertebra_C21_bend',
    'vertebra_C23_extend', 'vertebra_C25_bend', 'vertebra_C27_extend',
    'vertebra_C29_bend', 'vertebra_cervical_5_extend',
    'vertebra_cervical_4_bend', 'vertebra_cervical_3_twist',
    'vertebra_cervical_2_extend', 'vertebra_cervical_1_bend',
    'vertebra_axis_twist', 'vertebra_atlant_extend', 'atlas', 'mandible',
    'scapula_L_supinate', 'scapula_L_abduct', 'scapula_L_extend', 'shoulder_L',
    'shoulder_sup_L', 'elbow_L', 'wrist_L', 'finger_L', 'scapula_R_supinate',
    'scapula_R_abduct', 'scapula_R_extend', 'shoulder_R', 'shoulder_sup_R',
    'elbow_R', 'wrist_R', 'finger_R'
]

_RAT_MOCAP_BODY = [
    "torso","pelvis","upper_leg_L",
    "lower_leg_L","foot_L","upper_leg_R",
    "lower_leg_R","foot_R","skull","jaw",
    "scapula_L","upper_arm_L","lower_arm_L",
    "finger_L","scapula_R","upper_arm_R","lower_arm_R","finger_R"]

In [None]:
mj_model = mujoco.MjModel.from_xml_path("/root/talmolab-smb/kaiwen/flybody/flybody/fruitfly/assets_rodent/rodent.xml")
np.array([mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("body"), body) for body in _RAT_MOCAP_BODY])

In [None]:
np.array([mujoco.mj_name2id(mj_model, mujoco.mju_str2Type("joint"), joint) for joint in _RAT_MOCAP_JOINTS])

In [None]:
lst =[]
for n in range(100):
      joint = mujoco.mj_id2name(mj_model, 
                          mujoco.mju_str2Type("joint"), 
                          n)
      lst.append(joint)
lst

In [None]:
lst =[]
for n in range(100):
      joint = mujoco.mj_id2name(mj_model, 
                          mujoco.mju_str2Type("site"), 
                          n)
      lst.append(joint)
lst

In [None]:
from dm_control import mjcf
xml_path="/root/talmolab-smb/kaiwen/flybody/flybody/fruitfly/assets_rodent/rodent.xml"
xml = mjcf.from_path(xml_path)
np.array(xml.find_all("joint", "free"))

In [None]:
xml_path="/root/talmolab-smb/kaiwen/flybody/flybody/fruitfly/assets/fruitfly.xml"
xml = mjcf.from_path(xml_path)
np.array(xml.find_all("joint")).shape

In [None]:
np.array(xml.find_all("site")).shape

In [None]:
np.array(xml.find_all("body")).shape

# Instantiate the Environment

In [4]:
ref_traj_path = "/root/talmolab-smb/kaiwen/flybody/clips/all_snippets.h5"
env = walk_imitation(ref_traj_path)

In [5]:
env.observation_spec()

OrderedDict([('reference_props_pos_global',
              Array(shape=(0,), dtype=dtype('float64'), name='reference_props_pos_global')),
             ('reference_props_quat_global',
              Array(shape=(0,), dtype=dtype('float64'), name='reference_props_quat_global')),
             ('walker/actuator_activation',
              Array(shape=(38,), dtype=dtype('float64'), name='walker/actuator_activation')),
             ('walker/appendages_pos',
              Array(shape=(15,), dtype=dtype('float64'), name='walker/appendages_pos')),
             ('walker/body_height',
              Array(shape=(), dtype=dtype('float64'), name='walker/body_height')),
             ('walker/end_effectors_pos',
              Array(shape=(12,), dtype=dtype('float64'), name='walker/end_effectors_pos')),
             ('walker/joints_pos',
              Array(shape=(30,), dtype=dtype('float64'), name='walker/joints_pos')),
             ('walker/joints_vel',
              Array(shape=(30,), dtype=dtype('floa

In [6]:
env.action_spec()

BoundedArray(shape=(38,), dtype=dtype('float64'), name='walker/lumbar_extend\twalker/lumbar_bend\twalker/lumbar_twist\twalker/cervical_extend\twalker/cervical_bend\twalker/cervical_twist\twalker/caudal_extend\twalker/caudal_bend\twalker/hip_L_supinate\twalker/hip_L_abduct\twalker/hip_L_extend\twalker/knee_L\twalker/ankle_L\twalker/toe_L\twalker/hip_R_supinate\twalker/hip_R_abduct\twalker/hip_R_extend\twalker/knee_R\twalker/ankle_R\twalker/toe_R\twalker/atlas\twalker/mandible\twalker/scapula_L_supinate\twalker/scapula_L_abduct\twalker/scapula_L_extend\twalker/shoulder_L\twalker/shoulder_sup_L\twalker/elbow_L\twalker/wrist_L\twalker/finger_L\twalker/scapula_R_supinate\twalker/scapula_R_abduct\twalker/scapula_R_extend\twalker/shoulder_R\twalker/shoulder_sup_R\twalker/elbow_R\twalker/wrist_R\twalker/finger_R', minimum=[-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
 -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
 -1. -1.], maximum=[1. 1. 1.

In [7]:
timestep = env.reset()

In [8]:
n_actions = 38 #env.action_spec().shape[0]

def random_action_policy(observation):
    del observation  # Not used by dummy policy.
    random_action = np.random.uniform(-.5, .5, n_actions)
    return random_action

frames = []
rewards = []

timestep = env.reset()
for _ in range(200):
    
    action = random_action_policy(timestep.observation)
    timestep = env.step(action)
    rewards.append(timestep.reward)

    pixels = env.physics.render(camera_id=1)
    frames.append(pixels)

display_video(frames)