# Simplified Environment Setup for Controller Policy

We can reduce the observation space of the encoder directly to the future location of the five timestep, and we can use a generative encoder layer to use the model.

In [None]:
from track_mjx.environment.task import joysticks_brax
from track_mjx.environment.task.joysticks_brax import RodentJoystick
import jax.numpy as jnp
import jax
import os

import logging

from brax.training import acting
from brax.training.acme import running_statistics
from track_mjx.environment import wrappers
from track_mjx.environment.walker.rodent import Rodent
from track_mjx.agent import ppo_networks
from track_mjx.agent import losses
from brax.training.acme import specs
from track_mjx.io import load
from track_mjx.environment.task.multi_clip_tracking import MultiClipTracking
from track_mjx.environment.task.single_clip_tracking import SingleClipTracking
from track_mjx.environment.task.joysticks_brax import RodentJoystick
from track_mjx.environment import wrappers
from track_mjx.agent import checkpointing
from track_mjx.agent.logging import rollout_logging_fn, make_rollout_renderer
from track_mjx.environment.walker.rodent import Rodent
from track_mjx.environment.walker.fly import Fly
from track_mjx.environment.task.reward import RewardConfig
import functools

network_factory = functools.partial(
    ppo_networks.make_intention_ppo_networks,
    encoder_hidden_layer_sizes=tuple([256, 256, 256]),
    decoder_hidden_layer_sizes=tuple([256, 256, 256]),
    value_hidden_layer_sizes=tuple([256, 256, 256]),
    intention_latent_size=60,
)

ppo_network = network_factory(
    200,
    60,
    50,
    preprocess_observations_fn=running_statistics.normalize,
)
# eval_env = MultiClipTracking()

eval_env = wrappers.wrap(eval_env)
make_policy = ppo_networks.make_inference_fn(ppo_network)

eval_key = jax.random.PRNGKey(0)

evaluator = acting.Evaluator(
    eval_env,
    functools.partial(make_policy, deterministic=True),
    num_eval_envs=16,
    episode_length=200,
    action_repeat=1,
    key=eval_key,
)

key_policy, key_value = jax.random.split(eval_key)

init_params = losses.PPONetworkParams(
    policy=ppo_network.policy_network.init(key_policy),
    value=ppo_network.value_network.init(key_value),
)


normalizer_param = running_statistics.init_state(
    specs.Array((168,), jnp.dtype("float32"))
)

policy_params = (normalizer_param, init_params.policy)

TypeError: MultiClipTracking.__init__() missing 5 required positional arguments: 'reference_clip', 'walker', 'reward_config', 'physics_steps_per_control_step', and 'reset_noise_scale'

In [25]:
init_params

PPONetworkParams(policy={'params': {'encoder': {'hidden_0': {'kernel': Array([[-0.08926564, -0.17920424,  0.02216863, ...,  0.21081652,
        -0.02835298, -0.11204038],
       [ 0.19155349,  0.20925015, -0.12205659, ...,  0.04108075,
        -0.03041829, -0.18860896],
       [-0.17244923,  0.09366575,  0.19230407, ..., -0.0376111 ,
        -0.05099188, -0.11037949],
       ...,
       [-0.17868535, -0.02287448,  0.04045236, ..., -0.1842983 ,
         0.04223165,  0.20337617],
       [ 0.22031485, -0.08721233,  0.10910401, ...,  0.07660884,
         0.18766336,  0.06630293],
       [ 0.04210658,  0.13295233, -0.04475411, ..., -0.04505638,
         0.13203079,  0.18462384]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [34]:
from track_mjx.agent.checkpointing import (
    load_training_state,
    load_config_from_checkpoint,
    make_abstract_policy,
    load_checkpoint_for_eval,
)
from track_mjx.agent.ppo import TrainingState


from brax.training import acting
from brax.training.acme import running_statistics
from track_mjx.environment import wrappers
from track_mjx.agent import ppo_networks
from track_mjx.agent import losses
from brax.training.acme import specs
import functools
from jax import numpy as jnp

import optax


ckpt = load_checkpoint_for_eval(
    "/root/vast/scott-yang/track-mjx/model_checkpoints/250306_194809"
)

Loading checkpoint from /root/vast/scott-yang/track-mjx/model_checkpoints/250306_194809 at step 144


In [32]:
raw_decoder_param = ckpt["policy"][1]["params"]
raw_decoder_param

{'decoder': {'LayerNorm_0': {'bias': Array([ 1.96061194e-01,  1.38670459e-01,  1.69686630e-01,  1.11867584e-01,
           5.24064675e-02,  2.12541535e-01,  1.62234366e-01, -1.55701174e-03,
           2.18237355e-01,  1.25661403e-01,  1.83614895e-01,  1.68569446e-01,
           1.88297242e-01,  2.15668499e-01, -3.39656621e-01, -4.09842879e-02,
           1.18521914e-01,  8.26877877e-02, -8.75907466e-02,  1.17526025e-01,
           3.68886709e-01,  2.63486207e-01, -2.79327810e-01, -3.22497040e-01,
           5.13098761e-02,  7.14108795e-02, -1.14651151e-01,  1.35195538e-01,
          -5.41948378e-02,  3.13909531e-01,  4.25555445e-02,  9.82386395e-02,
          -6.64724708e-02,  2.73678154e-02, -4.72065024e-02, -3.16686854e-02,
           1.65353902e-02, -1.10959493e-01, -1.13595538e-02,  2.50108898e-01,
           1.16768047e-01,  1.20152086e-01,  1.68995216e-01,  2.98994750e-01,
           5.97618259e-02,  7.14793578e-02, -3.01619917e-02,  2.24365145e-01,
           8.83947685e-02,  9.

In [37]:
init_params.policy["params"]["decoder"] = raw_decoder_param

In [31]:
from flax import traverse_util


def create_policy_module_mask(
    ppo_params: losses.PPONetworkParams, freeze_submodule_name: str
):
    # Flatten the policy params
    flat_policy = traverse_util.flatten_dict(ppo_params.policy)

    # Create a mask dict for policy params
    policy_mask = {}
    for param_path in flat_policy:
        # param_path is a tuple like ('Dense_0', 'kernel')
        if freeze_submodule_name in param_path:
            policy_mask[param_path] = False  # Freeze this submodule
        else:
            policy_mask[param_path] = True  # Train other modules

    # Unflatten policy mask
    policy_mask = traverse_util.unflatten_dict(policy_mask)

    # The value network stays fully trainable here:
    value_mask = jax.tree_util.tree_map(lambda _: True, ppo_params.value)

    # Reconstruct the PPONetworkParams mask
    return losses.PPONetworkParams(policy=policy_mask, value=value_mask)


# Create a mask for the parameters
mask = create_policy_module_mask(init_params, "decoder")
mask

PPONetworkParams(policy={'params': {'encoder': {'hidden_0': {'kernel': True, 'bias': True}, 'LayerNorm_0': {'scale': True, 'bias': True}, 'hidden_1': {'kernel': True, 'bias': True}, 'LayerNorm_1': {'scale': True, 'bias': True}, 'hidden_2': {'kernel': True, 'bias': True}, 'LayerNorm_2': {'scale': True, 'bias': True}, 'fc2_mean': {'kernel': True, 'bias': True}, 'fc2_logvar': {'kernel': True, 'bias': True}}, 'decoder': {'hidden_0': {'kernel': False, 'bias': False}, 'LayerNorm_0': {'scale': False, 'bias': False}, 'hidden_1': {'kernel': False, 'bias': False}, 'LayerNorm_1': {'scale': False, 'bias': False}, 'hidden_2': {'kernel': False, 'bias': False}, 'LayerNorm_2': {'scale': False, 'bias': False}, 'hidden_3': {'kernel': False, 'bias': False}}}}, value={'params': {'hidden_0': {'bias': True, 'kernel': True}, 'hidden_1': {'bias': True, 'kernel': True}, 'hidden_2': {'bias': True, 'kernel': True}, 'hidden_3': {'bias': True, 'kernel': True}}})

In [22]:
init_params.policy["params"]

{'hidden_0': {'kernel': Array([[-0.03332756, -0.09089205, -0.05331829, ..., -0.03696887,
           0.0122158 , -0.031691  ],
         [-0.09496804,  0.11920086,  0.06245698, ..., -0.00633664,
           0.13164641, -0.06281761],
         [-0.04815069, -0.07481214,  0.10452467, ..., -0.06743849,
          -0.08343321,  0.09350343],
         ...,
         [-0.07320067,  0.02895315,  0.03827051, ..., -0.00677172,
           0.07714961, -0.07113333],
         [ 0.12086396,  0.11034053, -0.07441606, ..., -0.01116904,
          -0.06460512, -0.12223592],
         [ 0.09761509,  0.0153269 , -0.00826315, ...,  0.08571751,
           0.07862228,  0.00115547]], dtype=float32),
  'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       

In [30]:
init_params

PPONetworkParams(policy={'params': {'encoder': {'hidden_0': {'kernel': Array([[-0.08926564, -0.17920424,  0.02216863, ...,  0.21081652,
        -0.02835298, -0.11204038],
       [ 0.19155349,  0.20925015, -0.12205659, ...,  0.04108075,
        -0.03041829, -0.18860896],
       [-0.17244923,  0.09366575,  0.19230407, ..., -0.0376111 ,
        -0.05099188, -0.11037949],
       ...,
       [-0.17868535, -0.02287448,  0.04045236, ..., -0.1842983 ,
         0.04223165,  0.20337617],
       [ 0.22031485, -0.08721233,  0.10910401, ...,  0.07660884,
         0.18766336,  0.06630293],
       [ 0.04210658,  0.13295233, -0.04475411, ..., -0.04505638,
         0.13203079,  0.18462384]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [46]:
cfg_dict = ckpt["cfg"]
walker = Rodent(**cfg_dict["walker_config"])
reward_config = RewardConfig(**cfg_dict["env_rewards"])
reference_clip = load.load_data(cfg_dict["data_path"])

env = envs.get_environment(
    env_name=cfg.env_config.env_name,
    reference_clip=reference_clip,
    walker=walker,
    reward_config=reward_config,
    **env_args,
    **traj_config,
)

ConfigKeyError: Missing key env_rewards
    full_key: env_rewards
    object_type=dict

In [56]:
"""
Entry point for track-mjx. Load the config file, create environments, initialize network, and start training.
"""

import os
import sys

# set default env variable if not set
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = os.environ.get(
#     "XLA_PYTHON_CLIENT_MEM_FRACTION", "0.9"
# )

# limit to 1 GPU
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Use only GPU 0

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = os.environ.get("MUJOCO_GL", "egl")
os.environ["PYOPENGL_PLATFORM"] = os.environ.get("PYOPENGL_PLATFORM", "egl")
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true --xla_gpu_triton_gemm_any=True --xla_dump_to=/tmp/foo"
)

os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"

# os.environ["JAX_LOG_COMPILES"] = "1"

# # (Optional) For more detailed logging
# os.environ["JAX_LOG_COMPILES_VERBOSE"] = "1"

import jax

jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update(
    "jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir"
)

import hydra
from omegaconf import DictConfig, OmegaConf
import functools
import jax
import wandb
from brax import envs
import orbax.checkpoint as ocp
from track_mjx.agent import ppo, ppo_networks
import warnings
from pathlib import Path
from datetime import datetime
import logging

from track_mjx.io import load
from track_mjx.environment.task.multi_clip_tracking import MultiClipTracking
from track_mjx.environment.task.single_clip_tracking import SingleClipTracking
from track_mjx.environment.task.joysticks_brax import RodentJoystick
from track_mjx.environment import wrappers
from track_mjx.agent import checkpointing
from track_mjx.agent.logging import rollout_logging_fn, make_rollout_renderer
from track_mjx.environment.walker.rodent import Rodent
from track_mjx.environment.walker.fly import Fly
from track_mjx.environment.task.reward import RewardConfig

warnings.filterwarnings("ignore", category=DeprecationWarning)

_WALKERS = {
    "rodent": Rodent,
    "fly": Fly,
}

"""Main function using Hydra configs"""
try:
    n_devices = jax.device_count(backend="gpu")
    logging.info(f"Using {n_devices} GPUs")
except:
    n_devices = 1
    logging.info("Not using GPUs")

envs.register_environment("rodent_single_clip", SingleClipTracking)
envs.register_environment("rodent_multi_clip", MultiClipTracking)
envs.register_environment("fly_multi_clip", MultiClipTracking)
envs.register_environment("rodent_joystick", RodentJoystick)

hydra.core.global_hydra.GlobalHydra.instance().clear()

hydra.initialize(config_path="../track_mjx/config", job_name="train")

cfg = hydra.compose(config_name="rodent-joystick")
cfg_dict = OmegaConf.to_container(cfg, resolve=True)

# Generate a new run_id and associated checkpoint path
run_id = datetime.now().strftime("%y%m%d_%H%M%S")
# TODO: Use a base path given by the config
checkpoint_path = hydra.utils.to_absolute_path(
    f"./{cfg.logging_config.model_path}/{run_id}"
)

# Load the checkpoint's config
if (
    cfg.train_setup["checkpoint_to_restore"] is not None and False
):  # TODO: Only load decoder
    # TODO: We set the restored config's checkpoint_to_restore to itself
    # Because that restored config is used from now on. This is a hack.
    checkpoint_to_restore = cfg.train_setup["checkpoint_to_restore"]
    # Load the checkpoint's config and update the run_id and checkpoint path
    cfg = OmegaConf.create(
        checkpointing.load_config_from_checkpoint(
            cfg.train_setup["checkpoint_to_restore"]
        )
    )
    cfg.train_setup["checkpoint_to_restore"] = checkpoint_to_restore
    checkpoint_path = Path(checkpoint_to_restore)
    run_id = checkpoint_path.name

# Initialize checkpoint manager
mgr_options = ocp.CheckpointManagerOptions(
    create=True,
    max_to_keep=cfg.train_setup["checkpoint_max_to_keep"],
    keep_period=cfg.train_setup["checkpoint_keep_period"],
    step_prefix="PPONetwork",
)

ckpt_mgr = ocp.CheckpointManager(checkpoint_path, options=mgr_options)

logging.info(f"run_id: {run_id}")
logging.info(f"Training checkpoint path: {checkpoint_path}")

env_args = cfg.env_config["env_args"]
env_rewards = cfg.env_config["reward_weights"]
train_config = cfg.train_setup["train_config"]
walker_config = cfg["walker_config"]
traj_config = cfg["reference_config"]

logging.info(f"Loading data: {cfg.data_path}")
reference_clip = load.load_data(cfg.data_path)

if cfg.env_config["env_name"] == "rodent_joystick":
    env = envs.get_environment("rodent_joystick")
    episode_length = cfg.train_setup.episode_length
    network_factory = functools.partial(
        ppo_networks.make_mlp_ppo_networks,
        policy_hidden_layer_sizes=tuple(cfg.network_config.policy_hidden_layer_sizes),
        value_hidden_layer_sizes=tuple(cfg.network_config.critic_layer_sizes),
    )
else:
    walker = _WALKERS[cfg_dict["walker_type"]](**walker_config)
    reward_config = RewardConfig(**env_rewards)

    env = envs.get_environment(
        env_name=cfg.env_config.env_name,
        reference_clip=reference_clip,
        walker=walker,
        reward_config=reward_config,
        **env_args,
        **traj_config,
    )

    # Episode length is equal to (clip length - random init range - traj length) * steps per cur frame.
    episode_length = (
        traj_config.clip_length
        - traj_config.random_init_range
        - traj_config.traj_length
    ) * env._steps_for_cur_frame
    print(f"episode_length {episode_length}")
    logging.info(f"episode_length {episode_length}")

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(config_path="../track_mjx/config", job_name="train")


env._steps_for_cur_frame: 2.0
episode_length 390.0


In [59]:
rng = jax.random.PRNGKey(0)
state = env.reset(rng)

reference_obs.shape: (Array(15, dtype=int32, weak_type=True),)


In [64]:
state.obs[:15]

Array([-0.00156613,  0.00447299,  0.00611204, -0.00163687,  0.00477304,
        0.00604369, -0.00171546,  0.00502477,  0.00570974, -0.00173155,
        0.0048846 ,  0.00590869, -0.00205544,  0.00483106,  0.00578621],      dtype=float32)