In [1]:
import jax
import jax.numpy as jnp
from flax import struct

# --- Data Structures ---
@struct.dataclass
class Transition:
    obs: jax.Array
    act: jax.Array
    rew: jax.Array
    next_obs: jax.Array
    done: jax.Array 

@struct.dataclass
class BufferState:
    data: Transition
    ptr: jax.Array      # Shape (NumEnvs,)
    count: jax.Array    # Shape (NumEnvs,)
    capacity: int

# --- The Logic ---
class BatchedReplayBuffer:
    def __init__(self, num_envs, capacity, obs_dim, act_dim):
        self.num_envs = num_envs
        self.capacity = capacity
        # We pre-allocate the memory on GPU
        self.obs_shape = (num_envs, capacity, obs_dim)
        self.act_shape = (num_envs, capacity, act_dim)
        self.scalar_shape = (num_envs, capacity)

    def init(self):
        return BufferState(
            data=Transition(
                obs=jnp.zeros(self.obs_shape),
                act=jnp.zeros(self.act_shape),
                rew=jnp.zeros(self.scalar_shape),
                next_obs=jnp.zeros(self.obs_shape),
                done=jnp.zeros(self.scalar_shape),
            ),
            ptr=jnp.zeros(self.num_envs, dtype=jnp.int32),
            count=jnp.zeros(self.num_envs, dtype=jnp.int32),
            capacity=self.capacity
        )

    def add(self, state: BufferState, batch: Transition):
        # 1. Expand data to broadcast correctly: (NumEnvs, 1, Dim)
        # '...' detects remaining dims, making this part universal
        def fix(x): return x[:, None, ...] if x.ndim > 1 else x[:, None]
        
        # 2. Use vmap to insert into each env's circular buffer
        # This is the "Independent Buffer" logic        
        def insert_slice(buf, update, p):
            return jax.lax.dynamic_update_slice_in_dim(buf, update, start_index=p, axis=0)

        # We vmap over the "Environment" axis (0)
        new_data = Transition(
            obs      = jax.vmap(insert_slice)(state.data.obs     , fix(batch.obs)     , state.ptr),
            act      = jax.vmap(insert_slice)(state.data.act     , fix(batch.act)     , state.ptr),
            rew      = jax.vmap(insert_slice)(state.data.rew     , fix(batch.rew)     , state.ptr),
            next_obs = jax.vmap(insert_slice)(state.data.next_obs, fix(batch.next_obs), state.ptr),
            done     = jax.vmap(insert_slice)(state.data.done    , fix(batch.done)    , state.ptr)
        )

        new_ptr = (state.ptr + 1) % self.capacity
        new_count = jnp.minimum(state.count + 1, self.capacity)
        return state.replace(data=new_data, ptr=new_ptr, count=new_count)

    def sample(self, state: BufferState, rng, batch_size):
        k1, k2 = jax.random.split(rng)
        
        # 1. Randomly pick WHICH environments to sample from
        # Shape: (batch_size,)
        env_indices = jax.random.randint(k1, (batch_size,), 0, self.num_envs)
        
        # 2. Lookup the SPECIFIC count for each chosen environment
        # We use the indices from step 1 to grab the correct 'count' for that env.
        # Shape: (batch_size,) e.g. [20, 5, 100, 20, ...]
        batch_counts = state.count[env_indices]
        
        # 3. Sample time indices using the SPECIFIC bounds
        # jax.random.randint supports "Broadcasting". 
        # Since 'maxval' (batch_counts) matches the requested shape (batch_size,),
        # JAX will automatically use batch_counts[i] as the limit for item i.
        time_indices = jax.random.randint(k2, (batch_size,), 0, batch_counts)

        # 4. Gather the data
        def get(arr): return arr[env_indices, time_indices]
        
        return Transition(
            obs      = get(state.data.obs),
            act      = get(state.data.act),
            rew      = get(state.data.rew),
            next_obs = get(state.data.next_obs),
            done     = get(state.data.done)
        )

In [2]:
import jax
import jax.numpy as jnp
from flax import nnx
import distrax
from typing import Sequence, Callable

# --- Helper: The Universal MLP Builder ---
# Matches Spinning Up's 'mlp' function but handles NNX RNGs
def build_mlp(
    sizes: Sequence[int], 
    activation: Callable, 
    rngs: nnx.Rngs, 
    output_activation: Callable = None
):
    layers = []
    for i in range(len(sizes) - 1):
        # Add Linear Layer
        layers.append(nnx.Linear(
            sizes[i], sizes[i+1], 
            kernel_init=nnx.initializers.orthogonal(1.414),
            bias_init=nnx.initializers.constant(0.0),
            rngs=rngs
            )
        )
        if i < len(sizes) - 2:
            layers.append(activation)
        elif output_activation is not None:
            layers.append(output_activation)
            
    return nnx.Sequential(*layers)


# --- 1. The Flexible Critic ---
class Critic(nnx.Module):
    def __init__(
        self, 
        obs_dim: int, 
        act_dim: int, 
        hidden_sizes: Sequence[int] = (256, 256), 
        activation: Callable = nnx.relu,
        rngs: nnx.Rngs = None
    ):
        # Input to Q-net is Obs + Act
        input_dim = obs_dim + act_dim
        
        # Full architecture: [Input, ...Hidden..., 1]
        # We append [1] because Q-function outputs a single scalar value
        layer_sizes = [input_dim] + list(hidden_sizes) + [1]
        self.net1 = build_mlp(sizes=layer_sizes, activation=activation, rngs=rngs, output_activation=None)
        self.net2 = build_mlp(sizes=layer_sizes, activation=activation, rngs=rngs, output_activation=None)

    def __call__(self, obs, act):
        x = jnp.concatenate([obs, act], axis=-1)
        # Squeeze output to be (Batch,) instead of (Batch, 1)
        return self.net1(x).squeeze(-1), self.net2(x).squeeze(-1)
    

class Actor(nnx.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes=(256, 256), rngs=None):
        # 1. Base Network
        self.net = build_mlp([obs_dim] + list(hidden_sizes), nnx.relu, rngs, nnx.relu)
        
        # 2. Heads
        last_size = hidden_sizes[-1]
        self.mu_layer = nnx.Linear(last_size, act_dim, 
                                   kernel_init=nnx.initializers.orthogonal(0.01), rngs=rngs)
        self.log_std_layer = nnx.Linear(last_size, act_dim, 
                                        kernel_init=nnx.initializers.orthogonal(0.01), rngs=rngs)
        self.act_limit = 1.0 # Standard for Brax/Gym

    def __call__(self, obs):
        """Returns the Distribution object (used for Training Loss)"""
        x = self.net(obs)
        mu = self.mu_layer(x)
        log_std = self.log_std_layer(x)
        log_std = jnp.clip(log_std, -20, 2)
        
        base_dist = distrax.MultivariateNormalDiag(mu, jnp.exp(log_std))
        return distrax.Transformed(base_dist, distrax.Block(distrax.Tanh(), ndims=1))

    def get_deterministic_action(self, obs):
        """Used for Evaluation (No Noise)"""
        x = self.net(obs)
        mu = self.mu_layer(x)
        # We just squash the mean directly. No sampling.
        return jnp.tanh(mu) * self.act_limit

    def get_stochastic_action(self, obs, key):
        """Used for Rollouts (With Noise)"""
        dist = self(obs)
        # Sample and scale (Distrax Tanh handles the squashing)
        # We act_limit is usually 1.0, but good to keep explicit
        return dist.sample(seed=key) * self.act_limit

In [3]:
import jax
import jax.numpy as jnp
from flax import nnx
import optax

# --- 1. The Actor Loss (Policy Optimization) ---
def actor_loss_fn(actor: Actor, critic: Critic, batch: Transition, alpha, rng):
    # 1. Sample actions from the CURRENT policy
    # We need a key here because the policy is stochastic (Gaussian)
    dist = actor(batch.obs)
    pi_action, log_prob = dist.sample_and_log_prob(seed=rng)
    
    # 2. Get Q-values for these NEW actions
    # Note: We use the 'critic' (not target) to grade the actor
    q1, q2 = critic(batch.obs, pi_action)
    min_q = jnp.minimum(q1, q2)
    
    # 3. SAC Objective: Maximize (MinQ - alpha * LogProb)
    # Since we are minimizing loss, we flip the signs:
    # Minimize: alpha * LogProb - MinQ
    loss = (alpha * log_prob - min_q).mean()
    
    return loss, -log_prob.mean() # Return entropy for logging

# --- 2. The Critic Loss (Q-Learning) ---
def critic_loss_fn(critic: Critic, target_critic: Critic, actor: Actor, batch: Transition, alpha, rng):
    # 1. Generate Target Actions (from Next State)
    # We don't want gradients flowing through the target generation!
    dist = actor(batch.next_obs)
    next_action, next_log_prob = dist.sample_and_log_prob(seed=rng)
    
    # 2. Target Q-Values (Double Q-Learning)
    # Use target_critic weights
    target_q1, target_q2 = target_critic(batch.next_obs, next_action)
    target_min_q = jnp.minimum(target_q1, target_q2)
    
    # 3. The Bellman Backup (Soft Q)
    # Target = R + gamma * (1 - D) * (TargetQ - alpha * TargetEntropy)
    gamma = 0.99
    target_y = batch.rew + gamma * (1 - batch.done) * (target_min_q - alpha * next_log_prob)
    
    # 4. Current Q-Values
    # We execute both heads on the current batch
    current_q1, current_q2 = critic(batch.obs, batch.act)
    
    # 5. MSE Loss
    loss_q1 = ((current_q1 - target_y) ** 2).mean()
    loss_q2 = ((current_q2 - target_y) ** 2).mean()
    
    return loss_q1 + loss_q2, jnp.mean(target_min_q) # Log Q-vals

# --- 3. The Update Step (Combining Everything) ---
# This is the function that runs inside the Scan Loop
def train_step(
    actor: Actor, critic: Critic, target_critic: Critic,       # The Models
    actor_opt: nnx.Optimizer, critic_opt: nnx.Optimizer,       # The Optimizers
    batch: Transition,                                         # The Data
    key,                                                       # The Master Key
    alpha=0.2,                                                 # Fixed Entropy Temp (Simplified)
    polyak=0.995                                               # Target Update Rate
):
    # A. Split Keys for the two stochastic operations
    # key -> (next_key, key_for_actor_loss, key_for_critic_loss)
    next_key, k1, k2 = jax.random.split(key, 3)
    
    # B. Update Critic
    # nnx.value_and_grad gives us both the loss value (for logs) and gradients
    (c_loss, c_log), c_grads = nnx.value_and_grad(critic_loss_fn, has_aux=True)(
        critic, target_critic, actor, batch, alpha, k1
    )
    critic_opt.update(critic, c_grads)

    # C. Update Actor
    # Note: Actor update depends on Critic, so we do it second
    (a_loss, entropy), a_grads = nnx.value_and_grad(actor_loss_fn, has_aux=True)(
        actor, critic, batch, alpha, k2
    )
    actor_opt.update(actor, a_grads)
    
    # D. Update Target Networks (Polyak Averaging)
    # Standard JAX tree map: new_target = polyak * target + (1-polyak) * source
    # We access the parameters via 'nnx.state(model, nnx.Param)'
    
    # Helper to smooth weights
    def soft_update(target_node, source_node):
        return polyak * target_node + (1.0 - polyak) * source_node
        
    # We update the state of target_critic IN PLACE (conceptually)
    # nnx.update performs the replacement safely
    current_params = nnx.state(critic, nnx.Param)
    target_params = nnx.state(target_critic, nnx.Param)
    
    new_target_params = jax.tree.map(soft_update, target_params, current_params)
    nnx.update(target_critic, new_target_params)

    # E. Return Logs and Key
    metrics = {
        "loss_critic": c_loss,
        "loss_actor": a_loss,
        "q_val": c_log,
        "entropy": entropy
    }
    return next_key, metrics

In [None]:
import os
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
import wandb
from collections import namedtuple

# Playground Imports
from mujoco_playground import registry
from brax.envs import Wrapper, State

# --- 1. CONFIGURATION ---
jax.config.update("jax_enable_x64", True)
os.environ['JAX_DEFAULT_MATMUL_PRECISION'] = 'highest'

ENV_NAME = 'CartpoleSwingup'
NUM_ENVS = 2048
TOTAL_STEPS = 50_000
BATCH_SIZE = 256
HIDDEN_SIZES = (256, 256)
LR = 3e-4
ALPHA = 0.2
WARMUP_STEPS = 100
LOG_EVERY = 10

# --- 2. ADAPTERS & HELPERS ---
_PseudoStateBase = namedtuple('PseudoState', ['data', 'obs', 'reward', 'done', 'metrics', 'info'])

class PseudoState(_PseudoStateBase):
    def replace(self, **kwargs):
        return self._replace(**kwargs)

class BraxAdapter(Wrapper):
    def reset(self, rng):
        state = self.env.reset(rng)
        return self._to_brax_state(state)

    def step(self, state, action):
        mock_state = PseudoState(
            data=state.pipeline_state,
            obs=state.obs,
            reward=state.reward,
            done=state.done,
            metrics=state.metrics,
            info=state.info
        )
        next_state = self.env.step(mock_state, action)
        return self._to_brax_state(next_state)

    def _to_brax_state(self, play_state):
        return State(
            pipeline_state=play_state.data,
            obs=play_state.obs,
            reward=play_state.reward,
            done=play_state.done,
            metrics=play_state.metrics,
            info=play_state.info
        )

# --- 3. THE CONTAINER (To keep scan clean) ---
# We use this to bundle everything that changes during training
class AgentState(nnx.Variable):
    # This is just a dummy class to hold references if we wanted
    # But simpler is just a Python Dataclass or Tuple.
    # Let's use a simple dictionary-like structure for JAX to carry.
    pass

import orbax.checkpoint as ocp
# --- 3.5 CHECKPOINTING HELPERS ---
def create_checkpoint_manager(base_dir, max_to_keep=3):
    abs_path = os.path.abspath(base_dir)
    options = ocp.CheckpointManagerOptions(max_to_keep=max_to_keep, create=True)
    checkpointer = ocp.StandardCheckpointer()
    return ocp.CheckpointManager(abs_path, checkpointer, options=options)

def save_ckpt(manager, step, actor, critic, target_critic, actor_opt, critic_opt):
    # Extract state from NNX modules
    payload = {
        'actor': nnx.state(actor),
        'critic': nnx.state(critic),
        'target_critic': nnx.state(target_critic),
        'actor_opt': nnx.state(actor_opt),
        'critic_opt': nnx.state(critic_opt),
    }
    # Save (blocking)
    print(f"Saving checkpoint step {step}...")
    manager.save(step, payload)

def restore_ckpt(manager, actor, critic, target_critic, actor_opt, critic_opt):
    latest_step = manager.latest_step()
    if latest_step is None:
        print("No checkpoint found. Starting fresh.")
        return 0

    print(f"Restoring from step {latest_step}...")
    
    # Create structure template
    target_payload = {
        'actor': nnx.state(actor),
        'critic': nnx.state(critic),
        'target_critic': nnx.state(target_critic),
        'actor_opt': nnx.state(actor_opt),
        'critic_opt': nnx.state(critic_opt),
    }
    
    restored = manager.restore(latest_step, items=target_payload)
    
    # Update objects in-place
    nnx.update(actor, restored['actor'])
    nnx.update(critic, restored['critic'])
    nnx.update(target_critic, restored['target_critic'])
    nnx.update(actor_opt, restored['actor_opt'])
    nnx.update(critic_opt, restored['critic_opt'])
    
    return latest_step


# --- 4. MAIN LOOP ---
def main():

    wandb.init(
        project="jax-sac-playground",
        config={
            "env": ENV_NAME,
            "num_envs": NUM_ENVS,
            "total_steps": TOTAL_STEPS,
            "batch_size": BATCH_SIZE,
            "lr": LR,
            "hidden_sizes": HIDDEN_SIZES
        }
    )

    ckpt_dir = "./sac_checkpoints"
    ckpt_manager = create_checkpoint_manager(ckpt_dir)

    print(f"Loading Playground Env: {ENV_NAME}")
    base_env = registry.load(ENV_NAME)
    env = BraxAdapter(base_env)
    
    jit_reset = jax.jit(jax.vmap(env.reset))
    jit_step = jax.jit(jax.vmap(env.step))
    
    obs_dim = env.observation_size
    act_dim = env.action_size
    print(f"Obs: {obs_dim} | Act: {act_dim} | Envs: {NUM_ENVS}")

    # --- INITIALIZATION ---
    key = nnx.Rngs(0)
    actor = Actor(obs_dim, act_dim, HIDDEN_SIZES, rngs=key)
    critic = Critic(obs_dim, act_dim, HIDDEN_SIZES, rngs=key)
    target_critic = Critic(obs_dim, act_dim, HIDDEN_SIZES, rngs=nnx.Rngs(0))

    actor_opt = nnx.Optimizer(actor, optax.adam(LR), wrt=nnx.Param)
    critic_opt = nnx.Optimizer(critic, optax.adam(LR), wrt=nnx.Param)

    buffer = BatchedReplayBuffer(NUM_ENVS, capacity=10_000, obs_dim=obs_dim, act_dim=act_dim)
    buffer_state = buffer.init()

    # --- 5. ROLLOUT STEP (Fixed Signature) ---
    # We pass the models IN through the carry arguments
    def rollout_step(carry, _):
        # UNPACK EVERYTHING
        (env_state, buf_state, key, 
         actor, critic, target_critic, 
         actor_opt, critic_opt) = carry
        
        # 1. Action
        key, act_key = jax.random.split(key)
        act_keys = jax.random.split(act_key, NUM_ENVS)
        # Actor is now a "Tracer" object local to this loop, so we can use it safely
        action = jax.vmap(actor.get_stochastic_action)(env_state.obs, act_keys)
        
        # 2. Step
        next_env_state = jit_step(env_state, action)
        real_next_obs = next_env_state.obs
        
        trans = Transition(
            obs=env_state.obs,
            act=action,
            rew=next_env_state.reward,
            next_obs=real_next_obs,
            done=next_env_state.done
        )
        buf_state = buffer.add(buf_state, trans)
        
        # 3. Train
        batch = buffer.sample(buf_state, key, BATCH_SIZE)
        
        # This function updates the local 'actor/critic' variables in-place
        key, train_metrics = train_step(
            actor, critic, target_critic, 
            actor_opt, critic_opt, 
            batch, key, ALPHA
        )

        avg_reward = next_env_state.reward.mean()

        metrics = {
            **train_metrics,          # q_val, loss_actor, etc.
            "env_reward": avg_reward, # The raw reward signal
            "episode_return": env_state.metrics.get('episode_return', 0.0) # If playground provides it
        }
        
        # REPACK EVERYTHING
        # We must return the modified actor/critic objects
        new_carry = (next_env_state, buf_state, key, 
                     actor, critic, target_critic, 
                     actor_opt, critic_opt)
        
        return new_carry, metrics

    # --- 6. EXECUTION ---
    print("Initializing state...")
    master_key = jax.random.PRNGKey(42)
    master_key, reset_key = jax.random.split(master_key)
    reset_keys = jax.random.split(reset_key, NUM_ENVS)
    env_state = jit_reset(reset_keys)

    # Warmup
    print("Warmup...")
    def warmup_fn(carry, _):
        es, bs, k = carry
        k, ak = jax.random.split(k)
        action = jax.random.uniform(ak, (NUM_ENVS, act_dim), minval=-1, maxval=1)
        nes = jit_step(es, action)
        trans = Transition(es.obs, action, nes.reward, nes.obs, nes.done)
        bs = buffer.add(bs, trans)
        return (nes, bs, k), None
        
    (env_state, buffer_state, master_key), _ = jax.lax.scan(
        warmup_fn, (env_state, buffer_state, master_key), None, length=WARMUP_STEPS
    )

    # Train
    print("Training...")
    steps_per_epoch = 1000
    num_epochs = TOTAL_STEPS // steps_per_epoch
    
    # We construct the MEGA CARRY tuple
    carry = (env_state, buffer_state, master_key, 
             actor, critic, target_critic, 
             actor_opt, critic_opt)

    for epoch in range(num_epochs):
        # scan returns the new tuple with updated models
        carry, metrics_history = jax.lax.scan(
            rollout_step, 
            carry, 
            None, 
            length=steps_per_epoch
        )
        
        # <--- 4. AGGREGATE & LOG
        # We average the 1000 steps to get one data point for WandB
        # jax.tree.map applies jnp.mean to every item in the dict
        avg_metrics = jax.tree.map(lambda x: jnp.mean(x), metrics_history)
        
        # Convert JAX arrays to standard Python floats for WandB
        log_dict = {k: float(v) for k, v in avg_metrics.items()}
        log_dict["step"] = (epoch + 1) * steps_per_epoch
        
        wandb.log(log_dict)
        
        print(f"Epoch {epoch}: Reward = {log_dict['env_reward']:.4f}, Q-Val = {log_dict['q_val']:.2f}")

        # Unpack the current models from carry to save them
        current_step = (epoch + 1) * steps_per_epoch
        
        # Note: carry[3] is actor, carry[4] is critic, etc. based on your tuple
        (_, _, _, cur_actor, cur_critic, cur_target, cur_act_opt, cur_crit_opt) = carry
        
        # Save every 5 epochs (or whatever frequency you prefer)
        if (epoch + 1) % LOG_EVERY == 0:
            print("logging")
            save_ckpt(
                ckpt_manager, 
                current_step, 
                cur_actor, cur_critic, cur_target, cur_act_opt, cur_crit_opt
            )

    (_, _, _, cur_actor, cur_critic, cur_target, cur_act_opt, cur_crit_opt) = carry
    save_ckpt(
        ckpt_manager, 
        TOTAL_STEPS, 
        cur_actor, cur_critic, cur_target, cur_act_opt, cur_crit_opt
    )

if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Currently logged in as: [33mnagababa[0m ([33msbp_team[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Loading Playground Env: CartpoleSwingup
Obs: 5 | Act: 1 | Envs: 2048
Initializing state...
Warmup...
Training...
Epoch 0: Reward = 0.0322, Q-Val = 0.16
Epoch 1: Reward = 0.0771, Q-Val = 0.93
Epoch 2: Reward = 0.1122, Q-Val = 1.86
Epoch 3: Reward = 0.1377, Q-Val = 2.90
Epoch 4: Reward = 0.1691, Q-Val = 4.01
Epoch 5: Reward = 0.2072, Q-Val = 5.20
Epoch 6: Reward = 0.2601, Q-Val = 6.53
Epoch 7: Reward = 0.2974, Q-Val = 7.98
Epoch 8: Reward = 0.3192, Q-Val = 9.46
Epoch 9: Reward = 0.3513, Q-Val = 10.95
Epoch 10: Reward = 0.3906, Q-Val = 12.92
Epoch 11: Reward = 0.4161, Q-Val = 15.24
Epoch 12: Reward = 0.4476, Q-Val = 17.67
Epoch 13: Reward = 0.4884, Q-Val = 20.25
Epoch 14: Reward = 0.5111, Q-Val = 22.87
Epoch 15: Reward = 0.5498, Q-Val = 25.66
Epoch 16: Reward = 0.5794, Q-Val = 28.43
Epoch 17: Reward = 0.6184, Q-Val = 31.31
Epoch 18: Reward = 0.6339, Q-Val = 34.24
Epoch 19: Reward = 0.6499, Q-Val = 37.25
Epoch 20: Reward = 0.6464, Q-Val = 40.04
Epoch 21: Reward = 0.6450, Q-Val = 42.70
Epoc

In [None]:
import os
os.environ["MUJOCO_GL"] = "egl"
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
import mujoco
import numpy as np
import imageio
import orbax.checkpoint as ocp
from collections import namedtuple
from mujoco_playground import registry
from brax.envs import Wrapper, State
import absl.logging

# --- 1. CONFIGURATION ---
ENV_NAME = 'CartpoleSwingup'
CHECKPOINT_DIR = "./sac_checkpoints" # Must match your training checkpoint directory
HIDDEN_SIZES = (256, 256) # Must match training
LR = 3e-4

# --- 3. ADAPTERS (Standard) ---
_PseudoStateBase = namedtuple('PseudoState', ['data', 'obs', 'reward', 'done', 'metrics', 'info'])
class PseudoState(_PseudoStateBase):
    def replace(self, **kwargs): return self._replace(**kwargs)

class BraxAdapter(Wrapper):
    def reset(self, rng):
        state = self.env.reset(rng)
        return self._to_brax_state(state)
    def step(self, state, action):
        mock_state = PseudoState(
            data=state.pipeline_state,
            obs=state.obs,
            reward=state.reward,
            done=state.done,
            metrics=state.metrics,
            info=state.info
        )
        next_state = self.env.step(mock_state, action)
        return self._to_brax_state(next_state)
    def _to_brax_state(self, s):
        return State(pipeline_state=s.data, obs=s.obs, reward=s.reward, done=s.done, metrics=s.metrics, info=s.info)

# --- 4. CHECKPOINT LOADER ---
# We reuse the training helpers `create_checkpoint_manager` and `restore_ckpt`
# defined in the previous cell.

import orbax.checkpoint as ocp
# --- 3.5 CHECKPOINTING HELPERS ---
def create_checkpoint_manager(base_dir, max_to_keep=3):
    abs_path = os.path.abspath(base_dir)
    options = ocp.CheckpointManagerOptions(max_to_keep=max_to_keep, create=True)
    checkpointer = ocp.StandardCheckpointer()
    return ocp.CheckpointManager(abs_path, checkpointer, options=options)


def restore_ckpt(manager, actor, critic, target_critic, actor_opt, critic_opt):
    latest_step = manager.latest_step()
    if latest_step is None:
        print("No checkpoint found. Starting fresh.")
        return 0

    print(f"Restoring from step {latest_step}...")
    
    # Create structure template
    target_payload = {
        'actor': nnx.state(actor),
        'critic': nnx.state(critic),
        'target_critic': nnx.state(target_critic),
        'actor_opt': nnx.state(actor_opt),
        'critic_opt': nnx.state(critic_opt),
    }
    
    restored = manager.restore(latest_step, items=target_payload)
    
    # Update objects in-place
    nnx.update(actor, restored['actor'])
    nnx.update(critic, restored['critic'])
    nnx.update(target_critic, restored['target_critic'])
    nnx.update(actor_opt, restored['actor_opt'])
    nnx.update(critic_opt, restored['critic_opt'])

def main():
    print(f"Initializing Env: {ENV_NAME}")
    raw_env = registry.load(ENV_NAME)
    env = BraxAdapter(raw_env)
    obs_dim = env.observation_size
    act_dim = env.action_size

    # --- INIT ACTOR & CRITICS, OPTIMIZERS ---
    key = nnx.Rngs(0)
    actor = Actor(obs_dim, act_dim, HIDDEN_SIZES, rngs=key)
    critic = Critic(obs_dim, act_dim, HIDDEN_SIZES, rngs=key)
    target_critic = Critic(obs_dim, act_dim, HIDDEN_SIZES, rngs=nnx.Rngs(0))

    actor_opt = nnx.Optimizer(actor, optax.adam(LR), wrt=nnx.Param)
    critic_opt = nnx.Optimizer(critic, optax.adam(LR), wrt=nnx.Param)

    # --- SETUP CHECKPOINT MANAGER (reuse training helper) ---
    ckpt_manager = create_checkpoint_manager(CHECKPOINT_DIR)
    latest_step = restore_ckpt(ckpt_manager, actor, critic, target_critic, actor_opt, critic_opt)
    if latest_step == 0:
        raise ValueError(f"No checkpoints found in {os.path.abspath(CHECKPOINT_DIR)}")
    print(f"Restored checkpoint from step {latest_step}.")

    # --- SETUP RENDERING ---
    mj_model = raw_env.mj_model
    mj_data = mujoco.MjData(mj_model)
    renderer = mujoco.Renderer(mj_model, height=480, width=640)

    jit_reset = jax.jit(env.reset)
    jit_step = jax.jit(env.step)

    @jax.jit
    def get_action(obs):
        obs_batched = obs[None, :] 
        action_batched = actor.get_deterministic_action(obs_batched)
        return action_batched[0]

    print("Simulating...")
    frames = []
    rng = jax.random.PRNGKey(42)
    state = jit_reset(rng)

    for i in range(500):
        action = get_action(state.obs)
        state = jit_step(state, action)
        mj_data.qpos = np.array(state.pipeline_state.qpos)
        mj_data.qvel = np.array(state.pipeline_state.qvel)
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(mj_data)
        frames.append(renderer.render())
        if state.done:
            print(f"Episode ended at step {i}")
            # Optional: state = jit_reset(rng)

    output_name = f'{ENV_NAME}_sac.mp4'
    imageio.mimsave(output_name, frames, fps=60)
    print(f"Video saved to: {output_name}")

if __name__ == "__main__":
    main()



Initializing Env: CartpoleSwingup
Restoring from step 50000...
Restored checkpoint from step None.
Simulating...


  self.pid = _fork_exec(


Video saved to: CartpoleSwingup_sac.mp4
