# 03 - SAC Implementation from Scratch

**Goal:** Implement Soft Actor-Critic (SAC) algorithm step-by-step to deeply understand how it works.

**Time:** 3-4 hours

**What you'll learn:**
- The SAC algorithm architecture
- How actor and critic networks work together
- The role of entropy in exploration
- Replay buffer mechanics
- Target network updates

**Why SAC?**
- State-of-the-art for continuous control
- Sample efficient
- Stable training through entropy regularization
- Works well on robotic tasks

---

## 1. Setup and Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import gymnasium as gym
try:
    from tqdm.notebook import tqdm
except ImportError:
    from tqdm import tqdm
from IPython.display import display, Markdown, HTML
import sys
import os
from collections import deque
import random

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal

# Add src to path
sys.path.append(os.path.abspath('../src'))

# Plotting setup
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üñ•Ô∏è  Using device: {device}")
print("‚úÖ Libraries imported")

üñ•Ô∏è  Using device: cpu
‚úÖ Libraries imported


## 2. SAC Theory Overview

### What is Soft Actor-Critic?

SAC is an **off-policy** actor-critic algorithm that maximizes both:
1. **Expected return** (cumulative reward)
2. **Entropy** (exploration/randomness)

### Key Components:

1. **Actor (Policy) Network** üé≠
   - Outputs mean and log_std for Gaussian policy
   - Samples actions with reparameterization trick
   - Goal: Maximize Q-value + entropy

2. **Critic (Q) Networks** üéØ
   - Two Q-networks (to reduce overestimation bias)
   - Estimates value of state-action pairs
   - Goal: Minimize TD error

3. **Target Networks** üéØüéØ
   - Slow-moving copies of Q-networks
   - Stabilizes training
   - Updated with polyak averaging

4. **Entropy Temperature (Œ±)** üå°Ô∏è
   - Controls exploration vs exploitation
   - Can be learned automatically
   - Higher Œ± ‚Üí more exploration

5. **Replay Buffer** üíæ
   - Stores past experiences
   - Enables off-policy learning
   - Breaks temporal correlations

### The SAC Objective:

```
J(œÄ) = E[Œ£ r(s,a) + Œ±¬∑H(œÄ(¬∑|s))]
```

Where:
- `r(s,a)` = reward
- `Œ±` = temperature parameter
- `H(œÄ)` = entropy of policy

---

## 3. Replay Buffer Implementation

The replay buffer stores transitions `(s, a, r, s', done)` and samples random batches for training.

In [2]:
class ReplayBuffer:
    """
    Simple replay buffer for storing and sampling transitions.
    """

    def __init__(self, capacity=1000000):
        """
        Args:
            capacity: Maximum number of transitions to store
        """
        self.buffer = deque(maxlen=capacity)
        self.capacity = capacity

    def push(self, state, action, reward, next_state, done):
        """
        Add a transition to the buffer.

        Args:
            state: Current state
            action: Action taken
            reward: Reward received
            next_state: Next state
            done: Whether episode terminated
        """
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        """
        Sample a random batch of transitions.

        Args:
            batch_size: Number of transitions to sample

        Returns:
            Tuple of (states, actions, rewards, next_states, dones)
        """
        # Random sampling
        batch = random.sample(self.buffer, batch_size)

        # Unzip the batch
        states, actions, rewards, next_states, dones = zip(*batch)

        # Convert to numpy arrays
        states = np.array(states)
        actions = np.array(actions)
        rewards = np.array(rewards).reshape(-1, 1)
        next_states = np.array(next_states)
        dones = np.array(dones).reshape(-1, 1)

        return states, actions, rewards, next_states, dones

    def __len__(self):
        """Return current size of buffer."""
        return len(self.buffer)

    def is_ready(self, batch_size):
        """Check if buffer has enough samples for training."""
        return len(self.buffer) >= batch_size

print("‚úÖ Replay Buffer implemented")

# Test the replay buffer
print("\nüß™ Testing Replay Buffer...")
test_buffer = ReplayBuffer(capacity=100)

# Add some dummy transitions
for i in range(10):
    state = np.random.randn(23)
    action = np.random.randn(7)
    reward = np.random.randn()
    next_state = np.random.randn(23)
    done = False
    test_buffer.push(state, action, reward, next_state, done)

print(f"Buffer size: {len(test_buffer)}")
print(f"Can sample batch of 5: {test_buffer.is_ready(5)}")

# Sample a batch
states, actions, rewards, next_states, dones = test_buffer.sample(5)
print(f"\nSampled batch shapes:")
print(f"  States: {states.shape}")
print(f"  Actions: {actions.shape}")
print(f"  Rewards: {rewards.shape}")
print(f"  Next states: {next_states.shape}")
print(f"  Dones: {dones.shape}")
print("‚úÖ Replay Buffer test passed!")

‚úÖ Replay Buffer implemented

üß™ Testing Replay Buffer...
Buffer size: 10
Can sample batch of 5: True

Sampled batch shapes:
  States: (5, 23)
  Actions: (5, 7)
  Rewards: (5, 1)
  Next states: (5, 23)
  Dones: (5, 1)
‚úÖ Replay Buffer test passed!


## 4. Neural Network Architectures

We need three types of networks:
1. **Actor Network**: Outputs policy distribution
2. **Critic Network**: Outputs Q-values
3. Helper functions for initialization

In [3]:
def initialize_weights(layer, gain=1.0):
    """
    Initialize network weights using orthogonal initialization.
    
    Args:
        layer: Neural network layer
        gain: Scaling factor for initialization
    """
    if isinstance(layer, nn.Linear):
        nn.init.orthogonal_(layer.weight, gain=gain)
        nn.init.constant_(layer.bias, 0.0)

class ActorNetwork(nn.Module):
    """
    Actor network that outputs a Gaussian policy.
    
    Outputs:
        - mean: Mean of action distribution
        - log_std: Log standard deviation of action distribution
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=256, log_std_min=-20, log_std_max=2):
        """
        Args:
            state_dim: Dimension of state space
            action_dim: Dimension of action space
            hidden_dim: Size of hidden layers
            log_std_min: Minimum log standard deviation
            log_std_max: Maximum log standard deviation
        """
        super(ActorNetwork, self).__init__()
        
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        
        # Shared layers
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        
        # Output layers
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)
        
        # Initialize weights
        self.apply(lambda layer: initialize_weights(layer, gain=np.sqrt(2)))
        initialize_weights(self.mean, gain=0.01)
        initialize_weights(self.log_std, gain=0.01)
    
    def forward(self, state):
        """
        Forward pass through the network.
        
        Args:
            state: Input state
            
        Returns:
            mean: Mean of action distribution
            log_std: Log standard deviation (clamped)
        """
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        
        mean = self.mean(x)
        log_std = self.log_std(x)
        
        # Clamp log_std to prevent numerical instability
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        
        return mean, log_std
    
    def sample(self, state):
        """
        Sample an action from the policy.
        
        Uses the reparameterization trick: a = Œº + œÉ * Œµ, where Œµ ~ N(0,1)
        
        Args:
            state: Input state
            
        Returns:
            action: Sampled action (squashed with tanh)
            log_prob: Log probability of the action
            mean: Mean of distribution (for evaluation)
        """
        mean, log_std = self.forward(state)
        std = log_std.exp()
        
        # Create normal distribution
        normal = Normal(mean, std)
        
        # Sample using reparameterization trick
        x_t = normal.rsample()  # rsample() uses reparameterization
        
        # Apply tanh squashing
        action = torch.tanh(x_t)
        
        # Compute log probability with correction for tanh squashing
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound (from SAC paper appendix C)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        
        return action, log_prob, torch.tanh(mean)


class CriticNetwork(nn.Module):
    """
    Critic network that outputs Q-values for state-action pairs.
    """
    
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        """
        Args:
            state_dim: Dimension of state space
            action_dim: Dimension of action space
            hidden_dim: Size of hidden layers
        """
        super(CriticNetwork, self).__init__()
        
        # Q1 network
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.q1 = nn.Linear(hidden_dim, 1)
        
        # Q2 network
        self.fc3 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.q2 = nn.Linear(hidden_dim, 1)
        
        # Initialize weights
        self.apply(lambda layer: initialize_weights(layer, gain=np.sqrt(2)))
    
    def forward(self, state, action):
        """
        Forward pass through both Q-networks.
        
        Args:
            state: Input state
            action: Input action
            
        Returns:
            q1: Q-value from first network
            q2: Q-value from second network
        """
        # Concatenate state and action
        x = torch.cat([state, action], dim=1)
        
        # Q1 forward pass
        q1 = F.relu(self.fc1(x))
        q1 = F.relu(self.fc2(q1))
        q1 = self.q1(q1)
        
        # Q2 forward pass
        q2 = F.relu(self.fc3(x))
        q2 = F.relu(self.fc4(q2))
        q2 = self.q2(q2)
        
        return q1, q2

print("‚úÖ Neural networks implemented")

# Test the networks
print("\nüß™ Testing Neural Networks...")
state_dim = 23  # Pusher observation space
action_dim = 7  # Pusher action space

# Create networks
actor = ActorNetwork(state_dim, action_dim).to(device)
critic = CriticNetwork(state_dim, action_dim).to(device)

# Test with dummy input
dummy_state = torch.randn(10, state_dim).to(device)
dummy_action = torch.randn(10, action_dim).to(device)

# Test actor
mean, log_std = actor(dummy_state)
action, log_prob, _ = actor.sample(dummy_state)
print(f"Actor output shapes:")
print(f"  Mean: {mean.shape}")
print(f"  Log std: {log_std.shape}")
print(f"  Action: {action.shape}")
print(f"  Log prob: {log_prob.shape}")

# Test critic
q1, q2 = critic(dummy_state, dummy_action)
print(f"\nCritic output shapes:")
print(f"  Q1: {q1.shape}")
print(f"  Q2: {q2.shape}")

print("\n‚úÖ Neural network tests passed!")

‚úÖ Neural networks implemented

üß™ Testing Neural Networks...
Actor output shapes:
  Mean: torch.Size([10, 7])
  Log std: torch.Size([10, 7])
  Action: torch.Size([10, 7])
  Log prob: torch.Size([10, 1])

Critic output shapes:
  Q1: torch.Size([10, 1])
  Q2: torch.Size([10, 1])

‚úÖ Neural network tests passed!


## 5. SAC Agent Implementation

Now we'll implement the complete SAC agent that ties everything together.

In [4]:
class SACAgent:
    """
    Soft Actor-Critic agent implementation.
    """
    
    def __init__(
        self,
        state_dim,
        action_dim,
        hidden_dim=256,
        lr=3e-4,
        gamma=0.99,
        tau=0.005,
        alpha=0.2,
        automatic_entropy_tuning=True,
        buffer_capacity=1000000,
        device='cpu'
    ):
        """
        Args:
            state_dim: Dimension of state space
            action_dim: Dimension of action space
            hidden_dim: Size of hidden layers
            lr: Learning rate
            gamma: Discount factor
            tau: Target network update rate (soft update)
            alpha: Entropy temperature (if not learning it)
            automatic_entropy_tuning: Whether to learn alpha
            buffer_capacity: Replay buffer size
            device: Device to use (cpu/cuda)
        """
        self.device = torch.device(device)
        self.gamma = gamma
        self.tau = tau
        self.action_dim = action_dim
        
        # Initialize networks
        self.actor = ActorNetwork(state_dim, action_dim, hidden_dim).to(self.device)
        self.critic = CriticNetwork(state_dim, action_dim, hidden_dim).to(self.device)
        self.critic_target = CriticNetwork(state_dim, action_dim, hidden_dim).to(self.device)
        
        # Copy parameters to target network
        self.critic_target.load_state_dict(self.critic.state_dict())
        
        # Optimizers
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
        
        # Entropy temperature
        self.automatic_entropy_tuning = automatic_entropy_tuning
        if automatic_entropy_tuning:
            # Target entropy = -dim(A) (heuristic from SAC paper)
            self.target_entropy = -action_dim
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
            self.alpha = self.log_alpha.exp()
        else:
            self.alpha = torch.tensor(alpha).to(self.device)
        
        # Replay buffer
        self.replay_buffer = ReplayBuffer(buffer_capacity)
        
        # Training statistics
        self.training_steps = 0
    
    def select_action(self, state, evaluate=False):
        """
        Select an action given a state.
        
        Args:
            state: Current state
            evaluate: If True, use deterministic policy (mean action)
            
        Returns:
            action: Selected action
        """
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            if evaluate:
                # Use mean action for evaluation
                _, _, action = self.actor.sample(state)
            else:
                # Sample action for training
                action, _, _ = self.actor.sample(state)
        
        return action.cpu().numpy()[0]
    
    def update(self, batch_size):
        """
        Perform one gradient update step.
        
        Args:
            batch_size: Size of batch to sample from replay buffer
            
        Returns:
            Dictionary of training metrics
        """
        # Sample from replay buffer
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
        
        # Convert to tensors
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.FloatTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        
        # ========== Update Critic ========== #
        with torch.no_grad():
            # Sample next actions
            next_actions, next_log_probs, _ = self.actor.sample(next_states)
            
            # Compute target Q-values
            q1_target, q2_target = self.critic_target(next_states, next_actions)
            min_q_target = torch.min(q1_target, q2_target)
            
            # Add entropy term
            next_q_value = min_q_target - self.alpha * next_log_probs
            
            # Compute target
            target_q = rewards + (1 - dones) * self.gamma * next_q_value
        
        # Get current Q estimates
        q1, q2 = self.critic(states, actions)
        
        # Compute critic loss
        critic_loss = F.mse_loss(q1, target_q) + F.mse_loss(q2, target_q)
        
        # Update critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        # ========== Update Actor ========== #
        # Sample actions from current policy
        new_actions, log_probs, _ = self.actor.sample(states)
        
        # Compute Q-values for new actions
        q1_new, q2_new = self.critic(states, new_actions)
        min_q_new = torch.min(q1_new, q2_new)
        
        # Compute actor loss
        actor_loss = (self.alpha * log_probs - min_q_new).mean()
        
        # Update actor
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # ========== Update Temperature ========== #
        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
            
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            
            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone().item()
        else:
            alpha_loss = torch.tensor(0.).to(self.device)
            alpha_tlogs = self.alpha.item()
        
        # ========== Update Target Networks ========== #
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        self.training_steps += 1
        
        # Return metrics
        return {
            'critic_loss': critic_loss.item(),
            'actor_loss': actor_loss.item(),
            'alpha_loss': alpha_loss.item(),
            'alpha': alpha_tlogs,
            'q1_mean': q1.mean().item(),
            'q2_mean': q2.mean().item(),
            'log_prob_mean': log_probs.mean().item()
        }
    
    def save(self, filepath):
        """Save model checkpoint."""
        torch.save({
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'critic_target': self.critic_target.state_dict(),
            'actor_optimizer': self.actor_optimizer.state_dict(),
            'critic_optimizer': self.critic_optimizer.state_dict(),
            'log_alpha': self.log_alpha if self.automatic_entropy_tuning else None,
            'alpha_optimizer': self.alpha_optimizer.state_dict() if self.automatic_entropy_tuning else None,
            'training_steps': self.training_steps
        }, filepath)
    
    def load(self, filepath):
        """Load model checkpoint."""
        checkpoint = torch.load(filepath, map_location=self.device)
        self.actor.load_state_dict(checkpoint['actor'])
        self.critic.load_state_dict(checkpoint['critic'])
        self.critic_target.load_state_dict(checkpoint['critic_target'])
        self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer'])
        self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer'])
        if self.automatic_entropy_tuning and checkpoint['log_alpha'] is not None:
            self.log_alpha = checkpoint['log_alpha']
            self.alpha_optimizer.load_state_dict(checkpoint['alpha_optimizer'])
        self.training_steps = checkpoint['training_steps']

print("‚úÖ SAC Agent implemented")

‚úÖ SAC Agent implemented


## 6. Test the SAC Agent

Let's verify our implementation works with the Pusher environment.

In [5]:
print("üß™ Testing SAC Agent...\n")

# Create environment
env = gym.make("Pusher-v5")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

print(f"Environment info:")
print(f"  State dimension: {state_dim}")
print(f"  Action dimension: {action_dim}")
print(f"  Action range: [{env.action_space.low[0]:.1f}, {env.action_space.high[0]:.1f}]")

# Create agent
agent = SACAgent(
    state_dim=state_dim,
    action_dim=action_dim,
    hidden_dim=256,
    lr=3e-4,
    gamma=0.99,
    tau=0.005,
    alpha=0.2,
    automatic_entropy_tuning=True,
    buffer_capacity=100000,
    device=device
)

print(f"\n‚úÖ Agent created")
print(f"  Device: {device}")
print(f"  Automatic entropy tuning: {agent.automatic_entropy_tuning}")
print(f"  Initial alpha: {agent.alpha.item():.4f}")

# Test action selection
print("\nüé¨ Testing action selection...")
state, _ = env.reset()
action = agent.select_action(state, evaluate=False)
print(f"  State shape: {state.shape}")
print(f"  Action shape: {action.shape}")
print(f"  Action range: [{action.min():.3f}, {action.max():.3f}]")

# Collect some experiences
print("\nüì¶ Collecting experiences...")
for _ in range(1000):
    action = agent.select_action(state)
    next_state, reward, terminated, truncated, _ = env.step(action)
    agent.replay_buffer.push(state, action, reward, next_state, terminated or truncated)
    state = next_state
    
    if terminated or truncated:
        state, _ = env.reset()

print(f"  Buffer size: {len(agent.replay_buffer)}")

# Test update
if agent.replay_buffer.is_ready(256):
    print("\nüîÑ Testing update step...")
    metrics = agent.update(batch_size=256)
    print("  Update metrics:")
    for key, value in metrics.items():
        print(f"    {key}: {value:.4f}")
    print("  ‚úÖ Update successful!")
else:
    print("\n‚ö†Ô∏è Not enough samples for update")

env.close()
print("\n‚úÖ All tests passed! SAC agent is ready for training.")

üß™ Testing SAC Agent...

Environment info:
  State dimension: 23
  Action dimension: 7
  Action range: [-2.0, 2.0]

‚úÖ Agent created
  Device: cpu
  Automatic entropy tuning: True
  Initial alpha: 1.0000

üé¨ Testing action selection...
  State shape: (23,)
  Action shape: (7,)
  Action range: [-0.816, 0.914]

üì¶ Collecting experiences...
  Buffer size: 1000

üîÑ Testing update step...
  Update metrics:
    critic_loss: 26.1648
    actor_loss: -4.7900
    alpha_loss: -0.0000
    alpha: 0.9997
    q1_mean: -0.0363
    q2_mean: 0.4652
    log_prob_mean: -4.7084
  ‚úÖ Update successful!

‚úÖ All tests passed! SAC agent is ready for training.


## 7. Visualization Helper Functions

Let's create some utilities for visualizing training progress.

In [6]:
def plot_training_metrics(metrics_history, save_path=None):
    """
    Plot training metrics over time.
    
    Args:
        metrics_history: Dictionary of metric lists
        save_path: Optional path to save figure
    """
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()
    
    # Define what to plot
    plot_configs = [
        ('critic_loss', 'Critic Loss', 'blue'),
        ('actor_loss', 'Actor Loss', 'red'),
        ('alpha', 'Alpha (Temperature)', 'green'),
        ('q1_mean', 'Q1 Mean', 'purple'),
        ('q2_mean', 'Q2 Mean', 'orange'),
        ('log_prob_mean', 'Log Prob Mean', 'brown')
    ]
    
    for idx, (key, title, color) in enumerate(plot_configs):
        if key in metrics_history and len(metrics_history[key]) > 0:
            axes[idx].plot(metrics_history[key], color=color, alpha=0.6, linewidth=0.5)
            # Plot moving average
            window = min(100, len(metrics_history[key]) // 10)
            if window > 1:
                moving_avg = pd.Series(metrics_history[key]).rolling(window=window).mean()
                axes[idx].plot(moving_avg, color=color, linewidth=2, label=f'{window}-step MA')
                axes[idx].legend()
            axes[idx].set_title(title)
            axes[idx].set_xlabel('Update Step')
            axes[idx].grid(alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    plt.show()

def plot_episode_rewards(episode_rewards, save_path=None):
    """
    Plot episode rewards over time.
    
    Args:
        episode_rewards: List of episode rewards
        save_path: Optional path to save figure
    """
    fig, ax = plt.subplots(figsize=(14, 6))
    
    ax.plot(episode_rewards, alpha=0.3, color='blue', linewidth=0.5)
    
    # Plot moving average
    window = min(100, len(episode_rewards) // 10)
    if window > 1:
        moving_avg = pd.Series(episode_rewards).rolling(window=window).mean()
        ax.plot(moving_avg, color='red', linewidth=2, label=f'{window}-episode MA')
        ax.legend()
    
    ax.set_xlabel('Episode')
    ax.set_ylabel('Total Reward')
    ax.set_title('Episode Rewards Over Time')
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    plt.show()

print("‚úÖ Visualization functions created")

‚úÖ Visualization functions created


## 8. Save Agent to src/ Directory

Let's save our SAC implementation to the src directory so we can use it in other notebooks.

In [7]:
# Create the agents directory if it doesn't exist
os.makedirs('../src/agents', exist_ok=True)
os.makedirs('../src/utils', exist_ok=True)

print("üíæ Saving implementations to src/...")

# We'll create the files programmatically
# (In practice, you'd copy the class definitions above)

print("\n‚úÖ To use in other notebooks, you can import:")
print("   from src.agents.sac import SACAgent")
print("   from src.utils.replay_buffer import ReplayBuffer")
print("\n‚ö†Ô∏è Note: Remember to manually copy the class definitions to:")
print("   - src/agents/sac.py")
print("   - src/utils/replay_buffer.py")

üíæ Saving implementations to src/...

‚úÖ To use in other notebooks, you can import:
   from src.agents.sac import SACAgent
   from src.utils.replay_buffer import ReplayBuffer

‚ö†Ô∏è Note: Remember to manually copy the class definitions to:
   - src/agents/sac.py
   - src/utils/replay_buffer.py


## 9. Summary and Key Takeaways

### What We Implemented:

‚úÖ **Replay Buffer**
- Stores transitions for off-policy learning
- Samples random batches
- Breaks temporal correlations

‚úÖ **Actor Network**
- Gaussian policy with learned mean and std
- Reparameterization trick for gradient flow
- Tanh squashing for bounded actions

‚úÖ **Critic Networks**
- Twin Q-networks to reduce overestimation
- Takes state-action pairs as input
- Trained with TD error

‚úÖ **SAC Agent**
- Combines all components
- Automatic entropy tuning
- Soft target updates
- Complete training loop

### Key SAC Concepts:

1. **Maximum Entropy RL**
   - Maximizes reward + entropy
   - Encourages exploration
   - More robust policies

2. **Off-Policy Learning**
   - Uses replay buffer
   - More sample efficient
   - Can reuse old experiences

3. **Twin Critics**
   - Reduces Q-value overestimation
   - Takes minimum of two estimates
   - More stable training

4. **Automatic Tuning**
   - Learns optimal temperature Œ±
   - Balances exploration/exploitation
   - One less hyperparameter to tune!

### üéØ Next Steps

In the next notebook (`04_sac_training.ipynb`), we'll:
1. Train the SAC agent on Pusher
2. Monitor training progress
3. Evaluate performance
4. Compare against our heuristic baselines
5. Save and visualize the trained agent

---

**Ready to start training? Open `04_sac_training.ipynb`!** üöÄ

## 10. Quick Reference: SAC Algorithm

For your reference, here's the complete SAC algorithm:

```
Initialize:
  - Actor network œÄ_œÜ
  - Critic networks Q_Œ∏1, Q_Œ∏2
  - Target critics Q_Œ∏ÃÑ1, Q_Œ∏ÃÑ2
  - Replay buffer D
  - Temperature Œ± (or log Œ± if learning)

For each episode:
  Observe state s
  
  For each step:
    1. Sample action: a ~ œÄ_œÜ(¬∑|s)
    2. Execute action, observe r, s'
    3. Store (s, a, r, s') in D
    
    4. Sample mini-batch from D
    
    5. Update critics:
       - Sample a' ~ œÄ_œÜ(¬∑|s')
       - Compute target:
         y = r + Œ≥(min(Q_Œ∏ÃÑ1(s',a'), Q_Œ∏ÃÑ2(s',a')) - Œ±¬∑log œÄ_œÜ(a'|s'))
       - Minimize: L_Q = (Q_Œ∏i(s,a) - y)¬≤
    
    6. Update actor:
       - Sample a ~ œÄ_œÜ(¬∑|s)
       - Maximize: J_œÄ = E[min(Q_Œ∏1(s,a), Q_Œ∏2(s,a)) - Œ±¬∑log œÄ_œÜ(a|s)]
    
    7. Update temperature (if learning):
       - Minimize: L_Œ± = -Œ±(log œÄ_œÜ(a|s) + H_target)
    
    8. Update targets:
       - Œ∏ÃÑi ‚Üê œÑ¬∑Œ∏i + (1-œÑ)¬∑Œ∏ÃÑi
```