# Notebook for testing joystick environment for rodent

In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
from track_mjx.environment.task import joysticks_brax
from track_mjx.environment.task.joysticks_brax import RodentJoystick
import jax
import jax.numpy as jnp
import os

import logging

logging.basicConfig(level=logging.DEBUG, force=True)

os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
os.environ["JAX_LOG_COMPILES"] = "1"
# (Optional) For more detailed logging
os.environ["JAX_LOG_COMPILES_VERBOSE"] = "1"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true --xla_gpu_triton_gemm_any=True --xla_dump_to=/tmp/foo"
)

# Optionally, you can also redirect stderr to stdout so that you see logs in the normal notebook output:
import sys

sys.stderr = sys.stdout

# Enable persistent compilation cache.
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update(
    "jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir"
)

DEBUG:2025-03-15 17:42:31,366:jax._src.dispatch:182: Finished tracing + transforming reset for pjit in 19.575861692 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming reset for pjit in 19.575861692 sec
DEBUG:2025-03-15 17:42:31,786:jax._src.interpreters.pxla:1913: Compiling reset with global shapes and types [ShapedArray(uint32[4,16,2])]. Argument mapping: (UnspecifiedValue,).
DEBUG:jax._src.interpreters.pxla:Compiling reset with global shapes and types [ShapedArray(uint32[4,16,2])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2025-03-15 17:42:34,896:jax._src.dispatch:182: Finished jaxpr to MLIR module conversion jit(reset) in 3.098076582 sec
DEBUG:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(reset) in 3.098076582 sec
DEBUG:2025-03-15 17:42:34,899:jax._src.compiler:168: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CudaDevice(id=0)]]
DEBUG:jax._src.compiler:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[Cu

In [3]:
env = joysticks_brax.RodentJoystick()
rng = jax.random.PRNGKey(0)


jit_reset, jit_step = jax.jit(env.reset), jax.jit(env.step)

In [4]:
init_state = jit_reset(rng)
state = init_state

In [5]:
# state = jit_step(state, jnp.zeros(env.action_size))

In [6]:
state.obs.shape

(181,)

In [7]:
from brax.training import acting
from brax.training.acme import running_statistics
from track_mjx.environment import wrappers
from track_mjx.agent import ppo_networks
from track_mjx.agent import losses
from brax.training.acme import specs
import functools

network_factory = functools.partial(
    ppo_networks.make_mlp_ppo_networks,
    policy_hidden_layer_sizes=tuple([256, 256, 256]),
    value_hidden_layer_sizes=tuple([256, 256, 256]),
)

ppo_network = network_factory(
    state.obs.shape[-1],
    # int(_unpmap(env_state.info["reference_obs_size"])[0]),
    env.action_size,
    preprocess_observations_fn=running_statistics.normalize,
)
eval_env = joysticks_brax.RodentJoystick()

eval_env = wrappers.wrap(eval_env)
make_policy = ppo_networks.make_mlp_inference_fn(ppo_network)

eval_key = jax.random.PRNGKey(0)

evaluator = acting.Evaluator(
    eval_env,
    functools.partial(make_policy, deterministic=True),
    num_eval_envs=16,
    episode_length=200,
    action_repeat=1,
    key=eval_key,
)

key_policy, key_value = jax.random.split(eval_key)

init_params = losses.PPONetworkParams(
    policy=ppo_network.policy_network.init(key_policy),
    value=ppo_network.value_network.init(key_value),
)


normalizer_param = running_statistics.init_state(
    specs.Array(state.obs.shape[-1:], jnp.dtype("float32"))
)

policy_params = (normalizer_param, init_params.policy)

# metrics = evaluator.run_evaluation(policy_params, training_metrics={})

In [None]:
metrics = evaluator.run_evaluation(policy_params, training_metrics={})
metrics

{'eval/walltime': 313.2666697502136,
 'eval/episode_reward': Array(0.01136403, dtype=float32),
 'eval/episode_reward/action_rate': Array(-5.8399577, dtype=float32),
 'eval/episode_reward/ang_vel_xy': Array(-458.61618, dtype=float32),
 'eval/episode_reward/dof_pos_limits': Array(-7.0666866, dtype=float32),
 'eval/episode_reward/energy': Array(0., dtype=float32),
 'eval/episode_reward/feet_air_time': Array(0., dtype=float32),
 'eval/episode_reward/feet_clearance': Array(0., dtype=float32),
 'eval/episode_reward/feet_height': Array(0., dtype=float32),
 'eval/episode_reward/feet_slip': Array(0., dtype=float32),
 'eval/episode_reward/lin_vel_z': Array(-5.1647654, dtype=float32),
 'eval/episode_reward/orientation': Array(-85.58583, dtype=float32),
 'eval/episode_reward/pose': Array(14.885166, dtype=float32),
 'eval/episode_reward/stand_still': Array(0., dtype=float32),
 'eval/episode_reward/termination': Array(-1., dtype=float32),
 'eval/episode_reward/torques': Array(-0.12001748, dtype=floa

# Fuck! Run PPO here!

In [13]:
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Proximal policy optimization training.

See: https://arxiv.org/pdf/1707.06347.pdf
"""

import functools
import time
from typing import Callable, Optional, Tuple, Union

from absl import logging
from brax import base
from brax import envs
from brax.training import acting
from brax.training import pmap
from brax.training import types
from brax.training import gradients
from brax.training.acme import running_statistics
from brax.training.acme import specs
from brax.training.types import Params
from brax.training.types import PRNGKey
from brax.v1 import envs as envs_v1
import flax.training
import wandb

from track_mjx.agent import losses, ppo_networks
from track_mjx.environment import wrappers

import flax
import flax.struct
import jax
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp

InferenceParams = Tuple[running_statistics.NestedMeanStd, Params]
Metrics = types.Metrics

_PMAP_AXIS_NAME = "i"


@flax.struct.dataclass
class TrainingState:
    """Contains training state for the learner."""

    optimizer_state: optax.OptState
    params: losses.PPONetworkParams
    normalizer_params: running_statistics.RunningStatisticsState
    env_steps: jnp.ndarray


from track_mjx.agent import checkpointing


def _unpmap(v):
    return jax.tree_util.tree_map(lambda x: x[0], v)  # TODO: Change Back after debug


def _strip_weak_type(tree):
    # brax user code is sometimes ambiguous about weak_type.  in order to
    # avoid extra jit recompilations we strip all weak types from user input
    def f(leaf):
        leaf = jnp.asarray(leaf)
        return leaf.astype(leaf.dtype)

    return jax.tree_util.tree_map(f, tree)


# TODO: Pass in a loss-specific config instead of throwing them all in individually.
def train(
    environment: Union[envs_v1.Env, envs.Env],
    num_timesteps: int,
    episode_length: int,
    ckpt_mgr: ocp.CheckpointManager,
    checkpoint_to_restore: str | None = None,
    action_repeat: int = 1,
    num_envs: int = 64,
    max_devices_per_host: Optional[int] = None,
    num_eval_envs: int = 128,
    learning_rate: float = 1e-4,
    entropy_cost: float = 1e-4,
    kl_weight: float = 1e-3,
    discounting: float = 0.9,
    seed: int = 0,
    unroll_length: int = 10,
    batch_size: int = 32,
    num_minibatches: int = 16,
    num_updates_per_batch: int = 2,
    num_evals: int = 20,
    num_resets_per_eval: int = 0,
    normalize_observations: bool = False,
    reward_scaling: float = 1.0,
    clipping_epsilon: float = 0.3,
    gae_lambda: float = 0.95,
    deterministic_eval: bool = False,
    network_factory: types.NetworkFactory[
        ppo_networks.PPOImitationNetworks
    ] = ppo_networks.make_intention_ppo_networks,
    progress_fn: Callable[[int, Metrics], None] = lambda *args: None,
    normalize_advantage: bool = True,
    eval_env: Optional[envs.Env] = None,
    policy_params_fn: Callable[..., None] = lambda *args: None,
    randomization_fn: Optional[
        Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]
    ] = None,
    use_kl_schedule: bool = True,
    kl_ramp_up_frac: float = 0.25,
):
    """PPO training.

    Args:
      environment: the environment to train
      num_timesteps: the total number of environment steps to use during training
      episode_length: the length of an environment episode
      ckpt_mgr: an orbax checkpoint manager for saving policy checkpoints
      config_dict: a dictionary that contains the configuration for the training,
        will be saved to the orbax checkpoint alongside with the policy and training state
      checkpoint_to_restore: Optional path for a checkpoint to load to resume training
      action_repeat: the number of timesteps to repeat an action
      num_envs: the number of parallel environments to use for rollouts
        NOTE: `num_envs` must be divisible by the total number of chips since each
          chip gets `num_envs // total_number_of_chips` environments to roll out
        NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since
          data generated by `num_envs` parallel envs gets used for gradient
          updates over `num_minibatches` of data, where each minibatch has a
          leading dimension of `batch_size`
      max_devices_per_host: maximum number of chips to use per host process
      num_eval_envs: the number of envs to use for evluation. Each env will run 1
        episode, and all envs run in parallel during eval.
      learning_rate: learning rate for ppo loss
      entropy_cost: entropy reward for ppo loss, higher values increase entropy
        of the policy
      discounting: discounting rate
      seed: random seed
      unroll_length: the number of timesteps to unroll in each environment. The
        PPO loss is computed over `unroll_length` timesteps
      batch_size: the batch size for each minibatch SGD step
      num_minibatches: the number of times to run the SGD step, each with a
        different minibatch with leading dimension of `batch_size`
      num_updates_per_batch: the number of times to run the gradient update over
        all minibatches before doing a new environment rollout
      num_evals: the number of evals to run during the entire training run.
        Increasing the number of evals increases total training time
      num_resets_per_eval: the number of environment resets to run between each
        eval. The environment resets occur on the host
      normalize_observations: whether to normalize observations
      reward_scaling: float scaling for reward
      clipping_epsilon: clipping epsilon for PPO loss
      gae_lambda: General advantage estimation lambda
      deterministic_eval: whether to run the eval with a deterministic policy
      network_factory: function that generates networks for policy and value
        functions
      progress_fn: a user-defined callback function for reporting/plotting metrics
      normalize_advantage: whether to normalize advantage estimate
      eval_env: an optional environment for eval only, defaults to `environment`
      policy_params_fn: a user-defined callback function that can be used for
        saving policy checkpoints
      randomization_fn: a user-defined callback function that generates randomized
        environments
      use_kl_schedule: whether to use a ramping schedule for the kl weight in the PPO loss
        (intention network variational layer)
      kl_ramp_up_frac: the fraction of the total number of evals to ramp up max kl weight


    Returns:
      Tuple of (make_policy function, network params, metrics)
    """
    assert batch_size * num_minibatches % num_envs == 0, (
        batch_size * num_minibatches % num_envs
    )
    xt = time.time()

    process_count = jax.process_count()
    process_id = jax.process_index()
    local_device_count = jax.local_device_count()
    local_devices_to_use = local_device_count
    if max_devices_per_host:
        local_devices_to_use = min(local_devices_to_use, max_devices_per_host)
    logging.info(
        "Device count: %d, process count: %d (id %d), local device count: %d, "
        "devices to be used count: %d",
        jax.device_count(),
        process_count,
        process_id,
        local_device_count,
        local_devices_to_use,
    )
    device_count = local_devices_to_use * process_count

    # The number of environment steps executed for every training step.
    env_step_per_training_step = (
        batch_size * unroll_length * num_minibatches * action_repeat
    )
    num_evals_after_init = max(num_evals - 1, 1)
    # The number of training_step calls per training_epoch call.
    # equals to ceil(num_timesteps / (num_evals * env_step_per_training_step *
    #                                 num_resets_per_eval))
    num_training_steps_per_epoch = np.ceil(
        num_timesteps
        / (
            num_evals_after_init
            * env_step_per_training_step
            * max(num_resets_per_eval, 1)
        )
    ).astype(int)

    key = jax.random.PRNGKey(seed)
    global_key, local_key = jax.random.split(key)
    del key
    local_key = jax.random.fold_in(local_key, process_id)
    local_key, key_env, eval_key = jax.random.split(local_key, 3)
    # key_networks should be global, so that networks are initialized the same
    # way for different processes.
    key_policy, key_value, policy_params_fn_key = jax.random.split(global_key, 3)
    del global_key

    print(device_count)

    assert num_envs % device_count == 0

    v_randomization_fn = None
    if randomization_fn is not None:
        randomization_batch_size = num_envs // local_device_count
        # all devices gets the same randomization rng
        randomization_rng = jax.random.split(key_env, randomization_batch_size)
        v_randomization_fn = functools.partial(randomization_fn, rng=randomization_rng)

    if isinstance(environment, envs.Env):
        wrap_for_training = wrappers.wrap
    else:
        wrap_for_training = envs_v1.wrappers.wrap_for_training

    env = wrap_for_training(
        environment,
        episode_length=episode_length,
        action_repeat=action_repeat,
        randomization_fn=v_randomization_fn,
    )

    reset_fn = jax.jit(jax.vmap(env.reset))
    key_envs = jax.random.split(key_env, num_envs // process_count)
    key_envs = jnp.reshape(key_envs, (local_devices_to_use, -1) + key_envs.shape[1:])
    env_state = reset_fn(key_envs)

    normalize = lambda x, y: x
    if normalize_observations:
        normalize = running_statistics.normalize
    ppo_network = network_factory(
        env_state.obs.shape[-1],
        # int(_unpmap(env_state.info["reference_obs_size"])[0]),
        env.action_size,
        preprocess_observations_fn=normalize,
    )
    make_policy = ppo_networks.make_mlp_inference_fn(ppo_network)

    make_logging_policy = ppo_networks.make_logging_inference_fn(ppo_network)
    jit_logging_inference_fn = jax.jit(make_logging_policy(deterministic=True))

    optimizer = optax.adam(learning_rate=learning_rate)

    kl_schedule = None
    if use_kl_schedule:
        kl_schedule = losses.create_ramp_schedule(
            max_value=kl_weight, ramp_steps=int(num_evals * kl_ramp_up_frac)
        )

    loss_fn = functools.partial(
        losses.compute_ppo_loss,
        ppo_network=ppo_network,
        entropy_cost=entropy_cost,
        kl_weight=kl_weight,
        discounting=discounting,
        reward_scaling=reward_scaling,
        gae_lambda=gae_lambda,
        clipping_epsilon=clipping_epsilon,
        normalize_advantage=normalize_advantage,
        kl_schedule=kl_schedule,
        network_type="mlp",  # TODO: use network type from config
    )

    gradient_update_fn = gradients.gradient_update_fn(
        loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True
    )

    def minibatch_step(
        carry,
        data: types.Transition,
        normalizer_params: running_statistics.RunningStatisticsState,
    ):
        optimizer_state, params, key, it = carry
        key, key_loss = jax.random.split(key)
        (_, metrics), params, optimizer_state = gradient_update_fn(
            params,
            normalizer_params,
            data,
            key_loss,
            it,
            optimizer_state=optimizer_state,
        )

        return (optimizer_state, params, key, it), metrics

    def sgd_step(
        carry,
        unused_t,
        data: types.Transition,
        normalizer_params: running_statistics.RunningStatisticsState,
    ):
        optimizer_state, params, key, it = carry
        key, key_perm, key_grad = jax.random.split(key, 3)

        def convert_data(x: jnp.ndarray):
            x = jax.random.permutation(key_perm, x)
            x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])
            return x

        shuffled_data = jax.tree_util.tree_map(convert_data, data)
        (optimizer_state, params, _, _), metrics = jax.lax.scan(
            functools.partial(minibatch_step, normalizer_params=normalizer_params),
            (optimizer_state, params, key_grad, it),
            shuffled_data,
            length=num_minibatches,
        )
        return (optimizer_state, params, key, it), metrics

    def training_step(
        carry: Tuple[TrainingState, envs.State, PRNGKey, int], unused_t
    ) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey, int], Metrics]:
        training_state, state, key, it = carry
        key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)

        policy = make_policy(
            (training_state.normalizer_params, training_state.params.policy)
        )

        def f(carry, unused_t):
            current_state, current_key = carry
            current_key, next_key = jax.random.split(current_key)
            next_state, data = acting.generate_unroll(
                env,
                current_state,
                policy,
                current_key,
                unroll_length,
                extra_fields=("truncation",),
            )
            return (next_state, next_key), data

        (state, _), data = jax.lax.scan(
            f,
            (state, key_generate_unroll),
            (),
            length=batch_size * num_minibatches // num_envs,
        )
        # Have leading dimensions (batch_size * num_minibatches, unroll_length)
        data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data)
        data = jax.tree_util.tree_map(
            lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data
        )
        assert data.discount.shape[1:] == (unroll_length,)

        # Update normalization params and normalize observations.
        normalizer_params = running_statistics.update(
            training_state.normalizer_params,
            data.observation,
            pmap_axis_name=_PMAP_AXIS_NAME,
        )

        (optimizer_state, params, _, _), metrics = jax.lax.scan(
            functools.partial(sgd_step, data=data, normalizer_params=normalizer_params),
            (training_state.optimizer_state, training_state.params, key_sgd, it),
            (),
            length=num_updates_per_batch,
        )

        new_training_state = TrainingState(
            optimizer_state=optimizer_state,
            params=params,
            normalizer_params=normalizer_params,
            env_steps=jnp.int32(
                training_state.env_steps + env_step_per_training_step / 1e3
            ),  # env step in thousands
        )
        return (new_training_state, state, new_key, it), metrics

    def training_epoch(
        training_state: TrainingState, state: envs.State, key: PRNGKey, it: int
    ) -> Tuple[TrainingState, envs.State, Metrics]:
        (training_state, state, _, _), loss_metrics = jax.lax.scan(
            training_step,
            (training_state, state, key, it),
            (),
            length=num_training_steps_per_epoch,
        )
        loss_metrics = jax.tree_util.tree_map(jnp.mean, loss_metrics)
        return training_state, state, loss_metrics

    # TODO: remove this after debugging
    training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME)

    # Note that this is NOT a pure jittable method.
    def training_epoch_with_timing(
        training_state: TrainingState, env_state: envs.State, key: PRNGKey, it: int
    ) -> Tuple[TrainingState, envs.State, Metrics]:
        nonlocal training_walltime
        t = time.time()
        training_state, env_state = _strip_weak_type((training_state, env_state))
        step = jnp.ones_like(training_state.env_steps) * it
        result = training_epoch(training_state, env_state, key, step)
        training_state, env_state, metrics = _strip_weak_type(result)

        metrics = jax.tree_util.tree_map(jnp.mean, metrics)
        jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics)

        epoch_training_time = time.time() - t
        training_walltime += epoch_training_time
        sps = (
            num_training_steps_per_epoch
            * env_step_per_training_step
            * max(num_resets_per_eval, 1)
        ) / epoch_training_time
        metrics = {
            "training/sps": sps,
            "training/walltime": training_walltime,
            **{f"training/{name}": value for name, value in metrics.items()},
        }
        return (
            training_state,
            env_state,
            metrics,
        )  # pytype: disable=bad-return-type  # py311-upgrade

    init_params = losses.PPONetworkParams(
        policy=ppo_network.policy_network.init(key_policy),
        value=ppo_network.value_network.init(key_value),
    )
    training_state = TrainingState(  # pytype: disable=wrong-arg-types  # jax-ndarray
        optimizer_state=optimizer.init(
            init_params
        ),  # pytype: disable=wrong-arg-types  # numpy-scalars
        params=init_params,
        normalizer_params=running_statistics.init_state(
            specs.Array(env_state.obs.shape[-1:], jnp.dtype("float32"))
        ),
        env_steps=0,
    )

    # Load the checkpoint if it exists
    if checkpoint_to_restore is not None:
        training_state = checkpointing.load_training_state(
            checkpoint_to_restore, training_state
        )
        logging.info(f"Restored latest checkpoint at {checkpoint_to_restore}")

    training_state = jax.device_put_replicated(
        training_state, jax.local_devices()[:local_devices_to_use]
    )

    if not eval_env:
        eval_env = environment
    if randomization_fn is not None:
        v_randomization_fn = functools.partial(
            randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)
        )
    eval_env = wrap_for_training(
        eval_env,
        episode_length=episode_length,
        action_repeat=action_repeat,
        randomization_fn=v_randomization_fn,
    )

    evaluator = acting.Evaluator(
        eval_env,
        functools.partial(make_policy, deterministic=deterministic_eval),
        num_eval_envs=num_eval_envs,
        episode_length=episode_length,
        action_repeat=action_repeat,
        key=eval_key,
    )

    # Logic to restore iteration count from checkpoint
    start_it = 0
    if ckpt_mgr is not None:
        if ckpt_mgr.latest_step() is not None:
            num_evals_after_init -= ckpt_mgr.latest_step()
            start_it = ckpt_mgr.latest_step()

    print(f"Starting at iteration: {start_it} with {num_evals_after_init} evals left")

    # Run initial eval
    metrics = {}
    if process_id == 0 and num_evals > 1:
        policy_param = _unpmap(
            (training_state.normalizer_params, training_state.params.policy)
        )
        metrics = evaluator.run_evaluation(
            policy_param,
            training_metrics={},
        )
        logging.info(metrics)
        progress_fn(start_it, metrics)
        # Save checkpoints
        logging.info("Saving initial checkpoint")
        if ckpt_mgr is not None:
            # new orbax API
            ckpt_mgr.save(
                step=0,
                args=ocp.args.Composite(
                    policy=ocp.args.StandardSave(policy_param),
                    train_state=ocp.args.StandardSave(_unpmap(training_state)),
                    config=ocp.args.JsonSave({}),
                ),
            )
        else:
            logging.info("Skipping checkpoint save as ckpt_mgr is None")

    training_metrics = {}
    training_walltime = 0
    start_it += 1
    current_step = 0
    for it in range(start_it, num_evals_after_init + start_it):
        logging.info("starting iteration %s %s", it, time.time() - xt)
        for _ in range(max(num_resets_per_eval, 1)):
            # optimization
            epoch_key, local_key = jax.random.split(local_key)
            epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
            (training_state, env_state, training_metrics) = training_epoch_with_timing(
                training_state, env_state, epoch_keys, it
            )
            current_step = int(_unpmap(training_state.env_steps))

            key_envs = jax.vmap(
                lambda x, s: jax.random.split(x[0], s), in_axes=(0, None)
            )(key_envs, key_envs.shape[1])
            # TODO: move extra reset logic to the AutoResetWrapper.
            env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state

        if process_id == 0:
            # Run evaluation rollout, logging and checkpointing.
            metrics = evaluator.run_evaluation(
                _unpmap(
                    (training_state.normalizer_params, training_state.params.policy)
                ),
                training_metrics,
            )
            logging.info(metrics)
            progress_fn(current_step, metrics)

            policy_param = _unpmap(
                (training_state.normalizer_params, training_state.params.policy)
            )
            if policy_params_fn is not None:
                # Do policy evaluation and logging.
                _, policy_params_fn_key = jax.random.split(policy_params_fn_key)
                policy_params_fn(
                    current_step=it,
                    jit_logging_inference_fn=jit_logging_inference_fn,
                    params=policy_param,
                    policy_params_fn_key=policy_params_fn_key,
                )
            # Save checkpoint
            if ckpt_mgr is not None:
                checkpointing.save(
                    ckpt_mgr, it, policy_param, _unpmap(training_state), {}
                )

    total_steps = current_step
    assert total_steps >= num_timesteps

    # If there was no mistakes the training_state should still be identical on all
    # devices.
    pmap.assert_is_replicated(training_state)
    params = _unpmap((training_state.normalizer_params, training_state.params.policy))
    logging.info("total steps: %s", total_steps)
    pmap.synchronize_hosts()
    return (make_policy, params, metrics)

In [9]:
checkpoint_path = "checkpoints/rodent_joystick"
# Initialize checkpoint manager
mgr_options = ocp.CheckpointManagerOptions(
    create=True,
    max_to_keep=3,
    keep_period=6,
    step_prefix="PPONetwork",
)

ckpt_mgr = ocp.CheckpointManager(checkpoint_path, options=mgr_options)

train_fn = functools.partial(
    train,
    environment=env,
    num_timesteps=1_000_000,
    num_evals=300,
    num_resets_per_eval=1,
    episode_length=250,
    kl_weight=0,
    network_factory=network_factory,
    ckpt_mgr=ckpt_mgr,
)

run_id = f"debug_in_notebook"
wandb.init(
    project="joystick",
    id=run_id,
    resume="allow",
)

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): api.wandb.ai:443
DEBUG:urllib3.connectionpool:https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 None
DEBUG:urllib3.connectionpool:https://api.wandb.ai:443 "POST /graphql HTTP/1.1" 200 None


[34m[1mwandb[0m: Currently logged in as: [33myuy004[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [None]:
def wandb_progress(num_steps, metrics):
    metrics["num_steps_thousands"] = num_steps
    wandb.log(metrics, commit=True)


rollout_env = RodentJoystick(evaluator=True)

train_fn(
    environment=env,
    progress_fn=wandb_progress,
)

4
Starting at iteration: 0 with 299 evals left
