# Test Environment
Need to ensure that control steps in data is the same as control steps in training

In [1]:
import sys
from pathlib import Path

main_path = Path().resolve().parent
if str(main_path) not in sys.path:
    sys.path.append(str(main_path))

import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
os.environ["MUJOCO_GL"] = "egl"
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true --xla_gpu_triton_gemm_any=True "
)
os.environ["PYOPENGL_PLATFORM"] = "egl"

from absl import flags
import hydra
from omegaconf import DictConfig, OmegaConf
import uuid

from pathlib import Path

import functools
import jax
from typing import Dict
import wandb
import imageio
from brax import envs
from dm_control import mjcf as mjcf_dm
from dm_control.locomotion.walkers import rescale

import track_mjx.agent.custom_ppo as ppo
from track_mjx.agent import custom_ppo
from brax.io import model
import numpy as np
import pickle
import warnings
from jax import numpy as jp

from track_mjx.environment.task.multi_clip_tracking import MultiClipTracking
from track_mjx.environment.task.single_clip_tracking import SingleClipTracking
from track_mjx.io.preprocess.mjx_preprocess import process_clip_to_train
from track_mjx.io import preprocess as preprocessing  # the pickle file needs it
from track_mjx.environment import custom_wrappers
from track_mjx.agent import custom_ppo_networks
from track_mjx.agent.logging import setup_training_logging

from track_mjx.environment.walker.rodent import Rodent
from track_mjx.environment.walker.fly import Fly

from track_mjx.environment.task.reward import RewardConfig

FLAGS = flags.FLAGS
warnings.filterwarnings("ignore", category=DeprecationWarning)



In [2]:
import yaml

main_path = Path().resolve().parent
print(main_path)
if str(main_path) not in sys.path:
    sys.path.append(str(main_path))
    
config_path = "track_mjx/config/fly-mc-intention.yaml"
data_path = main_path / "data/FlyReferenceClip.p"

with open(main_path / config_path, "r") as f:
    cfg = yaml.safe_load(f)
    
env_args = cfg["env_config"]["env_args"]
env_rewards = cfg["env_config"]["reward_weights"]
# train_cfg = cfg["train_setup"]["train_config"]
walker_cfg = cfg["walker_config"]
# traj_cfg = cfg["reference_config"]
walker_type = cfg["walker_type"]

try:
    n_devices = jax.device_count(backend="gpu")
    print(f"Using {n_devices} GPUs")
except:
    n_devices = 1
    print("Not using GPUs")

envs.register_environment("rodent_single_clip", SingleClipTracking)
envs.register_environment("rodent_multi_clip", MultiClipTracking)
envs.register_environment("fly_multi_clip", MultiClipTracking)

sys.modules["preprocessing"] = preprocessing
print(f"Loading data: {data_path}")
with open(data_path, "rb") as file:
    reference_clip = pickle.load(file)

walker_map = {
    "rodent": Rodent,
    "fly": Fly,
}
walker_class = walker_map[walker_type]
walker = walker_class(**walker_cfg)

reward_config = RewardConfig(
    too_far_dist = env_rewards["too_far_dist"],
    bad_pose_dist = env_rewards["bad_pose_dist"],
    bad_quat_dist = env_rewards["bad_quat_dist"],
    ctrl_cost_weight = env_rewards["ctrl_cost_weight"],
    ctrl_diff_cost_weight = env_rewards["ctrl_diff_cost_weight"],
    pos_reward_weight = env_rewards["pos_reward_weight"],
    quat_reward_weight = env_rewards["quat_reward_weight"],
    joint_reward_weight = env_rewards["joint_reward_weight"],
    angvel_reward_weight = env_rewards["angvel_reward_weight"],
    bodypos_reward_weight = env_rewards["bodypos_reward_weight"],
    endeff_reward_weight = env_rewards["endeff_reward_weight"],
    healthy_z_range = env_rewards["healthy_z_range"],
    pos_reward_exp_scale = env_rewards["pos_reward_exp_scale"],
    quat_reward_exp_scale = env_rewards["quat_reward_exp_scale"],
    joint_reward_exp_scale = env_rewards["joint_reward_exp_scale"],
    angvel_reward_exp_scale = env_rewards["angvel_reward_exp_scale"],
    bodypos_reward_exp_scale = env_rewards["bodypos_reward_exp_scale"],
    endeff_reward_exp_scale = env_rewards["endeff_reward_exp_scale"],
    penalty_pos_distance_scale = jp.array(env_rewards["penalty_pos_distance_scale"]),
)

env = envs.get_environment(
    env_name=cfg["env_config"]["env_name"],
    reference_clip=reference_clip,
    walker=walker,
    reward_config=reward_config,
    **env_args,
)

print("Environment created successfully!")

/root/vast/kaiwen/track-mjx
Using 1 GPUs
Loading data: /root/vast/kaiwen/track-mjx/data/FlyReferenceClip.p
self._steps_for_cur_frame: 2.0
Environment created successfully!


In [3]:
rng_key = jax.random.PRNGKey(0)
state = env.reset(rng_key)
print("Initial observation:", state.obs)

Initial observation: [ 0.00230872 -0.00026385 -0.00083562 ...  0.0002936  -0.0001037
 -0.00062133]


In [13]:
for i in range(3):
    rng_key = jax.random.PRNGKey(i)
    action = jax.random.uniform(rng_key, shape=(env.action_size,), minval=-1.0, maxval=1.0)
    print("Random action:", action)
    state = env.step(state, action)

    print("Next observation:", state.obs)
    print("Reward:", state.reward)
    print("Done:", state.done)
    print("Current info is:", state.info)

Random action: [-0.15536189  0.20095396  0.43345404  0.34053254 -0.60310864  0.24935126
  0.48199916  0.02686977 -0.6259172  -0.20960021 -0.88146996 -0.47657156
 -0.3413067   0.63707423 -0.01317739  0.55167127  0.7239373  -0.6022897
  0.86403346  0.06278658 -0.77842927  0.6587422   0.16663074  0.83435416
 -0.2646935   0.9968438  -0.6801503  -0.39389205  0.229635   -0.1140871
  0.8935478   0.89969826  0.09094357  0.20086813 -0.80845547 -0.98660445]
Next observation: [ 4.5492849e-03 -1.5185314e-03 -1.0338974e-02 ...  9.0616312e+00
 -1.9432434e+02 -9.8465431e+01]
Reward: 1.7235847
Done: 1.0
Current info is: {'clip_idx': Array(10, dtype=int32), 'cur_frame': Array(12, dtype=int32), 'steps_taken_cur_frame': Array(1, dtype=int32, weak_type=True), 'summed_pos_distance': Array(2.1323545e-05, dtype=float32), 'quat_distance': Array(0.00170804, dtype=float32), 'joint_distance': Array(7.3317356, dtype=float32), 'prev_ctrl': Array([-0.15536189,  0.20095396,  0.43345404,  0.34053254, -0.60310864,
   

In [6]:
state = env.reset(rng_key)
print('reward after reset is:', state.reward)

reward after reset is: 0.0


In [60]:
env._reference_clips.position.shape

(17, 1000, 3)

In [66]:
state = env.reset(rng_key)
state.pipeline_state.qpos[:3]

Array([0.02204722, 0.00628723, 0.00635291], dtype=float32)

In [70]:
reference_clip = jax.tree_map(
            lambda x: x[state.info["cur_frame"]], env._get_reference_clip(state.info)
        )
reference_clip.position

Array([0.02156236, 0.00640273, 0.00592261], dtype=float32)