In [32]:
import orbax.checkpoint as ocp


from track_mjx.agent.custom_ppo import TrainingState
from track_mjx.agent import custom_losses as ppo_losses
from track_mjx.agent import custom_ppo, custom_ppo_networks
import jax
from jax import numpy as jnp
import optax
from brax import envs
from brax.training.acme import running_statistics, specs
from track_mjx.environment import custom_wrappers
from track_mjx.environment.task.multi_clip_tracking import RodentMultiClipTracking
from track_mjx.environment.task.single_clip_tracking import RodentTracking
from track_mjx.environment.walker.rodent import Rodent
import hydra
from hydra import initialize, compose
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
import pickle
import functools

if GlobalHydra.instance().is_initialized():
    GlobalHydra.instance().clear()
initialize(version_base=None, config_path="./../track_mjx/config")

# Load the config file
cfg = compose(config_name="rodent-mc-intention")


envs.register_environment("single clip", RodentTracking)
envs.register_environment("multi clip", RodentMultiClipTracking)

# config files
env_cfg = hydra.compose(config_name="rodent-mc-intention")
env_args = cfg.env_config["env_args"]
env_rewards = cfg.env_config["reward_weights"]
train_config = cfg.train_setup["train_config"]
wlaker_config = cfg["walker_config"]

In [None]:
# TODO (Kevin): add this as a yaml config
walker = Rodent(**wlaker_config)

# TODO(Scott): move this to track_mjx.io module
input_data_path = hydra.utils.to_absolute_path("../data/twoClips.p")
print(f"Loading data: {input_data_path}")
with open(input_data_path, "rb") as file:
    reference_clip = pickle.load(file)

# Automatically match dict keys and func needs
environment = envs.get_environment(
    env_name=cfg.env_config.env_name,
    reference_clip=reference_clip,
    walker=walker,
    **env_args,
    **env_rewards,
)

# Episode length is equal to (clip length - random init range - traj length) * steps per cur frame.
# Will work on not hardcoding these values later
episode_length = (250 - 50 - 5) * environment._steps_for_cur_frame
print(f"episode_length {episode_length}")


# put your model path here
model_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/930248f3-2319-4028-b77a-5d503cb58f6b"

# initialize orbax checkpoint manager
# TODO (Scott): add the checkpoint parameter to config file.
mgr_options = ocp.CheckpointManagerOptions(create=True, max_to_keep=3, keep_period=2, step_prefix="PPONetwork")
ckpt_mgr = ocp.CheckpointManager(model_path, ocp.Checkpointer(ocp.PyTreeCheckpointHandler()), mgr_options)

train_fn = functools.partial(
    custom_ppo.train,
    **train_config,
    num_evals=int(cfg.train_setup.train_config.num_timesteps / cfg.train_setup.eval_every),
    episode_length=episode_length,
    kl_weight=cfg.network_config.kl_weight,
    network_factory=functools.partial(
        custom_ppo_networks.make_intention_ppo_networks,
        encoder_hidden_layer_sizes=tuple(cfg.network_config.encoder_layer_sizes),
        decoder_hidden_layer_sizes=tuple(cfg.network_config.decoder_layer_sizes),
        value_hidden_layer_sizes=tuple(cfg.network_config.critic_layer_sizes),
    ),
    ckpt_mgr=ckpt_mgr,
)

network_factory = functools.partial(
    custom_ppo_networks.make_intention_ppo_networks,
    encoder_hidden_layer_sizes=tuple(cfg.network_config.encoder_layer_sizes),
    decoder_hidden_layer_sizes=tuple(cfg.network_config.decoder_layer_sizes),
    value_hidden_layer_sizes=tuple(cfg.network_config.critic_layer_sizes),
)


seed = 42

key = jax.random.PRNGKey(seed)
global_key, local_key = jax.random.split(key)
local_key, key_env, eval_key = jax.random.split(local_key, 3)
# key_networks should be global, so that networks are initialized the same
# way for different processes.
key_policy, key_value, policy_params_fn_key = jax.random.split(global_key, 3)


v_randomization_fn = None

if isinstance(environment, envs.Env):
    wrap_for_training = custom_wrappers.wrap
else:
    wrap_for_training = custom_wrappers.wrap

env = wrap_for_training(
    environment,
    episode_length=episode_length,
    action_repeat=1,
    randomization_fn=v_randomization_fn,
)

reset_fn = env.reset
key_envs = jax.random.split(key_env, 1)
env_state = reset_fn(key_envs)

normalize = lambda x, y: x
if True:
    normalize = running_statistics.normalize
ppo_network = network_factory(
    env_state.obs.shape[-1],
    int(env_state.info["reference_obs_size"][0]),
    env.action_size,
    preprocess_observations_fn=normalize,
)
make_policy = custom_ppo_networks.make_inference_fn(ppo_network)

optimizer = optax.adam(learning_rate=1e-4)

init_params = ppo_losses.PPONetworkParams(
    policy=ppo_network.policy_network.init(key_policy),
    value=ppo_network.value_network.init(key_value),
)

training_state = TrainingState(  # pytype: disable=wrong-arg-types  # jax-ndarray
    optimizer_state=optimizer.init(init_params),  # pytype: disable=wrong-arg-types  # numpy-scalars
    params=init_params,
    normalizer_params=running_statistics.init_state(specs.Array(env_state.obs.shape[-1:], jnp.dtype("float32"))),
    env_steps=0,
)

abstract_policy = (training_state.normalizer_params, training_state.params.policy)

In [41]:
# restore the whole model (both policy module, and training state)

options = ocp.CheckpointManagerOptions(step_prefix="PPONetwork")
with ocp.CheckpointManager(
    "/root/vast/scott-yang/track-mjx/model_checkpoints/930248f3-2319-4028-b77a-5d503cb58f6b",
    options=options,
) as mngr:
    cp = mngr.restore(0, args=ocp.args.Composite(policy=ocp.args.StandardRestore(abstract_policy)))
    ts = mngr.restore(0, args=ocp.args.Composite(train_state=ocp.args.StandardRestore(training_state)))

In [None]:
ts

# The following is how we save multiple items -- with the new `ocp` API

In [23]:
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
    ocp.test_utils.erase_and_create_empty("/root/vast/scott-yang/track-mjx/notebooks/test/ckpt4"),
    # `item_names` defines an up-front contract about what items the
    # CheckpointManager will be dealing with.
    options=options,
)
pytree = {"A": 1}
extra_metadata = {"B": 2}
mngr.save(
    0, args=ocp.args.Composite(state=ocp.args.StandardSave(pytree), extra_metadata=ocp.args.JsonSave(extra_metadata))
)
mngr.wait_until_finished()

In [None]:
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
    ocp.test_utils.erase_and_create_empty("/root/vast/scott-yang/track-mjx/notebooks/test/ckpt3"),
    {
        "state": ocp.Checkpointer(ocp.PyTreeCheckpointHandler()),
        "extra_metadata": ocp.Checkpointer(ocp.JsonCheckpointHandler()),
    },
    options=options,
)

restore_args = ocp.checkpoint_utils.construct_restore_args(pytree)
mngr.save(0, {"state": pytree, "extra_metadata": extra_metadata})
mngr.wait_until_finished()

mngr.restore(
    0,
    items={"state": pytree, "extra_metadata": None},
    restore_kwargs={"state": {"restore_args": restore_args}, "extra_metadata": None},
)

In [28]:
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
    "/root/vast/scott-yang/track-mjx/notebooks/test/ckpt4",
    # `item_names` defines an up-front contract about what items the
    # CheckpointManager will be dealing with.
    options=options,
)

# Notes

It is possible to only load partially of the checkpoint, but I need to create the abstract pytree using the previous methods, which should not be that hard.

The following cell demonstrates how to load only the state, but not the extra metadata.

In [None]:
abstract_pytree = {"A": 1}
mngr.restore(0, args=ocp.args.Composite(state=ocp.args.StandardRestore(abstract_pytree)))