In [1]:
import os
import subprocess
import distutils.util

# Check for GPU (remove or modify if you're running on a CPU-only machine)
if subprocess.run('nvidia-smi', shell=True).returncode:
    raise RuntimeError(
        'Cannot communicate with GPU. '
        'Make sure you have an NVIDIA GPU with the proper drivers installed.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
    with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
        f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

# Set environment variable to use GPU rendering:
os.environ['MUJOCO_GL'] = 'egl'
print('Environment variable MUJOCO_GL set to:', os.environ['MUJOCO_GL'])

try:
    print('Checking that the installation succeeded:')
    import mujoco
    mujoco.MjModel.from_xml_string('<mujoco/>')
except Exception as e:
    raise RuntimeError(
        'Something went wrong during installation. Check the terminal output for more information.\n'
        'If using a hosted runtime, make sure GPU acceleration is enabled.'
    ) from e

print('Installation successful.')

# Tell XLA to use Triton GEMM for improved performance (if applicable)
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
print('XLA_FLAGS set to:', os.environ['XLA_FLAGS'])


Thu Mar  6 17:56:16 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A4000               On  |   00000000:65:00.0  On |                  Off |
| 41%   45C    P8             17W /  140W |      89MiB /  16376MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A4000               On  |   00

In [2]:
# If you haven't installed mediapy yet:
# %pip install mediapy

import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List

# Instead of the apt-install commands, just ensure ffmpeg is installed in your environment.

import mediapy as media
import matplotlib.pyplot as plt

# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)

#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
from etils import epath
import functools
from IPython.display import HTML
from typing import Any, Dict, Sequence, Tuple, Union
import os
from ml_collections import config_dict


import jax
from jax import numpy as jp
import numpy as np
from flax.training import orbax_utils
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from orbax import checkpoint as ocp

import mujoco
from mujoco import mjx

from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.base import State as PipelineState
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model


print("All packages imported successfully!")


All packages imported successfully!


If you want the humanoid with the trey, uncomment the rewards accordingly, also change the xml path. (note: left arm reward might be messed up)

In [3]:
#@title Humanoid Env
#from pathlib import path
import jax.numpy as jnp    

#HUMANOID_ROOT_PATH = epath.Path(epath.resource_path('mujoco')) / 'mjx/test_data/humanoid'
HUMANOID_ROOT_PATH = epath.Path("/home/wrschiff/Desktop/mujoco/Robust_Locomotion/")

class Humanoid(PipelineEnv):

   def __init__(
      self,
      forward_reward_weight=1.25,
      ctrl_cost_weight=0.1,
      healthy_reward=5.0,
      terminate_when_unhealthy=True,
      healthy_z_range=(1.0, 2.0),
      ########### WEIGHT FOR BALANCING TRAY ############
      #balance_tray_reward_weight=0, # setting this to 0 for now
      #terminate_when_boxfall=True,
      ##################################################
      #left_arm_pose_penalty_weight=0, #also setting this to 0 for initial policy
      reset_noise_scale=1e-2,
      exclude_current_positions_from_observation=True,
      **kwargs,
   ):
   #
      mj_model = mujoco.MjModel.from_xml_path(os.getcwd() + '/basic_humanoid_no_tray.xml')
      mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
      mj_model.opt.iterations = 6
      mj_model.opt.ls_iterations = 6

      sys = mjcf.load_model(mj_model)

      physics_steps_per_control_step = 5
      kwargs['n_frames'] = kwargs.get(
         'n_frames', physics_steps_per_control_step)
      kwargs['backend'] = 'mjx'

      super().__init__(sys, **kwargs)

      self._forward_reward_weight = forward_reward_weight
      self._ctrl_cost_weight = ctrl_cost_weight
      self._healthy_reward = healthy_reward
      self._terminate_when_unhealthy = terminate_when_unhealthy
      self._healthy_z_range = healthy_z_range
      self._reset_noise_scale = reset_noise_scale
      ########### WEIGHT FOR BALANCING TRAY ############
      #self._balance_tray_reward_weight = balance_tray_reward_weight
      #self._terminate_when_boxfall = terminate_when_boxfall
      ##################################################
      self._exclude_current_positions_from_observation = (
         exclude_current_positions_from_observation
      )

      ########### INDICES FOR BOX AND TRAY IN XPOS ARRAY ############
      #self.tray_x_id = mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_XBODY, "tray")
      #self.box_x_id = mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_XBODY, "box")
      ###############################################################

      #self._left_arm_pose_penalty_weight = left_arm_pose_penalty_weight

      #self.target_left_arm_pose = jax.numpy.array([0.0, 0.0, 0.0])

      # Identify left arm joint indices (adjust joint names as needed)
      self.left_arm_joint_ids = jnp.array([
         mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_JOINT, "shoulder1_left"),
         mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_JOINT, "shoulder2_left"),
         mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_JOINT, "elbow_left")
      ])


   def reset(self, rng: jp.ndarray) -> State:
      """Resets the environment to an initial state."""
      rng, rng1, rng2 = jax.random.split(rng, 3)

      low, hi = -self._reset_noise_scale, self._reset_noise_scale
      qpos = self.sys.qpos0 + jax.random.uniform(
         rng1, (self.sys.nq,), minval=low, maxval=hi
      )
      qvel = jax.random.uniform(
         rng2, (self.sys.nv,), minval=low, maxval=hi
      )

      data = self.pipeline_init(qpos, qvel)

      
      obs = self._get_obs(data, jp.zeros(self.sys.nu))
      reward, done, zero = jp.zeros(3)
      metrics = {
         'forward_reward': zero,
         'reward_linvel': zero,
         'reward_quadctrl': zero,
         'reward_alive': zero,
         'x_position': zero,
         'y_position': zero,
         'distance_from_origin': zero,
         'x_velocity': zero,
         'y_velocity': zero,
      }
      return State(data, obs, reward, done, metrics)
   def step(self, state: State, action: jp.ndarray) -> State:
      """Runs one timestep of the environment's dynamics."""
      data0 = state.pipeline_state
      data = self.pipeline_step(data0, action)

      com_before = data0.subtree_com[1]
      com_after = data.subtree_com[1]
      velocity = (com_after - com_before) / self.dt
      forward_reward = self._forward_reward_weight * velocity[0]

      min_z, max_z = self._healthy_z_range
      is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
      is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
      if self._terminate_when_unhealthy:
         healthy_reward = self._healthy_reward
      else:
         healthy_reward = self._healthy_reward * is_healthy

      ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

      ############## CALCULATE BOX-TRAY REWARD ##################
      #euclid_dist_tb = jp.linalg.norm(data.x.pos[self.tray_x_id] - data.x.pos[self.box_x_id])
      #balance_cost = euclid_dist_tb * self._balance_tray_reward_weight
      ###########################################################

      obs = self._get_obs(data, action)
      ############## ADD TO OVERALL REWARD ##################
      reward = forward_reward + healthy_reward - ctrl_cost #- balance_cost
      #######################################################

      reward = forward_reward + healthy_reward - ctrl_cost #- balance_cost

      # Compute left arm steadiness penalty
      #left_arm_joint_angles = data.q[self.left_arm_joint_ids]      
      
      # Compute pose deviation error
      #pose_error = left_arm_joint_angles - self.target_left_arm_pose
      #pose_penalty = self._left_arm_pose_penalty_weight * jp.sum(jp.square(pose_error))
        
       # Subtract the penalty from the total reward
      #reward = reward - pose_penalty

      #print(f'CTRL COST BEFORE SCALAR (as benchmark): {ctrl_cost}')
      #print(f'EUCLID DISTANCE: {euclid_dist_tb} \t\tSCALED REWARD: {balance_cost}\t\tTOTAL REWARD: {reward}')

      ########## ADDING TERMINATION CONSTRAINT IF BOX FALLS OFF TRAY ############
      #is_balanced = data.x.pos[self.tray_x_id][2] < data.x.pos[self.box_x_id][2]
      #done = 0.0
      #done = jp.where(self._terminate_when_unhealthy, 1.0 - is_healthy, 0.0)
      #done = jp.where(
      #   (self._terminate_when_boxfall) & (done == 0.0),  # both must be True
      #   1.0 - is_balanced,
      #   done
      #)

      ###########################################################################
      # PREVIOUS METHOD: done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0

      #print(f'TRAY HEIGHT: {data.x.pos[self.tray_x_id][2]}\tBOX HEIGHT: {data.x.pos[self.box_x_id][2]}\tDONE:{done}')
      
      state.metrics.update(
         forward_reward=forward_reward,
         reward_linvel=forward_reward,
         reward_quadctrl=-ctrl_cost,
         reward_alive=healthy_reward,
         x_position=com_after[0],
         y_position=com_after[1],
         distance_from_origin=jp.linalg.norm(com_after),
         x_velocity=velocity[0],
         y_velocity=velocity[1],
      )

      return state.replace(
         pipeline_state=data, obs=obs, reward=reward, done=done
      )

   def _get_obs(
      self, data: mjx.Data, action: jp.ndarray
   ) -> jp.ndarray:
      """Observes humanoid body position, velocities, and angles."""
      position = data.qpos
      if self._exclude_current_positions_from_observation:
         position = position[2:]

      # external_contact_forces are excluded
      return jp.concatenate([
         position,
         data.qvel,
         data.cinert[1:].ravel(),
         data.cvel[1:].ravel(),
         data.qfrc_actuator,
      ])
envs.register_environment('humanoid', Humanoid)

In [10]:
# instantiate the environment
env_name = 'humanoid'
env = envs.get_environment(env_name)

# define the jit reset/step functions
jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)

Note: This example is still simplified – for instance, the minibatch selection does not do random shuffling. You can adapt it to your own needs (e.g., randomizing minibatch order). Also, it’s jitted only partially. For a full, production‐grade Brax PPO, you may want to pmap or vmap certain computations. But this script should demonstrate the correct logic for truncated BPTT, advantage handling, and KL early stopping.

ALSO: The Reward Cycler hasn't been fully implemented yet. Follow the reward cycle from either of these papers: https://arxiv.org/pdf/2204.04340 or https://arxiv.org/pdf/2011.01387 

In [5]:
import time
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
import brax.envs as envs
from functools import partial

# --------------------------------------------------------------------
# 1) Recurrent Actor-Critic Module
# --------------------------------------------------------------------
class RecurrentActorCritic(nn.Module):
    action_dim: int
    hidden_dim: int = 128

    @nn.compact
    def __call__(self, obs, carry):
        # One-step LSTM cell.
        lstm_cell = nn.LSTMCell()
        new_carry, hidden = lstm_cell(carry, obs)
        # Policy head.
        mean = nn.Dense(self.action_dim)(hidden)
        log_std = self.param('log_std', nn.initializers.zeros, (self.action_dim,))
        # Value head.
        value = nn.Dense(1)(hidden)
        return mean, log_std, value, new_carry

def init_carry(batch_size, hidden_dim):
    return nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), hidden_dim)

# --------------------------------------------------------------------
# 2) Utility Functions
# --------------------------------------------------------------------
def sample_action(mean, log_std, key, deterministic=False):
    std = jnp.exp(log_std)
    if deterministic:
        return mean
    noise = jax.random.normal(key, shape=mean.shape)
    return mean + std * noise

def gaussian_log_prob(mean, log_std, action):
    std = jnp.exp(log_std)
    var = std ** 2
    logp = -0.5 * (((action - mean) ** 2) / var + 2 * log_std + jnp.log(2 * jnp.pi))
    return jnp.sum(logp, axis=-1)

# --------------------------------------------------------------------
# 3) Reward Cycler (Optional)
# --------------------------------------------------------------------
class RewardCycler:
    def __init__(self, period):
        self.period = period
        self.step_count = 0
    def __call__(self, reward):
        phase = (self.step_count // self.period) % 2
        mod_reward = reward if phase == 0 else -reward
        self.step_count += 1
        return mod_reward

# --------------------------------------------------------------------
# 4) Rollout Generation
# --------------------------------------------------------------------
def rollout(params, env_state, carry, key, unroll_length, reward_cycler):
    """
    Collect a rollout of length unroll_length from the environment.
    Returns transitions (shape (T, B, ...)) and updated (carry, env_state, key).
    """
    def step_fn(carry_env, _):
        carry, state, key = carry_env
        key, subkey = jax.random.split(key)
        mean, log_std, value, new_carry = RecurrentActorCritic(
            action_dim=params['action_dim'], hidden_dim=params['hidden_dim']
        ).apply(params['model'], state.obs, carry)
        action = sample_action(mean, log_std, subkey)
        logp = gaussian_log_prob(mean, log_std, action)
        next_state = state.step(action)
        mod_reward = reward_cycler(next_state.reward)
        done = next_state.done
        # Reset LSTM state where done.
        new_carry = jax.tree_util.tree_map(
            lambda x: jnp.where(done[:, None], jnp.zeros_like(x), x),
            new_carry
        )
        transition = {
            'obs': state.obs,
            'action': action,
            'reward': mod_reward,
            'value': value,
            'done': done,
            'logp': logp,
        }
        return (new_carry, next_state, key), transition

    init = (carry, env_state, key)
    (final_carry, final_env_state, final_key), transitions = jax.lax.scan(
        step_fn, init, None, length=unroll_length
    )
    transitions = {k: jnp.stack([t[k] for t in transitions]) for k in transitions[0].keys()}
    return transitions, final_carry, final_env_state, final_key

# --------------------------------------------------------------------
# 5) Trajectory Processing: Splitting, Padding, and Masking
# --------------------------------------------------------------------
def process_trajectories(transitions, fixed_seq_length):
    """
    Splits each environment's trajectory (of length T) into fixed-length sequences.
    Returns a dict with keys: 'obs', 'action', 'reward', 'value', 'logp', 'mask'
    with shape (num_seq, fixed_seq_length, ...).
    """
    T, B = transitions['reward'].shape[:2]
    processed = {k: [] for k in ['obs', 'action', 'reward', 'value', 'logp']}
    masks = []
    for b in range(B):
        traj = {k: transitions[k][:, b] for k in transitions}
        done_indices = (traj['done'] == 1).nonzero()[0].tolist()
        if not done_indices:
            done_indices = [T - 1]
        start = 0
        for d in done_indices:
            ep_len = d - start + 1
            n_frags = (ep_len + fixed_seq_length - 1) // fixed_seq_length
            for i in range(n_frags):
                frag_start = start + i * fixed_seq_length
                frag_end = min(frag_start + fixed_seq_length, start + ep_len)
                frag = {}
                for key in processed.keys():
                    frag_data = traj[key][frag_start:frag_end]
                    pad_len = fixed_seq_length - frag_data.shape[0]
                    if pad_len > 0:
                        pad_shape = (pad_len,) + frag_data.shape[1:]
                        frag_data = jnp.concatenate([frag_data, jnp.zeros(pad_shape, dtype=frag_data.dtype)], axis=0)
                    frag[key] = frag_data
                for key in processed.keys():
                    processed[key].append(frag[key])
                valid = jnp.ones((frag_end - frag_start,), dtype=jnp.float32)
                pad_mask = jnp.zeros((pad_len,), dtype=jnp.float32) if pad_len > 0 else jnp.array([])
                mask = jnp.concatenate([valid, pad_mask])
                masks.append(mask)
            start = d + 1
    for key in processed:
        processed[key] = jnp.stack(processed[key], axis=0)
    processed['mask'] = jnp.stack(masks, axis=0)
    return processed

def process_scalar_trajectories(scalar_array, done_array, fixed_seq_length):
    """
    Process a scalar array (advantages or returns) with shape (T, B)
    into shape (num_seq, fixed_seq_length) using similar logic.
    """
    T, B = scalar_array.shape
    result = []
    for b in range(B):
        arr = scalar_array[:, b]
        dones = (done_array[:, b] == 1).nonzero()[0].tolist()
        if not dones:
            dones = [T - 1]
        start = 0
        for d in dones:
            ep_len = d - start + 1
            n_frags = (ep_len + fixed_seq_length - 1) // fixed_seq_length
            for i in range(n_frags):
                frag_start = start + i * fixed_seq_length
                frag_end = min(frag_start + fixed_seq_length, start + ep_len)
                frag_data = arr[frag_start:frag_end]
                pad_len = fixed_seq_length - frag_data.shape[0]
                if pad_len > 0:
                    frag_data = jnp.concatenate([frag_data, jnp.zeros((pad_len,), dtype=frag_data.dtype)], axis=0)
                result.append(frag_data)
            start = d + 1
    return jnp.stack(result, axis=0)

# --------------------------------------------------------------------
# 6) Advantage & Return Computation
# --------------------------------------------------------------------
def normalize_advantages(advantages, eps=1e-8):
    mean = jnp.mean(advantages)
    std = jnp.std(advantages)
    return (advantages - mean) / (std + eps)

def compute_advantages(transitions, gamma=0.99, lam=0.95):
    T, B = transitions['reward'].shape
    rewards = transitions['reward']
    values = transitions['value']
    dones = transitions['done']
    advantages = jnp.zeros((T, B))
    returns = jnp.zeros((T, B))
    last_advantage = jnp.zeros((B,))
    def scan_fn(carry, t):
        next_value, next_adv = carry
        reward = rewards[t]
        value = values[t]
        mask = 1.0 - dones[t]
        delta = reward + gamma * next_value * mask - value
        adv = delta + gamma * lam * next_adv * mask
        return (value, adv), adv
    (_, adv_seq) = jax.lax.scan(scan_fn, (values[-1], last_advantage), jnp.arange(T-1, -1, -1))
    advantages = adv_seq[::-1]
    advantages = normalize_advantages(advantages)
    returns = advantages + values
    return advantages, returns

# --------------------------------------------------------------------
# 7) PPO Loss & KL Computation
# --------------------------------------------------------------------
def compute_loss(params, batch, advantages, returns, clip_eps, c1, c2):
    """
    Compute PPO loss for a batch (a single sequence).
    Returns total loss and auxiliary outputs (policy_loss, value_loss, new_logp_seq).
    """
    def scan_step(carry, inp):
        obs, act = inp
        mean, log_std, val, new_carry = RecurrentActorCritic(
            action_dim=params['action_dim'], hidden_dim=params['hidden_dim']
        ).apply(params['model'], obs, carry)
        new_logp = gaussian_log_prob(mean, log_std, act)
        return new_carry, (new_logp, val)
    init_carry_seq = init_carry(1, params['hidden_dim'])
    new_logp_seq, val_seq = jax.lax.scan(
        scan_step, init_carry_seq, (batch['obs'], batch['action'])
    )[1]
    new_logp_seq = new_logp_seq.squeeze(-1)  # shape (seq_len,)
    val_seq = val_seq.squeeze(-1)            # shape (seq_len,)
    old_logp_seq = batch['logp']              # shape (seq_len,)
    mask = batch['mask']                     # shape (seq_len,)
    ratio = jnp.exp(new_logp_seq - old_logp_seq)
    surr1 = ratio * advantages
    surr2 = jnp.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * advantages
    policy_loss = -jnp.sum(jnp.minimum(surr1, surr2) * mask) / jnp.sum(mask)
    value_loss = jnp.sum(((val_seq - returns) ** 2) * mask) / jnp.sum(mask)
    entropy_bonus = 0.0  # For brevity.
    total_loss = policy_loss + c1 * value_loss - c2 * entropy_bonus
    return total_loss, (policy_loss, value_loss, new_logp_seq)

def compute_kl(new_logp, old_logp, mask):
    return jnp.sum((old_logp - new_logp) * mask) / jnp.sum(mask)

def update_minibatch(params, opt_state, batch, adv, ret, optimizer, clip_eps, c1, c2, target_kl):
    def loss_fn(p):
        loss_val, aux = compute_loss(p, batch, adv, ret, clip_eps, c1, c2)
        return loss_val, aux
    (loss_val, (pol_loss, val_loss, new_logp_seq)), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    kl_val = compute_kl(new_logp_seq, batch['logp'], batch['mask'])
    return new_params, opt_state, kl_val, loss_val

def update_policy(params, opt_state, processed, advantages, returns, optimizer,
                  clip_eps, c1, c2, target_kl, noptepochs):
    num_seq = processed['obs'].shape[0]
    for epoch in range(noptepochs):
        for i in range(num_seq):
            batch = {k: processed[k][i] for k in ['obs', 'action', 'logp', 'mask']}
            mb_adv = advantages[i]
            mb_ret = returns[i]
            params, opt_state, kl, loss_val = update_minibatch(params, opt_state, batch, mb_adv, mb_ret,
                                                               optimizer, clip_eps, c1, c2, target_kl)
            if kl > target_kl:
                print(f"Early stopping: KL {kl:.4f} exceeded target {target_kl} at epoch {epoch}, seq {i}")
                return params, opt_state
    return params, opt_state

# --------------------------------------------------------------------
# 8) Main Training Loop
# --------------------------------------------------------------------
def train_recurrent_ppo(
    env_name: str,
    num_timesteps: int,
    unroll_length: int,
    batch_size: int,
    learning_rate: float,
    reward_cycle_period: int,
    fixed_seq_length: int,
    # PPO hyperparameters:
    clip_eps=0.2,
    c1=0.5,
    c2=0.0,
    gamma=0.99,
    lam=0.95,
    noptepochs=10,
    target_kl=0.02
):
    """
    Main training loop:
      1) Collect rollouts.
      2) Process trajectories into fixed-length sequences.
      3) Compute advantages and returns.
      4) Update policy with multiple epochs and KL early stopping.
    """
    env = envs.get_environment(env_name)
    key = jax.random.PRNGKey(42)
    key, reset_key = jax.random.split(key)
    env_state = env.reset(reset_key)
    
    # Ensure observations are JAX arrays with a batch dimension.
    init_obs = jnp.array(env_state.obs)
    if init_obs.ndim == 1:
        init_obs = init_obs[None, ...]
    
    action_dim = env.action_size
    hidden_dim = 128
    model = RecurrentActorCritic(action_dim=action_dim, hidden_dim=hidden_dim)
    key, init_key = jax.random.split(key)
    carry = init_carry(batch_size, hidden_dim)
    params_model = model.init(init_key, init_obs, carry)
    params = {'model': params_model, 'action_dim': action_dim, 'hidden_dim': hidden_dim}
    
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(params)
    
    reward_cycler = RewardCycler(reward_cycle_period)
    
    total_updates = num_timesteps // unroll_length
    start_time = time.time()
    
    for t in range(total_updates):
        key, rollout_key = jax.random.split(key)
        transitions, new_carry, env_state, key = rollout(params, env_state, carry, rollout_key, unroll_length, reward_cycler)
        carry = new_carry  # update carry for next rollout
        
        # Process trajectories: (num_seq, fixed_seq_length, ...)
        processed = process_trajectories(transitions, fixed_seq_length)
        adv_raw, ret_raw = compute_advantages(transitions, gamma, lam)
        adv_proc = process_scalar_trajectories(adv_raw, transitions['done'], fixed_seq_length)
        ret_proc = process_scalar_trajectories(ret_raw, transitions['done'], fixed_seq_length)
        
        params, opt_state = update_policy(
            params, opt_state, processed, adv_proc, ret_proc,
            optimizer, clip_eps, c1, c2, target_kl, noptepochs
        )
        if t % 100 == 0:
            elapsed = time.time() - start_time
            print(f"Update {t}/{total_updates}, elapsed {elapsed:.2f}s")
    return params

if __name__ == "__main__":
    trained_params = train_recurrent_ppo(
        env_name="humanoid",
        num_timesteps=100000,
        unroll_length=20,
        batch_size=32,
        learning_rate=1e-4,
        reward_cycle_period=100,
        fixed_seq_length=10,
        clip_eps=0.2,
        c1=0.5,
        c2=0.0,
        gamma=0.99,
        lam=0.95,
        noptepochs=10,
        target_kl=0.02
    )
    print("Training complete!")


TypeError: 'int' object is not subscriptable