# Soft Actor-Critic (SAC) for BipedalWalker-v3

This notebook implements a Soft Actor-Critic (SAC) agent to solve the BipedalWalker-v3 environment from Gymnasium.
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy reinforcement learning framework.

In [None]:
!pip install swig
!pip install gymnasium[box2d]

import gymnasium as gym
import numpy as np
import cv2  # Added for parallel window rendering
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import torch.nn.functional as F
from collections import deque
import random
import os
import matplotlib.pyplot as plt
# Change to standard tqdm to avoid notebook widget errors
from tqdm import tqdm

from google.colab import drive
drive.mount('/content/drive')

In [41]:
# For local training on Mac, 
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

if device.type == "cuda":
    try:
        torch.cuda.set_device(0)
        gpu_name = torch.cuda.get_device_name(0)
        props = torch.cuda.get_device_properties(0)
        print(f"CUDA GPU: {gpu_name} | Memory: {props.total_memory/1024**3:.1f} GB")
    except Exception as e:
        print(f"GPU info not available: {e}")

# Set random seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True

set_seed(42)

Using device: mps


## Obstacles Environment Wrapper

Enhanced BipedalWalker with procedurally generated obstacles (platforms, gaps, slopes) to increase difficulty and encourage adaptive gait strategies.


In [43]:
import math
import numpy as np
import gymnasium as gym
from Box2D.b2 import (
    polygonShape,
    edgeShape,
)

class ObstacleBipedalWrapper(gym.Wrapper):
    """
    BipedalWalker with obstacles generated directly on the terrain surface.
    """

    def __init__(self, env, difficulty=0.5, seed=None):
        super().__init__(env)
        self.difficulty = difficulty
        self.seed = seed
        self.world = None
        self.terrain_poly = None 
        self.obstacle_bodies = []
        self.obstacle_polys = []

    # ------------------------------------------------------------
    # Core lifecycle
    # ------------------------------------------------------------

    def reset(self, *, seed=None, options=None):
        # We must reset the base env first to generate the terrain_y array
        obs, info = self.env.reset(seed=seed, options=options)

        self._extract_world()
        self._clear_obstacles()
        self._spawn_obstacle_course()

        return obs, info

    def step(self, action):
        return self.env.step(action)

    def render(self):
        mode = getattr(self.env.unwrapped, "render_mode", None)
        base_frame = self.env.render()
        
        viewer = getattr(self.env.unwrapped, "viewer", None)
        if viewer is None:
            return base_frame

        self._draw_obstacles_to_viewer(viewer)

        if mode == "rgb_array":
            return viewer.render(return_rgb_array=True)
        else:
            viewer.render()
            return base_frame

    # ------------------------------------------------------------
    # World & cleanup
    # ------------------------------------------------------------

    def _extract_world(self):
        self.world = self.env.unwrapped.world
        self.terrain_poly = getattr(self.env.unwrapped, "terrain_poly", None)

    def _clear_obstacles(self):
        if self.world is not None:
            for body in self.obstacle_bodies:
                try:
                    self.world.DestroyBody(body)
                except Exception:
                    pass
            self.obstacle_bodies.clear()

        if self.terrain_poly is not None and self.obstacle_polys:
            for poly in self.obstacle_polys:
                try:
                    self.terrain_poly.remove(poly)
                except ValueError:
                    pass
            self.obstacle_polys.clear()

    # ------------------------------------------------------------
    # Obstacle generation
    # ------------------------------------------------------------

    def _spawn_obstacle_course(self):
        """Generates objects that sit exactly on the terrain surface."""
        if self.seed is not None:
            np.random.seed(self.seed)

        # BipedalWalker terrain starts at x=0. 
        # We start placing obstacles after the initial flat zone.
        x_pos = 10.0 
        spacing = 8.0 - (2.0 * self.difficulty)
        n_obs = int(5 + 5 * self.difficulty)

        for _ in range(n_obs):
            # Randomize object dimensions
            w = 0.4 + np.random.rand() * 0.4
            h = 0.4 + np.random.rand() * 0.8
            
            self._create_object_on_surface(x_pos, width=w, height=h)
            x_pos += spacing

    def _create_object_on_surface(self, x_pos, width, height):
        """Finds terrain height at x_pos and places a static box there."""
        unwrapped = self.env.unwrapped
        
        # TERRAIN_STEP is usually 0.1 in BipedalWalker
        step = 0.1 
        idx = int(x_pos / step)
        
        # Ensure we are within the bounds of the generated terrain
        if idx >= len(unwrapped.terrain_y):
            return
            
        y_surface = unwrapped.terrain_y[idx]

        # Box2D uses the center of the shape for position.
        # To rest on the surface: center_y = surface_y + (height / 2)
        body = self.world.CreateStaticBody(
            position=(x_pos, y_surface + height / 2)
        )
        
        shape = polygonShape(box=(width / 2, height / 2))
        body.CreateFixture(shape=shape, friction=1.0)
        self.obstacle_bodies.append(body)

        # Create the visual polygon for the renderer
        hw, hh = width / 2, height / 2
        vertices = [
            (x_pos - hw, y_surface),
            (x_pos + hw, y_surface),
            (x_pos + hw, y_surface + height),
            (x_pos - hw, y_surface + height),
        ]
        
        # Use a distinct 'hazard' color (Reddish)
        self._register_render_poly(vertices, color=(0.8, 0.3, 0.3))

    def _register_render_poly(self, vertices, color):
        if self.terrain_poly is None:
            return
        poly = (vertices, color)
        self.terrain_poly.append(poly)
        self.obstacle_polys.append(poly)

    def _draw_obstacles_to_viewer(self, viewer):
        """Direct drawing for environments with active viewers."""
        for body in self.obstacle_bodies:
            for fixture in body.fixtures:
                shape = fixture.shape
                if isinstance(shape, polygonShape):
                    verts = [body.transform * v for v in shape.vertices]
                    # Draw a slightly darker border for visibility
                    viewer.draw_polygon(verts, color=(0.5, 0.2, 0.2))

## Replay Buffer

The replay buffer stores experience tuples (state, action, reward, next_state, done) to be sampled during training.

In [44]:
class ReplayBuffer:
    def __init__(self, state_dim, action_dim, buffer_size=int(1e6)):
        self.buffer_size = buffer_size
        self.ptr = 0
        self.size = 0
        
        self.state = np.zeros((buffer_size, state_dim), dtype=np.float32)
        self.action = np.zeros((buffer_size, action_dim), dtype=np.float32)
        self.reward = np.zeros((buffer_size, 1), dtype=np.float32)
        self.next_state = np.zeros((buffer_size, state_dim), dtype=np.float32)
        self.done = np.zeros((buffer_size, 1), dtype=np.float32)
    
    def add(self, state, action, reward, next_state, done):
        # Check if input is batch or single
        if state.ndim == 1:
            state = state[None, :]
            action = action[None, :]
            reward = np.array(reward)[None]
            next_state = next_state[None, :]
            done = np.array(done)[None]
        
        batch_size = len(state)
        
        if self.ptr + batch_size <= self.buffer_size:
            self.state[self.ptr:self.ptr+batch_size] = state
            self.action[self.ptr:self.ptr+batch_size] = action
            self.reward[self.ptr:self.ptr+batch_size] = reward.reshape(-1, 1)
            self.next_state[self.ptr:self.ptr+batch_size] = next_state
            self.done[self.ptr:self.ptr+batch_size] = done.reshape(-1, 1)
            self.ptr = (self.ptr + batch_size) % self.buffer_size
        else:
            # Handle wrap around
            overflow = (self.ptr + batch_size) - self.buffer_size
            split = batch_size - overflow
            
            # First part
            self.state[self.ptr:] = state[:split]
            self.action[self.ptr:] = action[:split]
            self.reward[self.ptr:] = reward[:split].reshape(-1, 1)
            self.next_state[self.ptr:] = next_state[:split]
            self.done[self.ptr:] = done[:split].reshape(-1, 1)
            
            # Second part (overflow)
            self.state[:overflow] = state[split:]
            self.action[:overflow] = action[split:]
            self.reward[:overflow] = reward[split:].reshape(-1, 1)
            self.next_state[:overflow] = next_state[split:]
            self.done[:overflow] = done[split:].reshape(-1, 1)
            self.ptr = overflow
            
        self.size = min(self.size + batch_size, self.buffer_size)
    
    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)
        return (
            self.state[ind],
            self.action[ind],
            self.reward[ind],
            self.next_state[ind],
            self.done[ind]
        )
    
    def __len__(self):
        return self.size

## Network Architectures

We define the Actor and Critic networks. The Actor outputs the mean and log standard deviation of the action distribution. The Critic estimates the Q-value for a given state-action pair.

In [45]:
class ActorNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(ActorNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mu = nn.Linear(hidden_dim, action_dim)
        # Output log_std from the network instead of a fixed parameter
        self.log_std_linear = nn.Linear(hidden_dim, action_dim)
        
    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        mu = self.mu(x)
        log_std = self.log_std_linear(x)
        # Constrain log_std for numerical stability
        log_std = torch.clamp(log_std, -20, 2)
        return mu, log_std
    
    def sample(self, state):
        mu, log_std = self.forward(state)
        std = log_std.exp()
        dist = Normal(mu, std)
        x_t = dist.rsample()
        action = torch.tanh(x_t)
        
        # Log prob calculation
        log_prob = dist.log_prob(x_t)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(-1, keepdim=True)
        return action, log_prob

class CriticNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(CriticNetwork, self).__init__()
        # Critic takes state and action as input
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.q_value = nn.Linear(hidden_dim, 1)
    
    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.q_value(x)

## SAC Agent

The SAC agent orchestrates the interaction with the environment and the training process. It maintains the actor, two critics (for double Q-learning), and their target networks.

In [46]:
class SACAgent:
    def __init__(
        self,
        state_dim,
        action_dim,
        action_scale=1.0,
        device="cpu",
        learning_rate=3e-4,
        gamma=0.99,
        tau=0.001,
        alpha=0.2,
        batch_size=1024,
        buffer_size=int(1e6),
        target_entropy=None,
        hidden_dim=256,
    ):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.action_scale = action_scale
        self.device = device
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.tau = tau  # Soft target update rate
        self.alpha = alpha
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        # Automatic entropy tuning target: default to -dim(A) unless overridden
        self.target_entropy = target_entropy if target_entropy is not None else -float(action_dim)
        
        # Networks
        self.actor = ActorNetwork(state_dim, action_dim, hidden_dim=hidden_dim).to(self.device)
        self.critic1 = CriticNetwork(state_dim, action_dim, hidden_dim=hidden_dim).to(self.device)
        self.critic2 = CriticNetwork(state_dim, action_dim, hidden_dim=hidden_dim).to(self.device)
        self.target_critic1 = CriticNetwork(state_dim, action_dim, hidden_dim=hidden_dim).to(self.device)
        self.target_critic2 = CriticNetwork(state_dim, action_dim, hidden_dim=hidden_dim).to(self.device)
        
        # Copy weights to target networks
        self.target_critic1.load_state_dict(self.critic1.state_dict())
        self.target_critic2.load_state_dict(self.critic2.state_dict())
        
        # Optimizers
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.learning_rate)
        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=self.learning_rate)
        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=self.learning_rate)
        
        # Replay buffer (Now using optimized Numpy buffer with batch support)
        self.replay_buffer = ReplayBuffer(state_dim, action_dim, self.buffer_size)
        
        # Log alpha for entropy adjustment
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=self.learning_rate)
        
    def select_action(self, state, deterministic=False):
        with torch.no_grad():
            state = np.array(state)
            if state.ndim == 1:
                state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            else:
                state_t = torch.FloatTensor(state).to(self.device)
                
            if deterministic:
                mu, _ = self.actor(state_t)
                action = torch.tanh(mu).cpu().numpy()
            else:
                action, _ = self.actor.sample(state_t)
                action = action.cpu().numpy()
            
            # If we passed a single state (ndim=1), we want a single action (ndim=1)
            if state.ndim == 1:
                return action.flatten()
            
            return action
    
    def update(self):
        if len(self.replay_buffer) < self.batch_size:
            return
            
        # Sample from replay buffer
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        
        # Convert to tensors
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        
        # Update critic
        with torch.no_grad():
            next_actions, next_log_pi = self.actor.sample(next_states)
            next_q1 = self.target_critic1(next_states, next_actions)
            next_q2 = self.target_critic2(next_states, next_actions)
            next_q = torch.min(next_q1, next_q2) - self.alpha * next_log_pi
            target_q = rewards + (1 - dones) * self.gamma * next_q
        
        current_q1 = self.critic1(states, actions)
        current_q2 = self.critic2(states, actions)
        
        critic1_loss = F.mse_loss(current_q1, target_q)
        critic2_loss = F.mse_loss(current_q2, target_q)
        
        self.critic1_optimizer.zero_grad()
        critic1_loss.backward()
        # Gradient clipping for critics
        torch.nn.utils.clip_grad_norm_(self.critic1.parameters(), 1.0)
        self.critic1_optimizer.step()
        
        self.critic2_optimizer.zero_grad()
        critic2_loss.backward()
        # Gradient clipping for critics
        torch.nn.utils.clip_grad_norm_(self.critic2.parameters(), 1.0)
        self.critic2_optimizer.step()
        
        # Update actor
        actions_pred, log_pi = self.actor.sample(states)
        q1 = self.critic1(states, actions_pred)
        q2 = self.critic2(states, actions_pred)
        q = torch.min(q1, q2)
        
        actor_loss = (self.alpha * log_pi - q).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        # Gradient clipping for actor
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
        self.actor_optimizer.step()
        
        # Failsafe for target_entropy
        if not hasattr(self, 'target_entropy'):
            self.target_entropy = -float(self.action_dim)
            
        # Update alpha
        alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
        
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        self.alpha = self.log_alpha.exp()
        
        # Update target networks
        for param, target_param in zip(self.critic1.parameters(), self.target_critic1.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            
        for param, target_param in zip(self.critic2.parameters(), self.target_critic2.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

## Training Loop

We train the agent in the environment. We'll also log rewards and save checkpoints.

In [47]:
import functools
import logging
import sys
from datetime import datetime
import os
import csv
import time
from collections import deque

import numpy as np
import cv2
import torch
import gymnasium as gym

# Define a custom handler to work with tqdm
class TqdmLoggingHandler(logging.Handler):
    def __init__(self, level=logging.NOTSET):
        super().__init__(level)

    def emit(self, record):
        try:
            msg = self.format(record)
            tqdm.write(msg)
            self.flush()
        except Exception:
            self.handleError(record)

def setup_logger(log_dir="logs"):
    os.makedirs(log_dir, exist_ok=True)
    logger = logging.getLogger("BipedalWalker")
    logger.setLevel(logging.INFO)
    
    # Clear existing handlers to avoid duplicate logs
    if logger.hasHandlers():
        logger.handlers.clear()
        
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(log_dir, f"training_{timestamp}.log")
    
    # File handler
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    
    # Console handler using TqdmLoggingHandler
    ch = TqdmLoggingHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    
    return logger

class WalkingRewardWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        
    def step(self, action):
        state, reward, done, truncated, info = self.env.step(action)
        
        # --- State Mapping for BipedalWalker-v3 ---
        hull_angle = state[0]
        fwd_vel = state[2]
        # Leg 1
        hip1_angle, hip1_speed, leg1_contact = state[4], state[5], state[8]
        # Leg 2
        hip2_angle, hip2_speed, leg2_contact = state[9], state[10], state[13]
        
        # 1. Energy Penalty (Keep it efficient)
        energy_penalty = -0.0005 * np.sum(np.square(action))
        
        # 2. Stability Reward (Keep the body level)
        # Penalize leaning, softened for obstacles
        if leg1_contact and leg2_contact:
            stability_reward = -0.15 * abs(hull_angle)
        else:
            stability_reward = -0.05 * abs(hull_angle)
        
        # 3. Advanced Gait (Scissoring) Reward
        gait_reward = 0
        if fwd_vel > 0.05:

            stride_width = abs(hip1_angle - hip2_angle)

            coordination = np.maximum(0, -hip1_speed * hip2_speed)

            contact_bonus = 0.2 if leg1_contact != leg2_contact else -0.1

            scissor_penalty = -0.3 * np.maximum(0.0, hip1_angle * hip2_angle)

            same_vel_penalty = -0.2 * np.maximum(0.0, hip1_speed * hip2_speed)

            hip_action_corr = action[0] * action[2]
            action_scissor_penalty = -0.1 * np.maximum(0.0, hip_action_corr)

            gait_reward = (
                0.4 * np.tanh(stride_width)
                + 0.4 * np.tanh(coordination)
                + contact_bonus
                + scissor_penalty
                + same_vel_penalty
                + action_scissor_penalty
            )

        # 4. Anti-Crouch Reward (Optional but helpful)
        # state[1] is the hull Y-position. Standard height is ~1.0. 
        # Penalize if the agent tries to "crawl" on its knees.
        crouch_penalty = 0
        if state[1] < 0.8:
            crouch_penalty = -0.2

        # Apply the shaped rewards
        reward += energy_penalty + stability_reward + gait_reward + crouch_penalty
        
        return state, reward, done, truncated, info

def save_checkpoint(agent, episode, learning_rate, log_dir="logs", filename=None):
    if filename is None:
        filename = f"checkpoint_ep{episode}.pth"
    
    path = os.path.join(log_dir, filename)
    torch.save({
        'actor_state_dict': agent.actor.state_dict(),
        'critic1_state_dict': agent.critic1.state_dict(),
        'critic2_state_dict': agent.critic2.state_dict(),
        'episode': episode,
        'learning_rate': learning_rate
    }, path)
    return path

# CSV logging of model parameters per episode

def write_params_csv(agent, episode, csv_path):
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)
    write_header = not os.path.exists(csv_path)
    with open(csv_path, "a", newline="") as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow(["timestamp", "episode", "component", "parameter", "shape", "values"])

        ts = datetime.now().isoformat(timespec="seconds")

        def _dump_state_dict(state_dict, component):
            for name, tensor in state_dict.items():
                arr = tensor.detach().cpu().numpy()
                flat = arr.reshape(-1).tolist()
                writer.writerow([
                    ts,
                    episode,
                    component,
                    name,
                    list(arr.shape),
                    " ".join(map(str, flat))
                ])

        _dump_state_dict(agent.actor.state_dict(), "actor")
        _dump_state_dict(agent.critic1.state_dict(), "critic1")
        _dump_state_dict(agent.critic2.state_dict(), "critic2")
        writer.writerow([
            ts,
            episode,
            "alpha",
            "log_alpha",
            [1],
            str(agent.log_alpha.detach().cpu().item())
        ])

# CSV logging of rewards and training metrics per episode

def write_rewards_csv(agent, episode, total_reward, total_steps, learning_rate, num_envs, csv_path, avg10=None, avg100=None, log_dir_name=None):
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)
    write_header = not os.path.exists(csv_path)
    # Compute lightweight stats
    alpha = float(agent.log_alpha.detach().cpu().exp().item())
    actor_norm = float(sum(p.data.norm().item() for p in agent.actor.parameters()))
    critic1_norm = float(sum(p.data.norm().item() for p in agent.critic1.parameters()))
    critic2_norm = float(sum(p.data.norm().item() for p in agent.critic2.parameters()))

    with open(csv_path, "a", newline="") as f:
        writer = csv.writer(f)
        if write_header:
            writer.writerow([
                "timestamp", "episode", "reward", "avg10", "avg100", "total_steps",
                "learning_rate", "num_envs", "alpha",
                "actor_param_norm", "critic1_param_norm", "critic2_param_norm", "log_dir"
            ])
        ts = datetime.now().isoformat(timespec="seconds")
        writer.writerow([
            ts, episode, float(total_reward),
            float(avg10) if avg10 is not None else "",
            float(avg100) if avg100 is not None else "",
            int(total_steps),
            float(learning_rate), int(num_envs), float(alpha),
            float(actor_norm), float(critic1_norm), float(critic2_norm),
            log_dir_name or ""
        ])


def train_agent(
    env_name="BipedalWalker-v3",
    max_episodes=1000,
    max_steps=1000,
    device="cpu",
    render_freq=50,
    learning_rate=3e-4,
    updates_per_step=1,
    start_steps=20000,
    num_envs=1,
    save_interval=10,
    log_dir=None,
    use_obstacles=False,
    obstacle_difficulty=1.0,
    seed=42,
    tau=0.001,
    batch_size=1024,
    buffer_size=int(1e6),
    target_entropy=None,
    hidden_dim=256,
):
    # Set seeds for reproducibility per run
    set_seed(seed)

    # Determine log directory
    if log_dir is None:
        log_dir = os.path.join("logs", datetime.now().strftime("run_%Y%m%d_%H%M%S"))
    
    # Add obstacle info to log directory if enabled
    if use_obstacles:
        log_dir = log_dir.replace("run_", f"run_obstacles_d{obstacle_difficulty}_")

    # Prepare CSV paths for logging
    params_csv_path = os.path.join(log_dir, "params.csv")
    rewards_csv_path = os.path.join(log_dir, "rewards.csv")

    # Setup logger
    logger = setup_logger(log_dir=log_dir)
    logger.info(f"Starting training with device: {device}, LR: {learning_rate}, Num Envs: {num_envs}")
    logger.info(f"Hyperparams: tau={tau}, batch_size={batch_size}, buffer_size={buffer_size}, target_entropy={target_entropy}, hidden_dim={hidden_dim}, seed={seed}")
    logger.info(f"Obstacles enabled: {use_obstacles}, Difficulty: {obstacle_difficulty if use_obstacles else 'N/A'}")
    logger.info(f"Logs and checkpoints will be saved to: {log_dir}")
    
    # Determine render mode
    render_mode = "rgb_array"
    
    # Track per-environment progress
    env_step_counts = np.zeros(num_envs, dtype=int)
    env_episode_counts = np.zeros(num_envs, dtype=int)
    
    # Create environment
    if num_envs > 1:
        vec_mode = "async"
        wrappers = [WalkingRewardWrapper]
        if use_obstacles:
            wrappers.append(lambda env: ObstacleBipedalWrapper(env, difficulty=obstacle_difficulty))
        env = gym.make_vec(env_name, num_envs=num_envs, vectorization_mode=vec_mode, wrappers=wrappers, render_mode=render_mode)
        logger.info(f"Using {num_envs} vectorized environments ({vec_mode}) with WalkingRewardWrapper")
        if use_obstacles:
            logger.info(f"Obstacle environment enabled with difficulty={obstacle_difficulty}")
    else:
        env = gym.make(env_name, render_mode=render_mode)
        env = WalkingRewardWrapper(env)
        if use_obstacles:
            env = ObstacleBipedalWrapper(env, difficulty=obstacle_difficulty)
            logger.info(f"Using ObstacleBipedalWrapper with difficulty={obstacle_difficulty}")
        logger.info("Using WalkingRewardWrapper")
        
    if num_envs > 1:
        state_dim = env.single_observation_space.shape[0]
        action_dim = env.single_action_space.shape[0]
        action_scale = float(env.single_action_space.high[0])
    else:
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0]
        action_scale = float(env.action_space.high[0])
    
    # Initialize agent
    logger.info(f"Initializing SAC Agent on device: {device}")
    agent = SACAgent(
        state_dim,
        action_dim,
        action_scale,
        device=device,
        learning_rate=learning_rate,
        gamma=0.99,
        tau=tau,
        alpha=0.2,
        batch_size=batch_size,
        buffer_size=buffer_size,
        target_entropy=target_entropy,
        hidden_dim=hidden_dim,
    )
    
    # Training loop
    total_steps = 0
    episode_rewards = []
    recent_rewards = deque(maxlen=100)
    
    os.makedirs(log_dir, exist_ok=True)
    
    pbar = tqdm(range(max_episodes), desc=f"Training Progress", unit="ep")
    
    current_episode = 0
    training_complete = False
    
    # Reset env
    state, _ = env.reset()
    
    try:
        while current_episode < max_episodes and not training_complete:
            episode_reward = 0 
            if num_envs > 1:
                current_rewards = np.zeros(num_envs)
                
            # Determine if we should render this episode
            should_render = (render_freq > 0) and (current_episode % render_freq == 0)
            
            for step in range(max_steps):
                # Select action
                if total_steps < start_steps:
                    if num_envs > 1:
                        action = np.array([env.single_action_space.sample() for _ in range(num_envs)])
                    else:
                        action = env.action_space.sample()
                else:
                    if num_envs > 1:
                        action = agent.select_action(state, deterministic=False) 
                    else:
                        action = agent.select_action(state)
                        if isinstance(action, np.ndarray) and action.ndim > 1:
                            action = action.flatten()
                
                # Take step
                next_state, reward, done, truncated, info = env.step(action)
                
                # Rendering logic (only for single environment)
                if should_render and num_envs == 1:
                    try:
                        frame = env.render()
                        if isinstance(frame, np.ndarray):
                            bgr_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
                            cv2.putText(bgr_frame, f"Ep: {current_episode}", (10, 30), 
                                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
                            cv2.imshow("BipedalWalker Training", bgr_frame)
                            cv2.waitKey(1)
                    except Exception as e:
                        # Rendering not available or window closed
                        pass
                elif should_render and num_envs > 1 and step == 0:
                    # Log once per episode that rendering is disabled for parallel envs
                    logger.info(f"Rendering disabled for parallel environments (num_envs={num_envs})")
                
                # Handle done/truncated
                if num_envs > 1:
                    done_flag = done | truncated
                    current_rewards += reward
                    
                    for i in range(num_envs):
                        if done_flag[i]:
                            episode_rewards.append(current_rewards[i])
                            recent_rewards.append(current_rewards[i])
                            current_rewards[i] = 0
                            current_episode += 1
                            pbar.update(1)
                            
                            # Update progress bar
                            avg_reward = np.mean(recent_rewards) if len(recent_rewards) > 0 else 0.0
                            pbar.set_postfix({
                                'Last': f'{episode_rewards[-1]:.1f}',
                                'Avg': f'{avg_reward:.1f}',
                                'Steps': total_steps
                            })

                            # Write rewards CSV for this episode
                            avg10 = np.mean(episode_rewards[-10:]) if len(episode_rewards) >= 10 else np.mean(episode_rewards)
                            write_rewards_csv(
                                agent=agent,
                                episode=current_episode,
                                total_reward=episode_rewards[-1],
                                total_steps=total_steps,
                                learning_rate=learning_rate,
                                num_envs=num_envs,
                                csv_path=rewards_csv_path,
                                avg10=avg10,
                                avg100=avg_reward,
                                log_dir_name=os.path.basename(log_dir)
                            )

                            # Update render status
                            should_render = (render_freq > 0) and (current_episode % render_freq == 0)
                            
                            # Checkpoint logic
                            if current_episode % save_interval == 0:
                                path = save_checkpoint(agent, current_episode, learning_rate, log_dir=log_dir)
                                logger.info(f"Checkpoint saved at episode {current_episode}: {path}")
                                # Append parameters to CSV
                                write_params_csv(agent, current_episode, params_csv_path)
                                
                            if current_episode >= max_episodes:
                                training_complete = True
                                break
                    
                    # Buffer addition for vec env
                    real_next_states = next_state.copy()
                    if "_final_observation" in info:
                        mask = info["_final_observation"]
                        for i, is_final in enumerate(mask):
                            if is_final and "final_observation" in info:
                                real_next_states[i] = info["final_observation"][i]
                    
                    agent.replay_buffer.add(state, action, reward, real_next_states, done_flag)
                    
                    # Exit early if we've reached max episodes
                    if training_complete:
                        break
                    
                else:
                    done_flag = done or truncated
                    agent.replay_buffer.add(state, action, reward, next_state, done_flag)
                    episode_reward += reward
                    
                    if done_flag:
                        episode_rewards.append(episode_reward)
                        recent_rewards.append(episode_reward)
                        current_episode += 1
                        pbar.update(1)
                        
                        # Update progress bar
                        avg_reward = np.mean(recent_rewards) if len(recent_rewards) > 0 else 0.0
                        pbar.set_postfix({
                            'Last': f'{episode_rewards[-1]:.1f}',
                            'Avg': f'{avg_reward:.1f}',
                            'Steps': total_steps
                        })

                        # Write rewards CSV for this episode
                        avg10 = np.mean(episode_rewards[-10:]) if len(episode_rewards) >= 10 else np.mean(episode_rewards)
                        write_rewards_csv(
                            agent=agent,
                            episode=current_episode,
                            total_reward=episode_reward,
                            total_steps=total_steps,
                            learning_rate=learning_rate,
                            num_envs=num_envs,
                            csv_path=rewards_csv_path,
                            avg10=avg10,
                            avg100=avg_reward,
                            log_dir_name=os.path.basename(log_dir)
                        )

                        should_render = (render_freq > 0) and (current_episode % render_freq == 0)
                        
                        # Checkpoint logic
                        if current_episode % save_interval == 0:
                             path = save_checkpoint(agent, current_episode, learning_rate, log_dir=log_dir)
                             logger.info(f"Checkpoint saved at episode {current_episode}: {path}")
                             # Append parameters to CSV
                             write_params_csv(agent, current_episode, params_csv_path)
                        
                        # Check if we've reached max episodes
                        if current_episode >= max_episodes:
                            training_complete = True
                        
                        state, _ = env.reset()
                        break
                
                state = next_state
                total_steps += num_envs
                
                # Update agent
                if len(agent.replay_buffer) > agent.batch_size and total_steps >= start_steps:
                    for _ in range(updates_per_step * num_envs):
                        agent.update()

                # Update progress bar steps periodically
                if total_steps % 1000 == 0:
                    last_r = episode_rewards[-1] if episode_rewards else 0
                    avg_r = np.mean(recent_rewards) if len(recent_rewards) > 0 else 0
                    pbar.set_postfix({
                        'Last': f'{last_r:.1f}',
                        'Avg': f'{avg_r:.1f}',
                        'Steps': total_steps
                    })
                        
            # Logging progress occasionally
            if len(episode_rewards) > 0 and current_episode % 10 == 0:
                avg_reward = np.mean(episode_rewards[-10:])
                logger.info(f"Episode {current_episode}: Avg Reward (10) = {avg_reward:.2f}, Total Steps = {total_steps}")
                
    except KeyboardInterrupt:
        logger.warning("Training interrupted by user. Saving emergency checkpoint...")
        path = save_checkpoint(agent, current_episode, learning_rate, log_dir=log_dir, filename=f"emergency_checkpoint_ep{current_episode}.pth")
        logger.info(f"Emergency checkpoint saved: {path}")
    except Exception as e:
        logger.error(f"Error occurred: {e}")
        raise e
    finally:
        try:
            cv2.destroyAllWindows()
            cv2.waitKey(1)
        except:
            pass
        env.close()
        logger.info("Training finished/stopped.")

    return episode_rewards, agent

## Training Execution

Train the agent with flexible options for obstacles, difficulty levels, and hyperparameters. Compare baseline vs obstacle-based environments.


In [None]:
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================

print(f"Starting BipedalWalker training on {device}...")

# Configurable Hyperparameters
MAX_EPISODES = 2000  # Episodes per run (reduced for comparison study)
UPDATES_PER_STEP = 1
NUM_ENVS = 32  # Number of parallel environments
START_STEPS = 10000
RENDER_FREQ = 0
LEARNING_RATE = 1e-4

# ============================================================================
# OPTION 1: Train Baseline (No Obstacles)
# ============================================================================

# print("\n" + "="*70)
# print("TRAINING: Baseline Agent (No Obstacles)")
# print("="*70)

# rewards_baseline, agent_baseline = train_agent(
#     max_episodes=MAX_EPISODES, 
#     device=device, 
#     updates_per_step=UPDATES_PER_STEP,
#     start_steps=START_STEPS,
#     num_envs=NUM_ENVS,
#     render_freq=RENDER_FREQ,
#     learning_rate=LEARNING_RATE,
#     use_obstacles=False
# )
# print("✓ Baseline training completed!")

# ============================================================================
# OPTION 2: Train with Obstacles
# ============================================================================

print("\n" + "="*70)
print("TRAINING: Agent with Obstacles (difficulty=0.7)")
print("="*70)

rewards_obstacles, agent_obstacles = train_agent(
    max_episodes=MAX_EPISODES, 
    device=device, 
    updates_per_step=UPDATES_PER_STEP,
    start_steps=START_STEPS,
    num_envs=NUM_ENVS,
    render_freq=RENDER_FREQ,
    learning_rate=LEARNING_RATE,
    use_obstacles=True,
    obstacle_difficulty=0.7
)
print("✓ Obstacle training completed!")


## Training Results Visualization

Visualize training curves and compare performance between baseline and obstacle-based training runs.

In [None]:
# ============================================================================
# RESULTS: Visualization & Comparison
# ============================================================================

print("\n" + "="*70)
print("RESULTS: Baseline vs Obstacles Comparison")
print("="*70)

fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Plot 1: Individual training curves
window = 30
# baseline_smoothed = [np.mean(rewards_baseline[max(0, i-window):i+1]) for i in range(len(rewards_baseline))]
baseline_smoothed = [0.0 for _ in range(len(rewards_obstacles))]  # Placeholder since baseline training is commented out
obstacles_smoothed = [np.mean(rewards_obstacles[max(0, i-window):i+1]) for i in range(len(rewards_obstacles))]

# axes[0].plot(rewards_baseline, alpha=0.2, color='blue', label='Raw (Baseline)')
# axes[0].plot(baseline_smoothed, linewidth=2.5, color='blue', label='Smoothed (Baseline)')
axes[0].plot(rewards_obstacles, alpha=0.2, color='red', label='Raw (Obstacles)')
axes[0].plot(obstacles_smoothed, linewidth=2.5, color='red', label='Smoothed (Obstacles)')
axes[0].set_title("Training Curves: Baseline vs Obstacles", fontsize=12, fontweight='bold')
axes[0].set_xlabel("Episode")
axes[0].set_ylabel("Reward")
axes[0].legend(fontsize=9)
axes[0].grid(True, alpha=0.3)

# Plot 2: Direct comparison (smoothed only)
# axes[1].plot(baseline_smoothed, linewidth=3, color='blue', label='Baseline (No Obstacles)', marker='o', markersize=3, markevery=20)
axes[1].plot(obstacles_smoothed, linewidth=3, color='red', label='With Obstacles (difficulty=0.7)', marker='s', markersize=3, markevery=20)
axes[1].set_title("Smoothed Reward Comparison", fontsize=12, fontweight='bold')
axes[1].set_xlabel("Episode")
axes[1].set_ylabel("Smoothed Reward")
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Summary statistics
print(f"\n{'Metric':<30} {'Baseline':<20} {'With Obstacles':<20}")
print("-" * 70)
# print(f"{'Total Episodes':<30} {len(rewards_baseline):<20} {len(rewards_obstacles):<20}")
# print(f"{'Best Reward':<30} {max(rewards_baseline):<20.2f} {max(rewards_obstacles):<20.2f}")
# print(f"{'Worst Reward':<30} {min(rewards_baseline):<20.2f} {min(rewards_obstacles):<20.2f}")

# avg_baseline = np.mean(rewards_baseline[-100:]) if len(rewards_baseline) >= 100 else np.mean(rewards_baseline)
avg_baseline = 0.0  # Placeholder since baseline training is commented out
avg_obstacles = np.mean(rewards_obstacles[-100:]) if len(rewards_obstacles) >= 100 else np.mean(rewards_obstacles)
print(f"{'Avg (Last 100 eps)':<30} {avg_baseline:<20.2f} {avg_obstacles:<20.2f}")
print(f"{'Final Smoothed Reward':<30} {baseline_smoothed[-1]:<20.2f} {obstacles_smoothed[-1]:<20.2f}")

# improvement = baseline_smoothed[-1] - obstacles_smoothed[-1]
# print(f"{'Difficulty Gap (B-O)':<30} {improvement:<20.2f}")
# print("="*70)

## Video Recording and Evaluation

After training, we can record a video of the agent's performance to visually verify its walking ability. This is required for the application.

In [None]:
from gymnasium.wrappers import RecordVideo
from IPython.display import Video
import glob
import os

def load_checkpoint(checkpoint_path, state_dim, action_dim, device="cpu", learning_rate=3e-4):
    """Load a trained agent from a checkpoint."""
    agent = SACAgent(state_dim, action_dim, device=device, learning_rate=learning_rate)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    # Load with strict=False to handle architecture changes (old log_std param vs new log_std_linear layer)
    agent.actor.load_state_dict(checkpoint['actor_state_dict'], strict=False)
    agent.critic1.load_state_dict(checkpoint['critic1_state_dict'], strict=False)
    agent.critic2.load_state_dict(checkpoint['critic2_state_dict'], strict=False)

    print(f"Loaded checkpoint from: {checkpoint_path}")
    return agent

def record_video(agent, env_name="BipedalWalker-v3", filename="bipedal_walker", device="cpu", use_obstacles=False, obstacle_difficulty=0.5):
    # Create environment with render mode
    env = gym.make(env_name, render_mode="rgb_array")
    
    # Apply obstacle wrapper if requested
    if use_obstacles:
        env = ObstacleBipedalWrapper(env, difficulty=obstacle_difficulty)

    # Wrap environment to record video
    # We force record the first episode
    video_folder = "videos"
    os.makedirs(video_folder, exist_ok=True)
    env = RecordVideo(env, video_folder=video_folder, name_prefix=filename, episode_trigger=lambda x: True)

    state, _ = env.reset()
    done = False
    truncated = False
    total_reward = 0

    while not (done or truncated):
        # Use deterministic policy for evaluation
        action = agent.select_action(state, deterministic=True)

        # Ensure action is 1D if it comes back as 2D batch (handle legacy/stale agent instances)
        if isinstance(action, np.ndarray) and action.ndim > 1:
             action = action.flatten()

        next_state, reward, done, truncated, _ = env.step(action)
        state = next_state
        total_reward += reward

    env.close()
    print(f"Evaluation Run - Total Reward: {total_reward:.2f}")

    # Find the video file
    mp4_files = glob.glob(f"{video_folder}/{filename}-episode-0.mp4")
    if mp4_files:
        print(f"Video saved to {mp4_files[0]}")
        return mp4_files[0]
    return None

# Find the best checkpoint across all log directories
log_dirs = glob.glob("logs/run_*")
best_checkpoint_path = None
best_reward = -float('inf')
best_log_dir = None

print("Searching through all log directories for best performance...")

if log_dirs:
    for log_dir in log_dirs:
        # Find training log file to get final reward metrics
        log_files = glob.glob(f"{log_dir}/training_*.log")
        if log_files:
            # Try to extract best reward from log file
            log_file = log_files[0]
            try:
                with open(log_file, 'r') as f:
                    lines = f.readlines()
                    # Look for lines with "Avg Reward" or similar metrics
                    for line in reversed(lines):  # Check from end for latest stats
                        if "Avg Reward" in line or "Average Reward" in line:
                            # Extract the reward value
                            import re
                            numbers = re.findall(r"[-+]?\d*\.?\d+", line)
                            if numbers:
                                try:
                                    reward_value = float(numbers[-1])  # Usually the last number is the reward
                                    if reward_value > best_reward:
                                        best_reward = reward_value
                                        best_log_dir = log_dir
                                    break
                                except:
                                    continue
            except:
                pass
        
        # Also check checkpoints directly for episode numbers
        checkpoints = glob.glob(f"{log_dir}/checkpoint_ep*.pth")
        if checkpoints:
            def _ep_num(path):
                name = os.path.basename(path)
                try:
                    return int(name.split("checkpoint_ep")[1].split(".pth")[0])
                except Exception:
                    return -1

            latest_checkpoint = max(checkpoints, key=_ep_num)
            ep_num = _ep_num(latest_checkpoint)
            print(f"  {log_dir}: Latest episode {ep_num}")

# If we found a best log directory, use its highest episode checkpoint
if best_log_dir:
    checkpoints = glob.glob(f"{best_log_dir}/checkpoint_ep*.pth")
    if checkpoints:
        def _ep_num(path):
            name = os.path.basename(path)
            try:
                return int(name.split("checkpoint_ep")[1].split(".pth")[0])
            except Exception:
                return -1

        best_checkpoint_path = max(checkpoints, key=_ep_num)
        print(f"\n✓ Found best performing run: {best_log_dir}")
        print(f"  Using checkpoint from episode: {_ep_num(best_checkpoint_path)}")
else:
    # Fallback: just use latest checkpoint from most recent log
    if log_dirs:
        latest_log_dir = max(log_dirs, key=os.path.getctime)
        checkpoints = glob.glob(f"{latest_log_dir}/checkpoint_ep*.pth")
        if checkpoints:
            def _ep_num(path):
                name = os.path.basename(path)
                try:
                    return int(name.split("checkpoint_ep")[1].split(".pth")[0])
                except Exception:
                    return -1
            best_checkpoint_path = max(checkpoints, key=_ep_num)
            best_log_dir = latest_log_dir
            print(f"Using latest log directory: {latest_log_dir}")

if best_checkpoint_path:
    print(f"Loading agent from: {best_checkpoint_path}")

    # Create a fresh agent and load the checkpoint
    env_temp = gym.make("BipedalWalker-v3")
    state_dim = env_temp.observation_space.shape[0]
    action_dim = env_temp.action_space.shape[0]
    env_temp.close()

    checkpoint_agent = load_checkpoint(best_checkpoint_path, state_dim, action_dim, device=device)

    # Use log directory name as video filename
    log_dir_name = os.path.basename(best_log_dir)

    # Record video with the loaded agent
    # Check if the agent was trained with obstacles by looking at the log directory name
    use_obstacles = "with_obstacles" in log_dir_name or "_obs1" in log_dir_name
    obstacle_difficulty = 0.7 if use_obstacles else 0.0
    
    print(f"Recording video with obstacles={use_obstacles}, difficulty={obstacle_difficulty}")
    video_path = record_video(checkpoint_agent, filename=log_dir_name, device=device, 
                               use_obstacles=use_obstacles, obstacle_difficulty=obstacle_difficulty)
    if video_path:
        display(Video(video_path, embed=True, html_attributes="controls autoplay loop"))
else:
    print("No log directories or checkpoints found. Train the agent first.")


## Ablation Runner
Set up a lightweight sweep loop to vary seeds, entropy targets, learning rate, update ratio, tau, batch size, buffer size, hidden width, env count, and obstacle difficulty. Commented execution call lets you start/stop the sweep easily.

In [None]:
from itertools import product
import json
import time
import os
from pathlib import Path

# Build a grid of hyperparameters to sweep
def build_ablation_grid():
    # Primary study dimensions (3 LRs × 2 entropy targets = 6 total runs)
    learning_rates = [3e-4, 1e-4, 3e-5]  # For future: add more rates
    entropy_targets = [-1.0, -4.0]  # Exploration strength; For future: [-1, -2, -3]
    
    # Fixed parameters (single value each - not currently being studied)
    seeds = [42]  # Fixed seed for ablation; use build_seed_study_grid() for seed variance
    update_ratios = [1]  # Gradient steps per env step; For future: [1, 2, 4]
    taus = [0.005]  # Target network update rate; For future: [0.005, 0.01, 0.02]
    batch_sizes = [1024]  # For future: [256, 512, 1024]
    buffer_sizes = [int(1e6)]  # For future: [5e5, 1e6]
    num_envs_list = [32]  # Parallel environments; For future: [8, 16, 32]
    obstacle_settings = [(False, 0.0)]  # For future: [(False, 0.0), (True, 0.3), (True, 0.7)]
    hidden_dims = [256]  # Network width; For future: [256, 512]

    grid = []
    for lr, ent_t, seed, upd, tau, bs, buf, num_envs, (use_obs, obs_d), hdim in product(
        learning_rates, entropy_targets, seeds, update_ratios, taus, batch_sizes, buffer_sizes, num_envs_list, obstacle_settings, hidden_dims
    ):
        grid.append({
            "seed": seed,
            "target_entropy": ent_t,
            "learning_rate": lr,
            "updates_per_step": upd,
            "tau": tau,
            "batch_size": bs,
            "buffer_size": buf,
            "num_envs": num_envs,
            "use_obstacles": use_obs,
            "obstacle_difficulty": obs_d,
            "hidden_dim": hdim,
        })
    return grid

# Build a seed study grid using the best hyperparameters from ablation
def build_seed_study_grid(best_lr=1e-4, best_entropy=-1.0):
    """Run the same config with multiple seeds to measure variance."""
    seeds = [42, 123, 789]  # Multiple seeds for reproducibility study
    
    grid = []
    for seed in seeds:
        grid.append({
            "seed": seed,
            "target_entropy": best_entropy,
            "learning_rate": best_lr,
            "updates_per_step": 1,
            "tau": 0.005,
            "batch_size": 1024,
            "buffer_size": int(1e6),
            "num_envs": 32,
            "use_obstacles": False,
            "obstacle_difficulty": 0.0,
            "hidden_dim": 256,
        })
    return grid

def _tag_from_cfg(cfg):
    return (
        f"seed{cfg['seed']}_lr{cfg['learning_rate']}_ent{cfg['target_entropy']}"
        f"_upd{cfg['updates_per_step']}_tau{cfg['tau']}_bs{cfg['batch_size']}"
        f"_buf{cfg['buffer_size']}_env{cfg['num_envs']}_obs{int(cfg['use_obstacles'])}d{cfg['obstacle_difficulty']}"
        f"_h{cfg['hidden_dim']}_t{int(time.time())}"
    )

def run_ablation(grid, base_config, results_path="logs/ablations/results.json"):
    results = []
    os.makedirs(os.path.dirname(results_path), exist_ok=True)

    for idx, cfg in enumerate(grid):
        run_cfg = {**base_config, **cfg}
        run_tag = _tag_from_cfg(cfg)
        run_cfg["log_dir"] = os.path.join("logs", "ablations", run_tag)

        print("\n" + "=" * 70)
        print(f"[Ablation {idx+1}/{len(grid)}] {run_tag}")
        print("=" * 70)

        rewards, _ = train_agent(**run_cfg)
        mean_last_50 = float(np.mean(rewards[-50:])) if len(rewards) else 0.0
        results.append({
            **cfg,
            "episodes": len(rewards),
            "mean_last_50": mean_last_50,
            "log_dir": run_cfg["log_dir"],
        })

        # Save running results to disk so partial sweeps are recoverable
        with open(results_path, "w") as f:
            json.dump(results, f, indent=2)

    return results

# Base config shared across all sweeps (500-1000 episodes per run)
base_ablation_cfg = dict(
    env_name="BipedalWalker-v3",
    max_episodes=1000,  # Adjust to 1000 if needed
    max_steps=1000,
    device=device,
    render_freq=0,
    start_steps=5000,
    save_interval=100,
)

# Build grid; comment/uncomment to launch
ablation_grid = build_ablation_grid()
print(f"Planned ablation runs: {len(ablation_grid)} total configurations")
print("Studying 2 dimensions: learning_rate (3), entropy_target (2) = 6 runs")

# To run the full sweep, uncomment the line below:
# ablation_results = run_ablation(ablation_grid, base_ablation_cfg)

## Ablation Results Visualization

Visualize the ablation study results across seeds, learning rates, and entropy targets. Load results from the saved JSON and generate comparison plots.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Load ablation results from JSON
results_path = "logs/ablations/results.json"

try:
    with open(results_path, 'r') as f:
        ablation_results = json.load(f)
    
    # Convert to DataFrame for easier analysis
    df = pd.DataFrame(ablation_results)
    
    print(f"Loaded {len(df)} ablation runs from {results_path}")
    print(f"\nColumns: {df.columns.tolist()}")
    print(f"\nSummary statistics:")
    print(df[['seed', 'learning_rate', 'target_entropy', 'mean_last_50']].describe())
    
    # Create visualization
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Plot 1: Bar plot by learning rate
    ax1 = axes[0]
    lr_summary = df.groupby('learning_rate')['mean_last_50'].agg(['mean', 'std']).reset_index()
    lr_summary['learning_rate_str'] = lr_summary['learning_rate'].apply(lambda x: f"{x:.0e}")
    ax1.bar(lr_summary['learning_rate_str'], lr_summary['mean'], 
            yerr=lr_summary['std'], capsize=5, alpha=0.7, color='steelblue')
    ax1.set_title('Performance by Learning Rate', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Learning Rate')
    ax1.set_ylabel('Mean Reward (Last 50 Episodes)')
    ax1.grid(True, alpha=0.3, axis='y')
    
    # Plot 2: Bar plot by entropy target
    ax2 = axes[1]
    ent_summary = df.groupby('target_entropy')['mean_last_50'].agg(['mean', 'std']).reset_index()
    ent_summary['target_entropy_str'] = ent_summary['target_entropy'].astype(str)
    ax2.bar(ent_summary['target_entropy_str'], ent_summary['mean'], 
            yerr=ent_summary['std'], capsize=5, alpha=0.7, color='coral')
    ax2.set_title('Performance by Entropy Target', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Entropy Target (Exploration Strength)')
    ax2.set_ylabel('Mean Reward (Last 50 Episodes)')
    ax2.grid(True, alpha=0.3, axis='y')
    
    # Plot 3: Heatmap (learning rate × entropy target)
    ax3 = axes[2]
    pivot_table = df.pivot_table(values='mean_last_50', 
                                  index='learning_rate', 
                                  columns='target_entropy', 
                                  aggfunc='mean')
    sns.heatmap(pivot_table, annot=True, fmt='.1f', cmap='YlOrRd', ax=ax3, cbar_kws={'label': 'Mean Reward'})
    ax3.set_title('Interaction: LR × Entropy', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Entropy Target')
    ax3.set_ylabel('Learning Rate')
    # Format y-axis labels
    ax3.set_yticklabels([f"{float(label.get_text()):.0e}" for label in ax3.get_yticklabels()], rotation=0)
    
    plt.tight_layout()
    plt.savefig('logs/ablations/ablation_results.png', dpi=150, bbox_inches='tight')
    print(f"\nVisualization saved to logs/ablations/ablation_results.png")
    plt.show()
    
    # Print best configuration
    best_idx = df['mean_last_50'].idxmax()
    best_config = df.loc[best_idx]
    print("\n" + "="*70)
    print("BEST CONFIGURATION:")
    print("="*70)
    print(f"Learning Rate: {best_config['learning_rate']:.0e}")
    print(f"Entropy Target: {best_config['target_entropy']}")
    print(f"Mean Reward (Last 50 eps): {best_config['mean_last_50']:.2f}")
    print(f"Log Directory: {best_config['log_dir']}")
    print("="*70)
    
    # Print summary table grouped by hyperparameters
    print("\n" + "="*70)
    print("SUMMARY BY HYPERPARAMETER:")
    print("="*70)
    
    # Group by learning rate
    lr_summary = df.groupby('learning_rate')['mean_last_50'].agg(['mean', 'std', 'min', 'max', 'count'])
    print("\nLearning Rate:")
    for lr, row in lr_summary.iterrows():
        print(f"  {lr:.0e}: {row['mean']:.2f} ± {row['std']:.2f} [{row['min']:.2f}, {row['max']:.2f}] (n={int(row['count'])})")
    
    # Group by entropy target
    ent_summary = df.groupby('target_entropy')['mean_last_50'].agg(['mean', 'std', 'min', 'max', 'count'])
    print("\nEntropy Target:")
    for ent, row in ent_summary.iterrows():
        print(f"  {ent:.1f}: {row['mean']:.2f} ± {row['std']:.2f} [{row['min']:.2f}, {row['max']:.2f}] (n={int(row['count'])})")

except FileNotFoundError:
    print(f"Results file not found: {results_path}")
    print("Run the ablation study first by uncommenting the execution line in the previous cell.")
except Exception as e:
    print(f"Error loading or visualizing results: {e}")

## Seed Study Execution

Run the same configuration with multiple seeds to measure variance and robustness. Use the best hyperparameters identified from the ablation study above.

In [None]:
# Seed study: Run with best hyperparameters across multiple seeds
seed_study_grid = build_seed_study_grid(best_lr=1e-4, best_entropy=-1.0)
print(f"Planned seed study runs: {len(seed_study_grid)} configurations")
print("Studying seed variance with fixed hyperparameters")

# To run the seed study, uncomment the line below:
# seed_results = run_ablation(seed_study_grid, base_ablation_cfg, results_path="logs/ablations/seed_study.json")

## Seed Study Visualization

Visualize seed variance using the best hyperparameters from the ablation study. This measures the robustness of the configuration across different random seeds.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Load seed study results from JSON
seed_results_path = "logs/ablations/seed_study.json"

try:
    with open(seed_results_path, 'r') as f:
        seed_data = json.load(f)
    
    # Convert to DataFrame
    seed_df = pd.DataFrame(seed_data)
    
    print(f"Loaded {len(seed_df)} seed study runs from {seed_results_path}")
    print(f"\nColumns: {seed_df.columns.tolist()}")
    
    # Calculate statistics
    seed_summary = seed_df.groupby('seed')['mean_last_50'].agg(['mean']).reset_index()
    overall_mean = seed_df['mean_last_50'].mean()
    overall_std = seed_df['mean_last_50'].std()
    
    # Create visualization
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Bar plot by seed
    ax1 = axes[0]
    seed_summary['seed_str'] = seed_summary['seed'].astype(str)
    ax1.bar(seed_summary['seed_str'], seed_summary['mean'], alpha=0.7, color='green')
    ax1.axhline(y=overall_mean, color='red', linestyle='--', linewidth=2, label=f'Mean: {overall_mean:.2f}')
    ax1.axhline(y=overall_mean + overall_std, color='orange', linestyle=':', linewidth=1, alpha=0.7, label=f'±1 Std: {overall_std:.2f}')
    ax1.axhline(y=overall_mean - overall_std, color='orange', linestyle=':', linewidth=1, alpha=0.7)
    ax1.set_title('Performance by Seed', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Random Seed')
    ax1.set_ylabel('Mean Reward (Last 50 Episodes)')
    ax1.legend()
    ax1.grid(True, alpha=0.3, axis='y')
    
    # Plot 2: Box plot
    ax2 = axes[1]
    ax2.boxplot([seed_df['mean_last_50']], labels=['All Seeds'])
    ax2.set_title('Seed Variance Distribution', fontsize=14, fontweight='bold')
    ax2.set_ylabel('Mean Reward (Last 50 Episodes)')
    ax2.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('logs/ablations/seed_study_results.png', dpi=150, bbox_inches='tight')
    print(f"\nVisualization saved to logs/ablations/seed_study_results.png")
    plt.show()
    
    # Print seed study summary
    print("\n" + "="*70)
    print("SEED STUDY RESULTS:")
    print("="*70)
    print(f"\nOverall Performance: {overall_mean:.2f} ± {overall_std:.2f}")
    print("\nBy Seed:")
    for _, row in seed_summary.iterrows():
        print(f"  Seed {int(row['seed'])}: {row['mean']:.2f}")
    print("\n" + "="*70)
    
except FileNotFoundError:
    print(f"Seed study results not found: {seed_results_path}")
    print("Run the seed study first by uncommenting the execution line in the ablation cell.")
except Exception as e:
    print(f"Error loading or visualizing seed study results: {e}")

## Backup the logs and videos after training.
Used only for colab execution to save logs and videos to Google Drive.

In [None]:
import shutil
import os
from google.colab import drive

# 1. Mount your Google Drive
drive.mount('/content/drive')

# 2. Define your paths
source_folders = {
    '/content/logs': '/content/drive/MyDrive/log/logs',
    '/content/videos': '/content/drive/MyDrive/log/videos'
}

# 3. Execute the transfer
for src, dest in source_folders.items():
    if os.path.exists(src):
        shutil.copytree(src, dest, dirs_exist_ok=True)
        print(f"✅ Successfully synced: {src} -> {dest}")
    else:
        print(f"⚠️ Source not found, skipping: {src}")

print("\nBackup complete. You can view your files in the 'log' folder of your Drive.")