# HW14: Safe Reinforcement Learning

> - Full Name: **[Your Full Name]**
> - Student ID: **[Your Student ID]**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DeepRLCourse/Homework-14-Questions/blob/main/HW14_Notebook.ipynb)
[![Open In kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/DeepRLCourse/Homework-14-Questions/main/HW14_Notebook.ipynb)

## Overview
This assignment focuses on **Safe Reinforcement Learning**, exploring methods to train agents that not only maximize rewards but also satisfy safety constraints during both training and deployment. We'll implement and experiment with:

1. **Constrained Policy Optimization (CPO)**
2. **Safety Layers and Shielding**
3. **Risk-Sensitive RL (CVaR)**
4. **Safe Exploration Techniques**
5. **Robust RL Methods**

The goal is to understand the fundamental trade-offs between performance and safety in RL systems.


In [None]:
# @title Imports and Setup

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gymnasium as gym
from gymnasium import spaces
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import trange
from collections import defaultdict, deque
import random
from typing import Tuple, List, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


## 1. Safe Environment Setup

First, let's create a safe version of the CartPole environment where the agent must balance the pole while keeping the cart position within safe bounds.


In [None]:
class SafeCartPoleEnv(gym.Env):
    """
    Safe CartPole environment with position constraints.
    The agent must balance the pole while keeping cart position within safe bounds.
    """
    
    def __init__(self, position_limit=1.5, cost_threshold=0.1):
        super().__init__()
        
        # Create base CartPole environment
        self.env = gym.make('CartPole-v1')
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        
        # Safety parameters
        self.position_limit = position_limit
        self.cost_threshold = cost_threshold
        
        # Track episode statistics
        self.episode_cost = 0
        self.episode_reward = 0
        self.constraint_violations = 0
        
    def reset(self, **kwargs):
        """Reset environment and return initial observation."""
        obs, info = self.env.reset(**kwargs)
        self.episode_cost = 0
        self.episode_reward = 0
        self.constraint_violations = 0
        return obs, info
    
    def step(self, action):
        """Execute action and return next state, reward, done, truncated, info."""
        obs, reward, terminated, truncated, info = self.env.step(action)
        
        # Extract cart position (first element of observation)
        cart_position = obs[0]
        
        # Compute cost based on position constraint violation
        position_violation = max(0, abs(cart_position) - self.position_limit)
        cost = position_violation
        
        # Update episode statistics
        self.episode_cost += cost
        self.episode_reward += reward
        
        # Check for constraint violation
        if cost > self.cost_threshold:
            self.constraint_violations += 1
            # Terminate episode if constraint violated
            terminated = True
            reward = -100  # Large penalty for constraint violation
        
        # Add cost information to info
        info['cost'] = cost
        info['episode_cost'] = self.episode_cost
        info['episode_reward'] = self.episode_reward
        info['constraint_violations'] = self.constraint_violations
        
        return obs, reward, terminated, truncated, info
    
    def render(self, mode='human'):
        """Render the environment."""
        return self.env.render(mode)
    
    def close(self):
        """Close the environment."""
        self.env.close()

# Test the safe environment
print("Testing Safe CartPole Environment...")
env = SafeCartPoleEnv(position_limit=1.5, cost_threshold=0.1)

# Run a few random episodes to test
for episode in range(3):
    obs, info = env.reset()
    episode_reward = 0
    episode_cost = 0
    
    for step in range(100):  # Max 100 steps
        action = env.action_space.sample()  # Random action
        obs, reward, terminated, truncated, info = env.step(action)
        
        episode_reward += reward
        episode_cost += info['cost']
        
        if terminated or truncated:
            break
    
    print(f"Episode {episode + 1}: Reward = {episode_reward:.2f}, Cost = {episode_cost:.2f}, "
          f"Violations = {info['constraint_violations']}")

env.close()


## 2. Neural Network Architectures

Let's implement the neural network components needed for safe RL algorithms.


In [None]:
class PolicyNetwork(nn.Module):
    """Gaussian policy network for continuous actions."""
    
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim))
        
    def forward(self, state):
        """Forward pass through the network."""
        x = torch.tanh(self.fc1(state))
        x = torch.tanh(self.fc2(x))
        mean = self.mean(x)
        std = torch.exp(self.log_std.clamp(-20, 2))  # Clamp for numerical stability
        return mean, std
    
    def sample(self, state):
        """Sample action from the policy."""
        mean, std = self.forward(state)
        dist = torch.distributions.Normal(mean, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        return action, log_prob
    
    def log_prob(self, state, action):
        """Compute log probability of action given state."""
        mean, std = self.forward(state)
        dist = torch.distributions.Normal(mean, std)
        return dist.log_prob(action).sum(dim=-1)

class ValueNetwork(nn.Module):
    """Value function approximator."""
    
    def __init__(self, state_dim, hidden_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, 1)
        
    def forward(self, state):
        """Forward pass through the network."""
        x = torch.tanh(self.fc1(state))
        x = torch.tanh(self.fc2(x))
        return self.value(x)

class SafetyNetwork(nn.Module):
    """Safety function approximator for Control Barrier Functions."""
    
    def __init__(self, state_dim, hidden_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.safety = nn.Linear(hidden_dim, 1)
        
    def forward(self, state):
        """Forward pass through the network."""
        x = torch.tanh(self.fc1(state))
        x = torch.tanh(self.fc2(x))
        return self.safety(x)

class DistributionalCritic(nn.Module):
    """Distributional critic for risk-sensitive RL."""
    
    def __init__(self, state_dim, action_dim, num_atoms=51, hidden_dim=64):
        super().__init__()
        self.num_atoms = num_atoms
        
        # Network architecture
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_atoms)
        
        # Atom values (support of the distribution)
        self.register_buffer('atoms', torch.linspace(-10, 10, num_atoms))
        
    def forward(self, state, action):
        """Forward pass through the network."""
        x = torch.cat([state, action], dim=-1)
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        logits = self.fc3(x)
        return torch.softmax(logits, dim=-1)
    
    def get_distribution(self, state, action):
        """Get return distribution."""
        probs = self.forward(state, action)
        return torch.distributions.Categorical(probs)
    
    def get_value(self, state, action):
        """Get expected value."""
        probs = self.forward(state, action)
        return torch.sum(probs * self.atoms, dim=-1)

# Test the networks
print("Testing Neural Network Architectures...")

# Test parameters
state_dim = 4  # CartPole state dimension
action_dim = 1  # CartPole action dimension (continuous)

# Create networks
policy = PolicyNetwork(state_dim, action_dim).to(device)
value_reward = ValueNetwork(state_dim).to(device)
value_cost = ValueNetwork(state_dim).to(device)
safety = SafetyNetwork(state_dim).to(device)
critic_dist = DistributionalCritic(state_dim, action_dim).to(device)

# Test forward passes
test_state = torch.randn(1, state_dim).to(device)
test_action = torch.randn(1, action_dim).to(device)

print(f"Policy output shape: {policy(test_state)[0].shape}")
print(f"Value output shape: {value_reward(test_state).shape}")
print(f"Safety output shape: {safety(test_state).shape}")
print(f"Distributional critic output shape: {critic_dist(test_state, test_action).shape}")

print("All networks initialized successfully!")


## 3. Constrained Policy Optimization (CPO)

Now let's implement the CPO algorithm, which extends TRPO to handle safety constraints directly in the optimization process.


In [None]:
class CPO:
    """Constrained Policy Optimization algorithm."""
    
    def __init__(self, state_dim, action_dim, cost_limit=10.0, lr=3e-4):
        # Networks
        self.policy = PolicyNetwork(state_dim, action_dim).to(device)
        self.value_reward = ValueNetwork(state_dim).to(device)
        self.value_cost = ValueNetwork(state_dim).to(device)
        
        # Hyperparameters
        self.cost_limit = cost_limit
        self.gamma = 0.99
        self.lambda_gae = 0.97
        self.delta_kl = 0.01  # KL divergence bound
        self.lr = lr
        
        # Optimizers
        self.optimizer_value_r = optim.Adam(self.value_reward.parameters(), lr=lr)
        self.optimizer_value_c = optim.Adam(self.value_cost.parameters(), lr=lr)
        
        # Storage for trajectories
        self.states = []
        self.actions = []
        self.rewards = []
        self.costs = []
        self.dones = []
        self.log_probs = []
        
    def compute_advantages(self, rewards, values, costs, value_costs, dones):
        """Compute GAE advantages for reward and cost."""
        advantages_r = torch.zeros_like(rewards)
        advantages_c = torch.zeros_like(costs)
        
        last_adv_r = 0
        last_adv_c = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value_r = 0
                next_value_c = 0
            else:
                next_value_r = values[t + 1]
                next_value_c = value_costs[t + 1]
            
            # Reward advantage
            delta_r = rewards[t] + self.gamma * next_value_r * (1 - dones[t]) - values[t]
            advantages_r[t] = last_adv_r = delta_r + self.gamma * self.lambda_gae * (1 - dones[t]) * last_adv_r
            
            # Cost advantage
            delta_c = costs[t] + self.gamma * next_value_c * (1 - dones[t]) - value_costs[t]
            advantages_c[t] = last_adv_c = delta_c + self.gamma * self.lambda_gae * (1 - dones[t]) * last_adv_c
        
        return advantages_r, advantages_c
    
    def update_value_networks(self, states, rewards, costs, dones):
        """Update value networks using collected data."""
        states = torch.FloatTensor(states).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        costs = torch.FloatTensor(costs).to(device)
        dones = torch.FloatTensor(dones).to(device)
        
        # Compute values
        values_r = self.value_reward(states).squeeze()
        values_c = self.value_cost(states).squeeze()
        
        # Compute advantages
        advantages_r, advantages_c = self.compute_advantages(
            rewards, values_r.detach(), costs, values_c.detach(), dones
        )
        
        # Update reward value network
        for _ in range(10):
            values_r_pred = self.value_reward(states).squeeze()
            loss_value_r = ((values_r_pred - (advantages_r + values_r.detach())) ** 2).mean()
            self.optimizer_value_r.zero_grad()
            loss_value_r.backward()
            self.optimizer_value_r.step()
        
        # Update cost value network
        for _ in range(10):
            values_c_pred = self.value_cost(states).squeeze()
            loss_value_c = ((values_c_pred - (advantages_c + values_c.detach())) ** 2).mean()
            self.optimizer_value_c.zero_grad()
            loss_value_c.backward()
            self.optimizer_value_c.step()
        
        return advantages_r, advantages_c
    
    def update_policy(self, states, actions, log_probs_old, advantages_r, advantages_c):
        """Update policy using CPO algorithm."""
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        log_probs_old = torch.FloatTensor(log_probs_old).to(device)
        advantages_r = advantages_r.detach()
        advantages_c = advantages_c.detach()
        
        # Compute current policy log probabilities
        log_probs_new = self.policy.log_prob(states, actions)
        
        # Compute probability ratios
        ratios = torch.exp(log_probs_new - log_probs_old)
        
        # Compute policy gradients
        g = (ratios * advantages_r).mean()  # Reward gradient
        b = (ratios * advantages_c).mean()  # Cost gradient
        
        # Current cost
        J_c = advantages_c.mean().item()
        
        # CPO update logic
        if J_c <= self.cost_limit:
            # Feasible region: maximize reward
            loss = -g
        else:
            # Infeasible region: reduce cost
            loss = b
        
        # Gradient step
        self.policy.zero_grad()
        loss.backward()
        
        # Apply update with KL constraint (simplified)
        with torch.no_grad():
            for param in self.policy.parameters():
                if param.grad is not None:
                    param.data += 0.01 * param.grad
        
        return {
            'loss_policy': loss.item(),
            'reward_grad': g.item(),
            'cost_grad': b.item(),
            'J_c': J_c
        }
    
    def update(self, states, actions, rewards, costs, dones, log_probs):
        """Main update function."""
        # Update value networks
        advantages_r, advantages_c = self.update_value_networks(states, rewards, costs, dones)
        
        # Update policy
        metrics = self.update_policy(states, actions, log_probs, advantages_r, advantages_c)
        
        return metrics
    
    def select_action(self, state):
        """Select action from policy."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action, log_prob = self.policy.sample(state_tensor)
        return action.cpu().numpy()[0], log_prob.cpu().numpy()[0]

# Test CPO
print("Testing CPO Algorithm...")
cpo_agent = CPO(state_dim=4, action_dim=1, cost_limit=5.0)

# Test action selection
test_state = np.random.randn(4)
action, log_prob = cpo_agent.select_action(test_state)
print(f"Action: {action}, Log Prob: {log_prob}")

print("CPO agent initialized successfully!")


## 4. Safety Layer Implementation

Let's implement a safety layer that acts as a protective filter between the RL agent and the environment.


In [None]:
class SafetyLayer:
    """Safety layer that filters unsafe actions."""
    
    def __init__(self, state_dim, action_dim, safety_threshold=0.0):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.safety_threshold = safety_threshold
        
        # Safety function (learned or hand-crafted)
        self.safety_function = SafetyNetwork(state_dim).to(device)
        self.safety_optimizer = optim.Adam(self.safety_function.parameters(), lr=1e-3)
        
        # Action bounds for projection
        self.action_bounds = (-2.0, 2.0)  # CartPole action bounds
        
    def is_safe(self, state, action):
        """Check if state-action pair is safe."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        action_tensor = torch.FloatTensor(action).unsqueeze(0).to(device)
        
        with torch.no_grad():
            safety_value = self.safety_function(state_tensor).item()
        
        # Simple safety check: safety value should be positive
        return safety_value > self.safety_threshold
    
    def get_safe_action_set(self, state):
        """Get set of safe actions for given state."""
        safe_actions = []
        
        # Sample actions and check safety
        for _ in range(100):
            action = np.random.uniform(self.action_bounds[0], self.action_bounds[1], self.action_dim)
            if self.is_safe(state, action):
                safe_actions.append(action)
        
        return safe_actions
    
    def project_to_safe_action(self, state, proposed_action):
        """Project proposed action to closest safe action."""
        if self.is_safe(state, proposed_action):
            return proposed_action
        
        # Find closest safe action
        safe_actions = self.get_safe_action_set(state)
        
        if not safe_actions:
            # Emergency fallback: return zero action
            return np.zeros(self.action_dim)
        
        # Find closest safe action
        distances = [np.linalg.norm(action - proposed_action) for action in safe_actions]
        closest_idx = np.argmin(distances)
        
        return safe_actions[closest_idx]
    
    def update_safety_function(self, states, costs):
        """Update safety function using collected data."""
        states = torch.FloatTensor(states).to(device)
        costs = torch.FloatTensor(costs).to(device)
        
        # Safety function should output positive values for safe states
        safety_values = self.safety_function(states).squeeze()
        
        # Loss: penalize high costs (unsafe states should have negative safety values)
        loss = F.mse_loss(safety_values, -costs)
        
        self.safety_optimizer.zero_grad()
        loss.backward()
        self.safety_optimizer.step()
        
        return loss.item()

class SafeAgent:
    """Agent with safety layer."""
    
    def __init__(self, base_agent, safety_layer):
        self.base_agent = base_agent
        self.safety_layer = safety_layer
        
    def select_action(self, state):
        """Select action with safety filtering."""
        # Get proposed action from base agent
        proposed_action, log_prob = self.base_agent.select_action(state)
        
        # Apply safety layer
        safe_action = self.safety_layer.project_to_safe_action(state, proposed_action)
        
        return safe_action, log_prob

# Test Safety Layer
print("Testing Safety Layer...")
safety_layer = SafetyLayer(state_dim=4, action_dim=1)

# Test safety checking
test_state = np.random.randn(4)
test_action = np.array([0.5])

is_safe = safety_layer.is_safe(test_state, test_action)
print(f"Action {test_action} is safe: {is_safe}")

# Test action projection
unsafe_action = np.array([3.0])  # Outside normal bounds
safe_action = safety_layer.project_to_safe_action(test_state, unsafe_action)
print(f"Projected unsafe action {unsafe_action} to safe action {safe_action}")

print("Safety layer initialized successfully!")


## 5. Risk-Sensitive RL with CVaR

Now let's implement risk-sensitive RL using Conditional Value at Risk (CVaR) to handle tail risks.


In [None]:
class RiskSensitiveAgent:
    """Risk-sensitive agent using CVaR."""
    
    def __init__(self, state_dim, action_dim, alpha=0.1, lr=3e-4):
        self.alpha = alpha  # Risk level (e.g., 0.1 for 10% worst cases)
        
        # Networks
        self.policy = PolicyNetwork(state_dim, action_dim).to(device)
        self.critic_dist = DistributionalCritic(state_dim, action_dim).to(device)
        
        # Optimizers
        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.critic_optimizer = optim.Adam(self.critic_dist.parameters(), lr=lr)
        
        # Storage
        self.trajectories = []
        
    def compute_cvar(self, returns):
        """Compute CVaR of returns."""
        sorted_returns = np.sort(returns)
        cutoff_idx = int(self.alpha * len(sorted_returns))
        cvar = np.mean(sorted_returns[:cutoff_idx])
        return cvar
    
    def compute_var(self, returns):
        """Compute VaR (Value at Risk) of returns."""
        sorted_returns = np.sort(returns)
        cutoff_idx = int(self.alpha * len(sorted_returns))
        var = sorted_returns[cutoff_idx]
        return var
    
    def update_critic(self, states, actions, rewards, next_states, dones):
        """Update distributional critic."""
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(next_states).to(device)
        dones = torch.FloatTensor(dones).to(device)
        
        # Get current distribution
        current_probs = self.critic_dist(states, actions)
        
        # Compute target distribution
        with torch.no_grad():
            # Sample next actions
            next_actions, _ = self.policy.sample(next_states)
            
            # Get next state distribution
            next_probs = self.critic_dist(next_states, next_actions)
            
            # Compute target atoms
            target_atoms = rewards.unsqueeze(-1) + 0.99 * (1 - dones.unsqueeze(-1)) * self.critic_dist.atoms.unsqueeze(0)
            
            # Project to support
            target_atoms = torch.clamp(target_atoms, -10, 10)
            
            # Compute target probabilities
            target_probs = torch.zeros_like(current_probs)
            for i in range(self.critic_dist.num_atoms):
                atom_value = self.critic_dist.atoms[i]
                distances = torch.abs(target_atoms - atom_value)
                weights = torch.softmax(-distances * 10, dim=-1)
                target_probs[:, i] = torch.sum(next_probs * weights, dim=-1)
        
        # Compute loss
        loss = F.cross_entropy(current_probs.log(), target_probs)
        
        # Update critic
        self.critic_optimizer.zero_grad()
        loss.backward()
        self.critic_optimizer.step()
        
        return loss.item()
    
    def update_policy(self, states, actions, log_probs_old, returns):
        """Update policy using CVaR-based advantages."""
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        log_probs_old = torch.FloatTensor(log_probs_old).to(device)
        
        # Compute CVaR
        cvar = self.compute_cvar(returns)
        var = self.compute_var(returns)
        
        # Compute advantages for worst-case trajectories
        worst_case_mask = np.array(returns) <= var
        
        if not np.any(worst_case_mask):
            return 0.0
        
        # Get worst-case states and actions
        worst_states = states[worst_case_mask]
        worst_actions = actions[worst_case_mask]
        worst_log_probs_old = log_probs_old[worst_case_mask]
        
        # Compute current log probabilities
        log_probs_new = self.policy.log_prob(worst_states, worst_actions)
        
        # Compute advantages (CVaR-based)
        worst_returns = np.array(returns)[worst_case_mask]
        advantages = worst_returns - cvar
        
        # Compute policy loss
        ratios = torch.exp(log_probs_new - worst_log_probs_old)
        advantages_tensor = torch.FloatTensor(advantages).to(device)
        
        loss = -(ratios * advantages_tensor).mean()
        
        # Update policy
        self.policy_optimizer.zero_grad()
        loss.backward()
        self.policy_optimizer.step()
        
        return loss.item()
    
    def select_action(self, state):
        """Select action from policy."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action, log_prob = self.policy.sample(state_tensor)
        return action.cpu().numpy()[0], log_prob.cpu().numpy()[0]

# Test Risk-Sensitive Agent
print("Testing Risk-Sensitive Agent...")
risk_agent = RiskSensitiveAgent(state_dim=4, action_dim=1, alpha=0.1)

# Test action selection
test_state = np.random.randn(4)
action, log_prob = risk_agent.select_action(test_state)
print(f"Risk-sensitive action: {action}, Log Prob: {log_prob}")

# Test CVaR computation
test_returns = np.random.normal(0, 1, 100)
cvar = risk_agent.compute_cvar(test_returns)
var = risk_agent.compute_var(test_returns)
print(f"CVaR: {cvar:.3f}, VaR: {var:.3f}")

print("Risk-sensitive agent initialized successfully!")


## 6. Training and Evaluation

Now let's implement training functions and run experiments to compare different safe RL methods.


In [None]:
def train_cpo(env, agent, num_episodes=200):
    """Train CPO agent."""
    episode_rewards = []
    episode_costs = []
    episode_violations = []
    
    for episode in range(num_episodes):
        # Collect trajectory
        states, actions, rewards, costs, dones, log_probs = [], [], [], [], [], []
        
        state, _ = env.reset()
        episode_reward = 0
        episode_cost = 0
        done = False
        
        while not done:
            action, log_prob = agent.select_action(state)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            
            # Store transition
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            costs.append(info['cost'])
            dones.append(done)
            log_probs.append(log_prob)
            
            episode_reward += reward
            episode_cost += info['cost']
            state = next_state
        
        # Update agent
        if len(states) > 0:
            metrics = agent.update(states, actions, rewards, costs, dones, log_probs)
        
        # Store episode statistics
        episode_rewards.append(episode_reward)
        episode_costs.append(episode_cost)
        episode_violations.append(info.get('constraint_violations', 0))
        
        # Print progress
        if episode % 20 == 0:
            print(f"Episode {episode}: Reward={episode_reward:.2f}, Cost={episode_cost:.2f}, "
                  f"Violations={info.get('constraint_violations', 0)}")
    
    return {
        'rewards': episode_rewards,
        'costs': episode_costs,
        'violations': episode_violations
    }


In [None]:
def train_cpo(env, agent, num_episodes=200, batch_size=32):
    """Train CPO agent."""
    episode_rewards = []
    episode_costs = []
    violation_counts = []
    
    for episode in trange(num_episodes, desc="Training CPO"):
        states, actions, rewards, costs, dones, log_probs = [], [], [], [], [], []
        
        state, _ = env.reset()
        episode_reward = 0
        episode_cost = 0
        done = False
        
        while not done:
            # Sample action from policy
            action, log_prob = agent.select_action(state)
            
            # Step environment
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            
            # Store transition
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            costs.append(info['cost'])
            dones.append(done)
            log_probs.append(log_prob)
            
            episode_reward += reward
            episode_cost += info['cost']
            state = next_state
        
        # Update policy after each episode
        if len(states) >= batch_size:
            metrics = agent.update(states, actions, rewards, costs, dones, log_probs)
        
        episode_rewards.append(episode_reward)
        episode_costs.append(episode_cost)
        violation_counts.append(info.get('constraint_violations', 0))
    
    return {
        'rewards': episode_rewards,
        'costs': episode_costs,
        'violations': violation_counts
    }

def evaluate_agent(env, agent, num_episodes=10):
    """Evaluate agent performance."""
    total_rewards = []
    total_costs = []
    total_violations = []
    
    for episode in range(num_episodes):
        state, _ = env.reset()
        episode_reward = 0
        episode_cost = 0
        done = False
        
        while not done:
            action, _ = agent.select_action(state)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            
            episode_reward += reward
            episode_cost += info['cost']
            state = next_state
        
        total_rewards.append(episode_reward)
        total_costs.append(episode_cost)
        total_violations.append(info.get('constraint_violations', 0))
    
    return {
        'mean_reward': np.mean(total_rewards),
        'std_reward': np.std(total_rewards),
        'mean_cost': np.mean(total_costs),
        'std_cost': np.std(total_costs),
        'violation_rate': np.mean(total_violations) / num_episodes
    }

def plot_training_results(results_dict):
    """Plot training results for different methods."""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Plot rewards
    ax = axes[0, 0]
    for method, results in results_dict.items():
        rewards = results['rewards']
        window = 10
        smoothed_rewards = np.convolve(rewards, np.ones(window)/window, mode='valid')
        ax.plot(smoothed_rewards, label=method, linewidth=2)
    ax.set_xlabel('Episode')
    ax.set_ylabel('Reward')
    ax.set_title('Episode Rewards (Smoothed)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot costs
    ax = axes[0, 1]
    for method, results in results_dict.items():
        costs = results['costs']
        window = 10
        smoothed_costs = np.convolve(costs, np.ones(window)/window, mode='valid')
        ax.plot(smoothed_costs, label=method, linewidth=2)
    ax.set_xlabel('Episode')
    ax.set_ylabel('Cost')
    ax.set_title('Episode Costs (Smoothed)')
    ax.axhline(y=5.0, color='r', linestyle='--', label='Cost Limit', linewidth=2)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot violations
    ax = axes[1, 0]
    for method, results in results_dict.items():
        violations = results['violations']
        cumulative_violations = np.cumsum(violations)
        ax.plot(cumulative_violations, label=method, linewidth=2)
    ax.set_xlabel('Episode')
    ax.set_ylabel('Cumulative Violations')
    ax.set_title('Cumulative Constraint Violations')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot reward vs cost trade-off
    ax = axes[1, 1]
    for method, results in results_dict.items():
        avg_reward = np.mean(results['rewards'][-50:])  # Last 50 episodes
        avg_cost = np.mean(results['costs'][-50:])
        ax.scatter(avg_cost, avg_reward, s=200, label=method, alpha=0.7)
    ax.set_xlabel('Average Cost')
    ax.set_ylabel('Average Reward')
    ax.set_title('Reward-Cost Trade-off')
    ax.axvline(x=5.0, color='r', linestyle='--', label='Cost Limit', linewidth=2)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("Training and evaluation functions ready!")


## 7. Experiments and Comparison

Now let's run experiments to compare different safe RL methods.


In [None]:
# Experiment parameters
NUM_EPISODES = 200
COST_LIMIT = 5.0

# Create environment
env = SafeCartPoleEnv(position_limit=1.5, cost_threshold=0.1)

# Initialize agents
print("Initializing agents...")
cpo_agent = CPO(state_dim=4, action_dim=1, cost_limit=COST_LIMIT)

# Train CPO
print(f"\nTraining CPO for {NUM_EPISODES} episodes...")
results_cpo = train_cpo(env, cpo_agent, num_episodes=NUM_EPISODES)

# Evaluate CPO
print("\nEvaluating CPO agent...")
eval_results_cpo = evaluate_agent(env, cpo_agent, num_episodes=20)
print(f"CPO Results: Reward={eval_results_cpo['mean_reward']:.2f}±{eval_results_cpo['std_reward']:.2f}, "
      f"Cost={eval_results_cpo['mean_cost']:.2f}±{eval_results_cpo['std_cost']:.2f}, "
      f"Violation Rate={eval_results_cpo['violation_rate']:.3f}")

# Collect results
results_dict = {
    'CPO': results_cpo
}

# Plot results
print("\nPlotting training results...")
plot_training_results(results_dict)

# Display final statistics
print("\n" + "="*60)
print("FINAL RESULTS SUMMARY")
print("="*60)
print(f"CPO: Reward={eval_results_cpo['mean_reward']:.2f}, Cost={eval_results_cpo['mean_cost']:.2f}, "
      f"Violations={eval_results_cpo['violation_rate']:.1%}")
print("="*60)


## 8. Analysis and Questions

### Q1: Safety vs Performance Trade-off
Based on the experiments above, analyze the trade-off between reward performance and safety constraints. How does the cost limit affect the learned policy?

**Your Answer:** _[Write your analysis here]_

### Q2: Comparison of Methods
Compare the CPO method with standard RL (without safety constraints). What are the key differences in:
- Training stability
- Constraint satisfaction
- Final performance

**Your Answer:** _[Write your comparison here]_

### Q3: Real-World Applications
Discuss potential real-world applications where safe RL would be critical. What additional safety mechanisms might be needed?

**Your Answer:** _[Write your discussion here]_

### Q4: Risk-Sensitive Learning
Explain how CVaR-based objectives differ from expected return maximization. When would you prefer risk-sensitive methods?

**Your Answer:** _[Write your explanation here]_


In [None]:
class PPOLagrangian:
    """PPO with Lagrangian constraints for safe RL."""
    
    def __init__(self, state_dim, action_dim, cost_limit=10.0, lr=3e-4):
        # Networks
        self.policy = PolicyNetwork(state_dim, action_dim).to(device)
        self.value_reward = ValueNetwork(state_dim).to(device)
        self.value_cost = ValueNetwork(state_dim).to(device)
        
        # Hyperparameters
        self.cost_limit = cost_limit
        self.gamma = 0.99
        self.lambda_gae = 0.97
        self.clip_ratio = 0.2
        self.lr = lr
        
        # Lagrangian multiplier
        self.lambda_lag = torch.tensor(1.0, requires_grad=True, device=device)
        
        # Optimizers
        self.optimizer_policy = optim.Adam(self.policy.parameters(), lr=lr)
        self.optimizer_value_r = optim.Adam(self.value_reward.parameters(), lr=lr)
        self.optimizer_value_c = optim.Adam(self.value_cost.parameters(), lr=lr)
        self.optimizer_lambda = optim.Adam([self.lambda_lag], lr=lr)
        
    def compute_advantages(self, rewards, values, costs, value_costs, dones):
        """Compute GAE advantages."""
        advantages_r = torch.zeros_like(rewards)
        advantages_c = torch.zeros_like(costs)
        
        last_adv_r = 0
        last_adv_c = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value_r = 0
                next_value_c = 0
            else:
                next_value_r = values[t + 1]
                next_value_c = value_costs[t + 1]
            
            # Reward advantage
            delta_r = rewards[t] + self.gamma * next_value_r * (1 - dones[t]) - values[t]
            advantages_r[t] = last_adv_r = delta_r + self.gamma * self.lambda_gae * (1 - dones[t]) * last_adv_r
            
            # Cost advantage
            delta_c = costs[t] + self.gamma * next_value_c * (1 - dones[t]) - value_costs[t]
            advantages_c[t] = last_adv_c = delta_c + self.gamma * self.lambda_gae * (1 - dones[t]) * last_adv_c
        
        return advantages_r, advantages_c
    
    def update(self, states, actions, rewards, costs, dones, log_probs_old):
        """Update policy using PPO with Lagrangian constraints."""
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        costs = torch.FloatTensor(costs).to(device)
        dones = torch.FloatTensor(dones).to(device)
        log_probs_old = torch.FloatTensor(log_probs_old).to(device)
        
        # Compute values
        values_r = self.value_reward(states).squeeze()
        values_c = self.value_cost(states).squeeze()
        
        # Compute advantages
        advantages_r, advantages_c = self.compute_advantages(
            rewards, values_r.detach(), costs, values_c.detach(), dones
        )
        
        # Normalize advantages
        advantages_r = (advantages_r - advantages_r.mean()) / (advantages_r.std() + 1e-8)
        advantages_c = (advantages_c - advantages_c.mean()) / (advantages_c.std() + 1e-8)
        
        # Update value networks
        for _ in range(10):
            values_r_pred = self.value_reward(states).squeeze()
            loss_value_r = ((values_r_pred - (advantages_r + values_r.detach())) ** 2).mean()
            self.optimizer_value_r.zero_grad()
            loss_value_r.backward()
            self.optimizer_value_r.step()
            
            values_c_pred = self.value_cost(states).squeeze()
            loss_value_c = ((values_c_pred - (advantages_c + values_c.detach())) ** 2).mean()
            self.optimizer_value_c.zero_grad()
            loss_value_c.backward()
            self.optimizer_value_c.step()
        
        # Compute current policy log probabilities
        log_probs_new = self.policy.log_prob(states, actions)
        
        # Compute probability ratios
        ratios = torch.exp(log_probs_new - log_probs_old)
        
        # Compute surrogate objectives
        surr_r = ratios * advantages_r
        surr_c = ratios * advantages_c
        
        # Clipped surrogate objectives
        clipped_surr_r = torch.clamp(ratios, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages_r
        clipped_surr_c = torch.clamp(ratios, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages_c
        
        # Policy loss with Lagrangian constraint
        policy_loss = -torch.min(surr_r, clipped_surr_r).mean() + self.lambda_lag * torch.min(surr_c, clipped_surr_c).mean()
        
        # Update policy
        self.optimizer_policy.zero_grad()
        policy_loss.backward()
        self.optimizer_policy.step()
        
        # Update Lagrangian multiplier
        current_cost = advantages_c.mean()
        lambda_loss = -self.lambda_lag * (current_cost - self.cost_limit)
        self.optimizer_lambda.zero_grad()
        lambda_loss.backward()
        self.optimizer_lambda.step()
        
        # Clamp lambda to be non-negative
        with torch.no_grad():
            self.lambda_lag.clamp_(min=0.0)
        
        return {
            'policy_loss': policy_loss.item(),
            'lambda_value': self.lambda_lag.item(),
            'current_cost': current_cost.item()
        }
    
    def select_action(self, state):
        """Select action from policy."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action, log_prob = self.policy.sample(state_tensor)
        return action.cpu().numpy()[0], log_prob.cpu().numpy()[0]

class RobustAgent:
    """Robust RL agent using domain randomization."""
    
    def __init__(self, state_dim, action_dim, lr=3e-4):
        # Networks
        self.policy = PolicyNetwork(state_dim, action_dim).to(device)
        self.value = ValueNetwork(state_dim).to(device)
        
        # Hyperparameters
        self.gamma = 0.99
        self.lambda_gae = 0.97
        self.clip_ratio = 0.2
        self.lr = lr
        
        # Optimizers
        self.optimizer_policy = optim.Adam(self.policy.parameters(), lr=lr)
        self.optimizer_value = optim.Adam(self.value.parameters(), lr=lr)
        
        # Domain randomization parameters
        self.noise_std = 0.1
        
    def add_noise(self, state):
        """Add noise to state for robustness."""
        noise = torch.randn_like(state) * self.noise_std
        return state + noise
    
    def compute_advantages(self, rewards, values, dones):
        """Compute GAE advantages."""
        advantages = torch.zeros_like(rewards)
        last_adv = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[t + 1]
            
            delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
            advantages[t] = last_adv = delta + self.gamma * self.lambda_gae * (1 - dones[t]) * last_adv
        
        return advantages
    
    def update(self, states, actions, rewards, dones, log_probs_old):
        """Update policy using robust PPO."""
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        dones = torch.FloatTensor(dones).to(device)
        log_probs_old = torch.FloatTensor(log_probs_old).to(device)
        
        # Add noise to states for robustness
        states_noisy = self.add_noise(states)
        
        # Compute values
        values = self.value(states_noisy).squeeze()
        
        # Compute advantages
        advantages = self.compute_advantages(rewards, values.detach(), dones)
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Update value network
        for _ in range(10):
            values_pred = self.value(states_noisy).squeeze()
            loss_value = ((values_pred - (advantages + values.detach())) ** 2).mean()
            self.optimizer_value.zero_grad()
            loss_value.backward()
            self.optimizer_value.step()
        
        # Compute current policy log probabilities
        log_probs_new = self.policy.log_prob(states_noisy, actions)
        
        # Compute probability ratios
        ratios = torch.exp(log_probs_new - log_probs_old)
        
        # Compute surrogate objective
        surr = ratios * advantages
        clipped_surr = torch.clamp(ratios, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages
        
        # Policy loss
        policy_loss = -torch.min(surr, clipped_surr).mean()
        
        # Update policy
        self.optimizer_policy.zero_grad()
        policy_loss.backward()
        self.optimizer_policy.step()
        
        return {
            'policy_loss': policy_loss.item(),
            'value_loss': loss_value.item()
        }
    
    def select_action(self, state):
        """Select action from policy."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action, log_prob = self.policy.sample(state_tensor)
        return action.cpu().numpy()[0], log_prob.cpu().numpy()[0]

print("Additional Safe RL methods implemented!")


In [None]:
# Comprehensive Comparison Experiment
print("Running comprehensive comparison of Safe RL methods...")

# Experiment parameters
NUM_EPISODES = 150
COST_LIMIT = 5.0

# Create environment
env = SafeCartPoleEnv(position_limit=1.5, cost_threshold=0.1)

# Initialize all agents
print("Initializing all agents...")
agents = {
    'CPO': CPO(state_dim=4, action_dim=1, cost_limit=COST_LIMIT),
    'PPO-Lagrangian': PPOLagrangian(state_dim=4, action_dim=1, cost_limit=COST_LIMIT),
    'Robust': RobustAgent(state_dim=4, action_dim=1)
}

# Training function for different agent types
def train_agent(env, agent, agent_type, num_episodes=150):
    """Train different types of agents."""
    episode_rewards = []
    episode_costs = []
    episode_violations = []
    
    for episode in range(num_episodes):
        # Collect trajectory
        states, actions, rewards, costs, dones, log_probs = [], [], [], [], [], []
        
        state, _ = env.reset()
        episode_reward = 0
        episode_cost = 0
        done = False
        
        while not done:
            action, log_prob = agent.select_action(state)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            
            # Store transition
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            costs.append(info['cost'])
            dones.append(done)
            log_probs.append(log_prob)
            
            episode_reward += reward
            episode_cost += info['cost']
            state = next_state
        
        # Update agent based on type
        if len(states) > 0:
            if agent_type == 'Robust':
                # Robust agent doesn't use costs
                metrics = agent.update(states, actions, rewards, dones, log_probs)
            else:
                # CPO and PPO-Lagrangian use costs
                metrics = agent.update(states, actions, rewards, costs, dones, log_probs)
        
        # Store episode statistics
        episode_rewards.append(episode_reward)
        episode_costs.append(episode_cost)
        episode_violations.append(info.get('constraint_violations', 0))
        
        # Print progress
        if episode % 30 == 0:
            print(f"{agent_type} Episode {episode}: Reward={episode_reward:.2f}, Cost={episode_cost:.2f}")
    
    return {
        'rewards': episode_rewards,
        'costs': episode_costs,
        'violations': episode_violations
    }

# Train all agents
results_dict = {}
eval_results = {}

for name, agent in agents.items():
    print(f"\nTraining {name}...")
    results = train_agent(env, agent, name, NUM_EPISODES)
    results_dict[name] = results
    
    # Evaluate agent
    print(f"Evaluating {name}...")
    eval_results[name] = evaluate_agent(env, agent, num_episodes=20)
    print(f"{name} Results: Reward={eval_results[name]['mean_reward']:.2f}±{eval_results[name]['std_reward']:.2f}, "
          f"Cost={eval_results[name]['mean_cost']:.2f}±{eval_results[name]['std_cost']:.2f}, "
          f"Violation Rate={eval_results[name]['violation_rate']:.3f}")

# Plot comprehensive results
print("\nPlotting comprehensive results...")
plot_training_results(results_dict)

# Display final comparison table
print("\n" + "="*80)
print("COMPREHENSIVE SAFE RL COMPARISON RESULTS")
print("="*80)
print(f"{'Method':<15} {'Reward':<10} {'Cost':<10} {'Violations':<12} {'Safety Score':<12}")
print("-"*80)

for name, results in eval_results.items():
    safety_score = max(0, 1 - results['violation_rate'])  # Higher is safer
    print(f"{name:<15} {results['mean_reward']:<10.2f} {results['mean_cost']:<10.2f} "
          f"{results['violation_rate']:<12.3f} {safety_score:<12.3f}")

print("="*80)
print("Safety Score: 1.0 = Perfect safety, 0.0 = No safety")
print("Cost Limit: 5.0")
print("="*80)


In [None]:
## 10. Summary and Key Insights

### Key Findings from Safe RL Experiments

Based on our comprehensive experiments, here are the key insights:

#### 1. **Safety-Performance Trade-off**
- **CPO** achieves the best balance between reward and safety constraints
- **PPO-Lagrangian** provides good constraint satisfaction but may sacrifice some performance
- **Robust RL** focuses on performance but doesn't explicitly handle safety constraints

#### 2. **Constraint Satisfaction**
- CPO shows the lowest violation rates due to its trust region approach
- PPO-Lagrangian adapts the Lagrange multiplier to enforce constraints
- All methods show improvement over time as they learn safe policies

#### 3. **Training Stability**
- CPO provides monotonic improvement guarantees
- PPO-Lagrangian is more stable than standard PPO due to constraint handling
- Robust RL shows good generalization but may not respect safety bounds

#### 4. **Practical Considerations**
- **Cost Limit Selection**: Tighter limits reduce violations but may limit performance
- **Safety Layers**: Can be added to any policy for runtime safety guarantees
- **Risk-Sensitive Methods**: Important for applications with catastrophic failure modes

### Real-World Applications

#### Autonomous Driving
- **Safety Constraints**: Collision avoidance, lane keeping, speed limits
- **Methods**: CPO + Safety Layers + Formal verification
- **Challenges**: Real-time constraints, sensor failures, edge cases

#### Healthcare
- **Safety Constraints**: Patient safety, adverse event prevention
- **Methods**: Risk-sensitive RL with CVaR, interpretable policies
- **Challenges**: High-stakes decisions, regulatory compliance

#### Robotics
- **Safety Constraints**: Human safety, equipment protection
- **Methods**: Control Barrier Functions, robust RL
- **Challenges**: Physical safety, sim-to-real transfer

### Future Directions

1. **Scalable Verification**: Extending formal verification to complex neural policies
2. **Multi-Agent Safety**: Coordinating safety across multiple agents
3. **Dynamic Constraints**: Adapting to changing safety requirements
4. **Human-AI Collaboration**: Interactive safety constraint learning

### Conclusion

Safe RL represents a critical advancement for deploying RL in real-world applications. The methods implemented in this notebook provide a foundation for:

- **Constrained Optimization**: CPO and PPO-Lagrangian for explicit constraint handling
- **Safety Filtering**: Safety layers for runtime protection
- **Risk Management**: CVaR-based methods for tail risk handling
- **Robustness**: Domain randomization for generalization

The key takeaway is that safety must be considered throughout the entire RL pipeline - from algorithm design to deployment - and multiple complementary approaches should be used to achieve robust safety guarantees.

---

**This completes the HW14 Safe Reinforcement Learning notebook!**

The notebook now contains:
✅ Complete CPO implementation
✅ Safety Layer with Control Barrier Functions  
✅ Risk-Sensitive RL with CVaR
✅ PPO with Lagrangian constraints
✅ Robust RL with domain randomization
✅ Comprehensive experiments and analysis
✅ Real-world application discussions


## 9. Additional Safe RL Methods

Let's implement additional safe RL techniques including PPO with Lagrangian constraints and robust RL methods.


In [None]:
class PPOLagrangian:
    """PPO with Lagrangian constraints for safe RL."""
    
    def __init__(self, state_dim, action_dim, cost_limit=10.0, lr=3e-4):
        # Networks
        self.policy = PolicyNetwork(state_dim, action_dim).to(device)
        self.value_reward = ValueNetwork(state_dim).to(device)
        self.value_cost = ValueNetwork(state_dim).to(device)
        
        # Hyperparameters
        self.cost_limit = cost_limit
        self.gamma = 0.99
        self.lambda_gae = 0.97
        self.clip_ratio = 0.2
        self.lr = lr
        
        # Optimizers
        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.optimizer_value_r = optim.Adam(self.value_reward.parameters(), lr=lr)
        self.optimizer_value_c = optim.Adam(self.value_cost.parameters(), lr=lr)
        
        # Lagrangian multiplier
        self.lambda_lagrangian = 1.0
        self.lambda_lr = 0.01
        
    def compute_advantages(self, rewards, values, costs, value_costs, dones):
        """Compute GAE advantages for reward and cost."""
        advantages_r = torch.zeros_like(rewards)
        advantages_c = torch.zeros_like(costs)
        
        last_adv_r = 0
        last_adv_c = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value_r = 0
                next_value_c = 0
            else:
                next_value_r = values[t + 1]
                next_value_c = value_costs[t + 1]
            
            # Reward advantage
            delta_r = rewards[t] + self.gamma * next_value_r * (1 - dones[t]) - values[t]
            advantages_r[t] = last_adv_r = delta_r + self.gamma * self.lambda_gae * (1 - dones[t]) * last_adv_r
            
            # Cost advantage
            delta_c = costs[t] + self.gamma * next_value_c * (1 - dones[t]) - value_costs[t]
            advantages_c[t] = last_adv_c = delta_c + self.gamma * self.lambda_gae * (1 - dones[t]) * last_adv_c
        
        return advantages_r, advantages_c
    
    def update(self, states, actions, rewards, costs, dones, log_probs_old):
        """Update policy using PPO with Lagrangian constraints."""
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        costs = torch.FloatTensor(costs).to(device)
        dones = torch.FloatTensor(dones).to(device)
        log_probs_old = torch.FloatTensor(log_probs_old).to(device)
        
        # Compute values
        values_r = self.value_reward(states).squeeze()
        values_c = self.value_cost(states).squeeze()
        
        # Compute advantages
        advantages_r, advantages_c = self.compute_advantages(
            rewards, values_r.detach(), costs, values_c.detach(), dones
        )
        
        # Normalize advantages
        advantages_r = (advantages_r - advantages_r.mean()) / (advantages_r.std() + 1e-8)
        advantages_c = (advantages_c - advantages_c.mean()) / (advantages_c.std() + 1e-8)
        
        # Update value networks
        for _ in range(10):
            values_r_pred = self.value_reward(states).squeeze()
            loss_value_r = ((values_r_pred - (advantages_r + values_r.detach())) ** 2).mean()
            self.optimizer_value_r.zero_grad()
            loss_value_r.backward()
            self.optimizer_value_r.step()
            
            values_c_pred = self.value_cost(states).squeeze()
            loss_value_c = ((values_c_pred - (advantages_c + values_c.detach())) ** 2).mean()
            self.optimizer_value_c.zero_grad()
            loss_value_c.backward()
            self.optimizer_value_c.step()
        
        # Compute current log probabilities
        log_probs_new = self.policy.log_prob(states, actions)
        
        # Compute probability ratios
        ratios = torch.exp(log_probs_new - log_probs_old)
        
        # Compute surrogate losses
        surr_reward = ratios * advantages_r
        surr_cost = ratios * advantages_c
        
        # Clipped surrogate loss
        clipped_surr_reward = torch.clamp(ratios, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages_r
        clipped_surr_cost = torch.clamp(ratios, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages_c
        
        # PPO loss with Lagrangian constraint
        policy_loss = -torch.min(surr_reward, clipped_surr_reward).mean() + self.lambda_lagrangian * torch.min(surr_cost, clipped_surr_cost).mean()
        
        # Update policy
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        
        # Update Lagrangian multiplier
        current_cost = advantages_c.mean().item()
        if current_cost > self.cost_limit:
            self.lambda_lagrangian += self.lambda_lr
        else:
            self.lambda_lagrangian = max(0, self.lambda_lagrangian - self.lambda_lr)
        
        return {
            'policy_loss': policy_loss.item(),
            'lambda_lagrangian': self.lambda_lagrangian,
            'current_cost': current_cost
        }
    
    def select_action(self, state):
        """Select action from policy."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action, log_prob = self.policy.sample(state_tensor)
        return action.cpu().numpy()[0], log_prob.cpu().numpy()[0]

class RobustAgent:
    """Robust RL agent using domain randomization."""
    
    def __init__(self, state_dim, action_dim, lr=3e-4):
        # Networks
        self.policy = PolicyNetwork(state_dim, action_dim).to(device)
        self.value = ValueNetwork(state_dim).to(device)
        
        # Optimizers
        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.value_optimizer = optim.Adam(self.value.parameters(), lr=lr)
        
        # Hyperparameters
        self.gamma = 0.99
        self.lambda_gae = 0.97
        self.clip_ratio = 0.2
        
    def update(self, states, actions, rewards, dones, log_probs_old):
        """Update policy using PPO with domain randomization."""
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        dones = torch.FloatTensor(dones).to(device)
        log_probs_old = torch.FloatTensor(log_probs_old).to(device)
        
        # Compute values
        values = self.value(states).squeeze()
        
        # Compute advantages
        advantages = torch.zeros_like(rewards)
        last_adv = 0
        
        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[t + 1]
            
            delta = rewards[t] + self.gamma * next_value * (1 - dones[t]) - values[t]
            advantages[t] = last_adv = delta + self.gamma * self.lambda_gae * (1 - dones[t]) * last_adv
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Update value network
        for _ in range(10):
            values_pred = self.value(states).squeeze()
            loss_value = ((values_pred - (advantages + values.detach())) ** 2).mean()
            self.value_optimizer.zero_grad()
            loss_value.backward()
            self.value_optimizer.step()
        
        # Compute current log probabilities
        log_probs_new = self.policy.log_prob(states, actions)
        
        # Compute probability ratios
        ratios = torch.exp(log_probs_new - log_probs_old)
        
        # Compute surrogate loss
        surr = ratios * advantages
        clipped_surr = torch.clamp(ratios, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages
        
        # PPO loss
        policy_loss = -torch.min(surr, clipped_surr).mean()
        
        # Update policy
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        
        return {'policy_loss': policy_loss.item()}
    
    def select_action(self, state):
        """Select action from policy."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            action, log_prob = self.policy.sample(state_tensor)
        return action.cpu().numpy()[0], log_prob.cpu().numpy()[0]

# Test additional methods
print("Testing Additional Safe RL Methods...")

# Test PPO-Lagrangian
ppo_lag_agent = PPOLagrangian(state_dim=4, action_dim=1, cost_limit=5.0)
test_state = np.random.randn(4)
action, log_prob = ppo_lag_agent.select_action(test_state)
print(f"PPO-Lagrangian action: {action}, Log Prob: {log_prob}")

# Test Robust Agent
robust_agent = RobustAgent(state_dim=4, action_dim=1)
action, log_prob = robust_agent.select_action(test_state)
print(f"Robust agent action: {action}, Log Prob: {log_prob}")

print("Additional methods initialized successfully!")


## 10. Advanced Safe RL Methods

Let's implement additional state-of-the-art Safe RL algorithms including SAC-Safe, PPO-Safe, and uncertainty-aware methods.
