# Try DMC Tracking Env directly

In [1]:
import numpy as np
import PIL.ImageDraw

import tensorflow as tf
import tensorflow_probability as tfp
from acme import wrappers

from flybody.basic_rodent_2020 import (
    walk_imitation
)

from flybody.agents.utils_tf import TestPolicyWrapper
from flybody.utils import (
    display_video,
    rollout_and_render,
)

from flybody.fruitfly import rodent

from dm_control import composer
from dm_control.locomotion import arenas
from dm_control.locomotion import walkers
from dm_control.locomotion.walkers import rodent
from dm_control.locomotion.mocap import props
from dm_control.locomotion.tasks.reference_pose import tracking
from dm_control.locomotion.tasks.reference_pose import types
from dm_control.utils import io as resources
import os
import h5py

In [2]:
REF_WALK_RODENT = "/root/talmolab-smb/kaiwen/flybody/clips/all_snippets.h5"

In [None]:
# Define the path to your motion capture data file
file_name = REF_WALK_RODENT
current_directory = os.getcwd()
TEST_FILE_PATH = os.path.join(current_directory, file_name)

with h5py.File(TEST_FILE_PATH, 'r') as f:
    dataset_keys = tuple(f.keys())
    dataset = types.ClipCollection(ids=dataset_keys,)

# Set up the mocap tracking task
task = tracking.MultiClipMocapTracking(
    walker=rodent.Rat,
    arena=arenas.Floor(),
    ref_path=resources.GetResourceFilename(TEST_FILE_PATH),
    ref_steps=(1, 2, 3, 4, 5),
    min_steps=1,
    dataset=dataset,
    reward_type='comic',
    always_init_at_clip_start=True
)

# Initialize the environment
env = composer.Environment(task=task)
reset = env.reset()

# Instantiate Environment

In [3]:
ref_traj_path = "/root/talmolab-smb/kaiwen/flybody/clips/all_snippets.h5"
env = walk_imitation(ref_traj_path)

In [4]:
env.observation_spec()

OrderedDict([('reference_props_pos_global',
              Array(shape=(0,), dtype=dtype('float64'), name='reference_props_pos_global')),
             ('reference_props_quat_global',
              Array(shape=(0,), dtype=dtype('float64'), name='reference_props_quat_global')),
             ('walker/actuator_activation',
              Array(shape=(38,), dtype=dtype('float64'), name='walker/actuator_activation')),
             ('walker/appendages_pos',
              Array(shape=(15,), dtype=dtype('float64'), name='walker/appendages_pos')),
             ('walker/body_height',
              Array(shape=(), dtype=dtype('float64'), name='walker/body_height')),
             ('walker/end_effectors_pos',
              Array(shape=(12,), dtype=dtype('float64'), name='walker/end_effectors_pos')),
             ('walker/joints_pos',
              Array(shape=(30,), dtype=dtype('float64'), name='walker/joints_pos')),
             ('walker/joints_vel',
              Array(shape=(30,), dtype=dtype('floa

In [5]:
env.action_spec()

BoundedArray(shape=(38,), dtype=dtype('float64'), name='walker/lumbar_extend\twalker/lumbar_bend\twalker/lumbar_twist\twalker/cervical_extend\twalker/cervical_bend\twalker/cervical_twist\twalker/caudal_extend\twalker/caudal_bend\twalker/hip_L_supinate\twalker/hip_L_abduct\twalker/hip_L_extend\twalker/knee_L\twalker/ankle_L\twalker/toe_L\twalker/hip_R_supinate\twalker/hip_R_abduct\twalker/hip_R_extend\twalker/knee_R\twalker/ankle_R\twalker/toe_R\twalker/atlas\twalker/mandible\twalker/scapula_L_supinate\twalker/scapula_L_abduct\twalker/scapula_L_extend\twalker/shoulder_L\twalker/shoulder_sup_L\twalker/elbow_L\twalker/wrist_L\twalker/finger_L\twalker/scapula_R_supinate\twalker/scapula_R_abduct\twalker/scapula_R_extend\twalker/shoulder_R\twalker/shoulder_sup_R\twalker/elbow_R\twalker/wrist_R\twalker/finger_R', minimum=[-1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
 -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1. -1.
 -1. -1.], maximum=[1. 1. 1.

In [7]:
n_actions = env.action_spec().shape[0]

def random_action_policy(observation):
    del observation  # Not used by dummy policy.
    random_action = np.random.uniform(-.5, .5, n_actions)
    return random_action

frames = []
rewards = []

timestep = env.reset()
for _ in range(100):
    
    action = random_action_policy(timestep.observation)
    timestep = env.step(action)
    rewards.append(timestep.reward)

    pixels = env.physics.render(camera_id=1)
    frames.append(pixels)

display_video(frames)