# HW14: Safe Reinforcement Learning

> - Full Name: **[Your Full Name]**
> - Student ID: **[Your Student ID]**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DeepRLCourse/Homework-14-Questions/blob/main/HW14_Notebook.ipynb)
[![Open In kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/DeepRLCourse/Homework-14-Questions/main/HW14_Notebook.ipynb)

## Overview
This assignment focuses on **Safe Reinforcement Learning**, exploring methods to train agents that not only maximize rewards but also satisfy safety constraints during both training and deployment. We'll implement and experiment with:

1. **Constrained Policy Optimization (CPO)**
2. **Safety Layers and Shielding**
3. **Risk-Sensitive RL (CVaR)**
4. **Safe Exploration Techniques**
5. **Robust RL Methods**

The goal is to understand the fundamental trade-offs between performance and safety in RL systems.


In [None]:
# @title Imports and Setup

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gymnasium as gym
from gymnasium import spaces
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import trange
from collections import defaultdict, deque
import random
from typing import Tuple, List, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


## 1. Safe Environment Setup

First, let's create a safe version of the CartPole environment where the agent must balance the pole while keeping the cart position within safe bounds.


In [None]:
class SafeCartPoleEnv(gym.Env):
    """
    Safe CartPole environment with position constraints.
    The agent must balance the pole while keeping cart position within safe bounds.
    """
    
    def __init__(self, position_limit=1.5, cost_threshold=0.1):
        super().__init__()
        
        # Create base CartPole environment
        self.env = gym.make('CartPole-v1')
        self.observation_space = self.env.observation_space
        self.action_space = self.env.action_space
        
        # Safety parameters
        self.position_limit = position_limit
        self.cost_threshold = cost_threshold
        
        # Track episode statistics
        self.episode_cost = 0
        self.episode_reward = 0
        self.constraint_violations = 0
        
    def reset(self, **kwargs):
        """Reset environment and return initial observation."""
        obs, info = self.env.reset(**kwargs)
        self.episode_cost = 0
        self.episode_reward = 0
        self.constraint_violations = 0
        return obs, info
    
    def step(self, action):
        """Execute action and return next state, reward, done, truncated, info."""
        obs, reward, terminated, truncated, info = self.env.step(action)
        
        # Extract cart position (first element of observation)
        cart_position = obs[0]
        
        # Compute cost based on position constraint violation
        position_violation = max(0, abs(cart_position) - self.position_limit)
        cost = position_violation
        
        # Update episode statistics
        self.episode_cost += cost
        self.episode_reward += reward
        
        # Check for constraint violation
        if cost > self.cost_threshold:
            self.constraint_violations += 1
            # Terminate episode if constraint violated
            terminated = True
            reward = -100  # Large penalty for constraint violation
        
        # Add cost information to info
        info['cost'] = cost
        info['episode_cost'] = self.episode_cost
        info['episode_reward'] = self.episode_reward
        info['constraint_violations'] = self.constraint_violations
        
        return obs, reward, terminated, truncated, info
    
    def render(self, mode='human'):
        """Render the environment."""
        return self.env.render(mode)
    
    def close(self):
        """Close the environment."""
        self.env.close()

# Test the safe environment
print("Testing Safe CartPole Environment...")
env = SafeCartPoleEnv(position_limit=1.5, cost_threshold=0.1)

# Run a few random episodes to test
for episode in range(3):
    obs, info = env.reset()
    episode_reward = 0
    episode_cost = 0
    
    for step in range(100):  # Max 100 steps
        action = env.action_space.sample()  # Random action
        obs, reward, terminated, truncated, info = env.step(action)
        
        episode_reward += reward
        episode_cost += info['cost']
        
        if terminated or truncated:
            break
    
    print(f"Episode {episode + 1}: Reward = {episode_reward:.2f}, Cost = {episode_cost:.2f}, "
          f"Violations = {info['constraint_violations']}")

env.close()


## 2. Neural Network Architectures

Let's implement the neural network components needed for safe RL algorithms.


In [None]:
class PolicyNetwork(nn.Module):
    """Gaussian policy network for continuous actions."""
    
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Parameter(torch.zeros(action_dim))
        
    def forward(self, state):
        """Forward pass through the network."""
        x = torch.tanh(self.fc1(state))
        x = torch.tanh(self.fc2(x))
        mean = self.mean(x)
        std = torch.exp(self.log_std.clamp(-20, 2))  # Clamp for numerical stability
        return mean, std
    
    def sample(self, state):
        """Sample action from the policy."""
        mean, std = self.forward(state)
        dist = torch.distributions.Normal(mean, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        return action, log_prob
    
    def log_prob(self, state, action):
        """Compute log probability of action given state."""
        mean, std = self.forward(state)
        dist = torch.distributions.Normal(mean, std)
        return dist.log_prob(action).sum(dim=-1)

class ValueNetwork(nn.Module):
    """Value function approximator."""
    
    def __init__(self, state_dim, hidden_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, 1)
        
    def forward(self, state):
        """Forward pass through the network."""
        x = torch.tanh(self.fc1(state))
        x = torch.tanh(self.fc2(x))
        return self.value(x)

class SafetyNetwork(nn.Module):
    """Safety function approximator for Control Barrier Functions."""
    
    def __init__(self, state_dim, hidden_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.safety = nn.Linear(hidden_dim, 1)
        
    def forward(self, state):
        """Forward pass through the network."""
        x = torch.tanh(self.fc1(state))
        x = torch.tanh(self.fc2(x))
        return self.safety(x)

class DistributionalCritic(nn.Module):
    """Distributional critic for risk-sensitive RL."""
    
    def __init__(self, state_dim, action_dim, num_atoms=51, hidden_dim=64):
        super().__init__()
        self.num_atoms = num_atoms
        
        # Network architecture
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_atoms)
        
        # Atom values (support of the distribution)
        self.register_buffer('atoms', torch.linspace(-10, 10, num_atoms))
        
    def forward(self, state, action):
        """Forward pass through the network."""
        x = torch.cat([state, action], dim=-1)
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        logits = self.fc3(x)
        return torch.softmax(logits, dim=-1)
    
    def get_distribution(self, state, action):
        """Get return distribution."""
        probs = self.forward(state, action)
        return torch.distributions.Categorical(probs)
    
    def get_value(self, state, action):
        """Get expected value."""
        probs = self.forward(state, action)
        return torch.sum(probs * self.atoms, dim=-1)

# Test the networks
print("Testing Neural Network Architectures...")

# Test parameters
state_dim = 4  # CartPole state dimension
action_dim = 1  # CartPole action dimension (continuous)

# Create networks
policy = PolicyNetwork(state_dim, action_dim).to(device)
value_reward = ValueNetwork(state_dim).to(device)
value_cost = ValueNetwork(state_dim).to(device)
safety = SafetyNetwork(state_dim).to(device)
critic_dist = DistributionalCritic(state_dim, action_dim).to(device)

# Test forward passes
test_state = torch.randn(1, state_dim).to(device)
test_action = torch.randn(1, action_dim).to(device)

print(f"Policy output shape: {policy(test_state)[0].shape}")
print(f"Value output shape: {value_reward(test_state).shape}")
print(f"Safety output shape: {safety(test_state).shape}")
print(f"Distributional critic output shape: {critic_dist(test_state, test_action).shape}")

print("All networks initialized successfully!")
