# Soft Actor Critic Agent(115 Points)

> Name:

> SID: 



In this notebook, we are going to implement **Soft Actor Critic (SAC)** 
on the **CartPole** environment in online and offline settings. In this framework, the actor aims to maximize expected reward while also maximizing **entropy**. That is, to succeed at the task while acting as randomly as possible. This method seeks a high entropy in the policy to explicitly encourage exploration. For the offline setting, you are going to make SAC conservative using CQL method. 

* SAC is an off-policy algorithm.
* The version of SAC implemented here can only be used for environments with discrete action spaces.
* An alternate version of SAC, which slightly changes the policy update  rule, can be implemented to handle continouse action spaces.
* Complete the **TODO** parts in the code accordingly.
* Remember to answer the conceptual questions.




## Overview

This notebook provides a **complete implementation** of the **Soft Actor-Critic (SAC)** algorithm, including:

### What is SAC?

**Soft Actor-Critic** is a state-of-the-art off-policy actor-critic algorithm that:
- Maximizes **both reward and entropy** (encourages exploration)
- Uses **clipped double-Q learning** to reduce value overestimation
- Automatically tunes the **temperature parameter** Œ± for optimal exploration-exploitation balance
- Works for both **continuous** and **discrete** action spaces

### What You'll Learn

1. **Theory**: Understanding SAC's objective function, loss functions, and the role of entropy
2. **Implementation**: Building neural networks, critics, actors, and the full training loop
3. **Online RL**: Training an agent through environment interaction
4. **Offline RL**: Training from a fixed dataset without environment interaction
5. **Conservative Q-Learning (CQL)**: Making offline RL more robust and stable

### Structure

- **Part 1**: Network Architecture - Build feedforward neural networks
- **Part 2**: Conceptual Questions - Understand the theory behind SAC
- **Part 3**: SAC Agent - Implement the complete algorithm with critics, actor, and training
- **Part 4**: Online Training - Train SAC with environment interaction
- **Part 5**: Offline Training - Train from a fixed replay buffer
- **Part 6**: Conservative Training - Add CQL regularization for safer offline learning
- **Part 7**: Analysis - Compare all three approaches

### Prerequisites

- Understanding of reinforcement learning basics (MDP, Q-learning, policy gradients)
- Familiarity with PyTorch and neural networks
- Knowledge of actor-critic methods

Let's begin! üöÄ


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim
import numpy as np
import random
import gym
import matplotlib.pyplot as plt


seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

## Network Structure (8 points)
For constructing SAC agent, we use objects of feedforward neural networks with 3 layers. Complete the code below.

In [None]:
class Network(torch.nn.Module):

    def __init__(self, input_dimension, output_dimension, output_activation=torch.nn.Identity()):
        super(Network, self).__init__()
        ##########################################################
        # TODO (4 points): 
        # Define your network layers.
        ##########################################################
        # 3-layer feedforward neural network with hidden size 256
        self.layer_1 = torch.nn.Linear(input_dimension, 256)
        self.layer_2 = torch.nn.Linear(256, 256)
        self.output_layer = torch.nn.Linear(256, output_dimension)
        self.output_activation = output_activation
        ##########################################################

    def forward(self, inpt):  
        output = None      
        ##########################################################
        # TODO (4 points): 
        # Use relu and the output activation functions to calculate the output
        ##########################################################
        # Forward pass through the network with ReLU activations
        x = torch.nn.functional.relu(self.layer_1(inpt))
        x = torch.nn.functional.relu(self.layer_2(x))
        output = self.output_activation(self.output_layer(x))
        return output
        ##########################################################

### Network Architecture Explanation

The `Network` class implements a 3-layer feedforward neural network:

1. **Input Layer ‚Üí Hidden Layer 1**: Maps from input dimension to 256 neurons
2. **Hidden Layer 1 ‚Üí Hidden Layer 2**: 256 ‚Üí 256 neurons  
3. **Hidden Layer 2 ‚Üí Output Layer**: 256 ‚Üí output dimension

**Activation Functions:**
- **ReLU** is used between hidden layers to introduce non-linearity
- **Output activation** is customizable (e.g., `Softmax` for actor, `Identity` for critics)

This network architecture will be used for both the **actor** (policy network) and **critics** (Q-value networks) in SAC.


## Replay Buffer

A SAC agent needs a replay buffer, from which previously visited states can be sampled. You can use the implemented code below. You are going to use the replay buffer of an online-trained agent to train the offline model.

In [11]:
import numpy as np


class ReplayBuffer:

    def __init__(self, environment, capacity=500000):
        transition_type_str = self.get_transition_type_str(environment)
        self.buffer = np.zeros(capacity, dtype=transition_type_str)
        self.weights = np.zeros(capacity)
        self.head_idx = 0
        self.count = 0
        self.capacity = capacity
        self.max_weight = 10**-2
        self.delta = 10**-4
        self.indices = None
        self.mirror_index = np.random.permutation(range(self.buffer.shape[0]))

    def get_transition_type_str(self, environment):
        state_dim = environment.observation_space.shape[0]
        state_dim_str = '' if state_dim == () else str(state_dim)
        state_type_str = environment.observation_space.sample().dtype.name
        action_dim = environment.action_space.shape
        action_dim_str = '' if action_dim == () else str(action_dim)
        action_type_str = environment.action_space.sample().__class__.__name__

        # type str for transition = 'state type, action type, reward type, state type'
        transition_type_str = '{0}{1}, {2}{3}, float32, {0}{1}, bool'.format(state_dim_str, state_type_str,
                                                                             action_dim_str, action_type_str)

        return transition_type_str

    def add_transition(self, transition):
        self.buffer[self.head_idx] = transition
        self.weights[self.head_idx] = self.max_weight

        self.head_idx = (self.head_idx + 1) % self.capacity
        self.count = min(self.count + 1, self.capacity)

    def sample_minibatch(self, size=100, batch_deterministic_start=None):
        set_weights = self.weights[:self.count] + self.delta
        probabilities = set_weights / sum(set_weights)
        if batch_deterministic_start is None:
            self.indices = np.random.choice(range(self.count), size, p=probabilities, replace=False)
        else:
            self.indices = self.mirror_index[batch_deterministic_start:batch_deterministic_start+size]
        return self.buffer[self.indices]

    def update_weights(self, prediction_errors):
        max_error = max(prediction_errors)
        self.max_weight = max(self.max_weight, max_error)
        self.weights[self.indices] = prediction_errors

    def get_size(self):
        return self.count

## Questions (18 points)

‚ùì We know that standard RL maximizes the expected sum of rewards. What is the objective function of SAC algorithm? Compare it to the standard RL loss.

‚ùì Write down the actor cost function.

‚ùì Write down the critic cost function.

‚ùì Elaborate on the reason why most implementations of SAC use two critics (one local and one target).

‚ùì What is the difference between training samples in offline and online settings?

‚ùì How does adding CQL on top of SAC change the objective function?



### Answers to Conceptual Questions

**‚ùì Q1: What is the objective function of SAC algorithm? Compare it to the standard RL loss.**

**Answer:**  
Standard RL maximizes:
$$J(\pi) = \mathbb{E}_{\tau \sim \pi}[\sum_{t=0}^{\infty} \gamma^t r(s_t, a_t)]$$

SAC maximizes **entropy-regularized objective**:
$$J(\pi) = \mathbb{E}_{\tau \sim \pi}[\sum_{t=0}^{\infty} \gamma^t (r(s_t, a_t) + \alpha H(\pi(\cdot|s_t)))]$$

where $H(\pi(\cdot|s_t)) = -\mathbb{E}_{a \sim \pi}[\log \pi(a|s_t)]$ is the entropy.

**Key Difference:** SAC encourages exploration by rewarding high entropy (randomness) in the policy, while standard RL focuses only on maximizing rewards.

---

**‚ùì Q2: Write down the actor cost function.**

**Answer:**  
$$J_\pi(\phi) = \mathbb{E}_{s_t \sim D}\mathbb{E}_{a_t \sim \pi_\phi}[\alpha \log \pi_\phi(a_t|s_t) - Q_\theta(s_t, a_t)]$$

Or equivalently for discrete actions:
$$J_\pi(\phi) = \mathbb{E}_{s_t \sim D}[\sum_a \pi_\phi(a|s_t)(\alpha \log \pi_\phi(a|s_t) - Q_\theta(s_t, a))]$$

The actor minimizes this cost, which balances between maximizing Q-values and maintaining high entropy.

---

**‚ùì Q3: Write down the critic cost function.**

**Answer:**  
$$J_Q(\theta) = \mathbb{E}_{(s_t,a_t,r_t,s_{t+1}) \sim D}[(Q_\theta(s_t, a_t) - y_t)^2]$$

where the target is:
$$y_t = r_t + \gamma(1-d_t) \mathbb{E}_{a_{t+1} \sim \pi}[Q_{\theta'}(s_{t+1}, a_{t+1}) - \alpha \log \pi(a_{t+1}|s_{t+1})]$$

For discrete actions:
$$y_t = r_t + \gamma(1-d_t) \sum_a \pi(a|s_{t+1})[Q_{\theta'}(s_{t+1}, a) - \alpha \log \pi(a|s_{t+1})]$$

---

**‚ùì Q4: Elaborate on the reason why most implementations of SAC use two critics (one local and one target).**

**Answer:**  
SAC uses **two local critics** and **two target critics** (4 critics total):

1. **Two Local Critics (Q1, Q2):** Helps reduce **overestimation bias**. We take the minimum: $Q(s,a) = \min(Q_1(s,a), Q_2(s,a))$. This clipped double-Q learning prevents the critic from being overly optimistic.

2. **Target Networks (Q1_target, Q2_target):** Provides **stable training targets**. Target networks are slowly updated (soft update with $\tau \ll 1$), preventing the "moving target" problem where the Q-values we're trying to match keep changing rapidly.

---

**‚ùì Q5: What is the difference between training samples in offline and online settings?**

**Answer:**  

| Aspect | Online RL | Offline RL |
|--------|-----------|------------|
| **Data Collection** | Agent interacts with environment during training | Uses pre-collected fixed dataset |
| **Exploration** | Can explore new states/actions | Limited to dataset coverage |
| **Distribution Shift** | Policy improves, collects better data | Policy may diverge from dataset distribution |
| **Sample Efficiency** | Requires many environment interactions | No environment interaction needed |
| **Safety** | May take dangerous actions during exploration | Safe (no real-world interaction) |

**Key Challenge in Offline RL:** **Extrapolation error** - the agent may learn to take actions not well-represented in the dataset, leading to overestimated Q-values for out-of-distribution actions.

---

**‚ùì Q6: How does adding CQL on top of SAC change the objective function?**

**Answer:**  
CQL (Conservative Q-Learning) adds a **conservative regularizer** to the critic loss:

**Standard SAC Critic Loss:**
$$J_Q(\theta) = \mathbb{E}_{(s,a,r,s') \sim D}[(Q_\theta(s, a) - y)^2]$$

**CQL Critic Loss:**
$$J_{CQL}(\theta) = \alpha_{CQL} \cdot \underbrace{(\mathbb{E}_{s \sim D, a \sim \mu(a|s)}[Q_\theta(s,a)] - \mathbb{E}_{s,a \sim D}[Q_\theta(s,a)])}_{\text{CQL regularizer}} + J_Q(\theta)$$

where $\mu$ is a behavior policy (e.g., uniform or current policy).

**Effect:** 
- **Increases** Q-values for actions in the dataset $D$
- **Decreases** Q-values for out-of-distribution actions $\mu$
- This makes the agent **conservative**, avoiding actions not seen in the offline dataset
- The tradeoff factor $\alpha_{CQL}$ controls the strength of this conservatism


## SAC Agent (50 points)

Now complete the following class. You can use the auxiliary methods provided in the class.

In [None]:
class SACAgent:

    ALPHA_INITIAL = 1.
    REPLAY_BUFFER_BATCH_SIZE = 100
    DISCOUNT_RATE = 0.99
    LEARNING_RATE = 10 ** -4
    SOFT_UPDATE_INTERPOLATION_FACTOR = 0.01
    TRADEOFF_FACTOR = 5 # trade-off factor in the CQL

    def __init__(self, environment, replay_buffer=None, use_cql=False, offline=False):

        assert not use_cql or offline, 'Please activate the offline flag for CQL.' 
        assert not offline or not replay_buffer is None, 'Please pass a replay buffer to the offline method.' 

        self.environment = environment
        self.state_dim = self.environment.observation_space.shape[0]
        self.action_dim = self.environment.action_space.n

        self.offline = offline
        self.replay_buffer = ReplayBuffer(self.environment) if replay_buffer is None else replay_buffer
        self.use_cql = use_cql

        ##########################################################
        # TODO (6 points): 
        # Define critiss usig your impelmented feed forward netwrok(10 points).
        # To have easier critic updates, you can use two local critic networks 
        # and two target critics.
        ##########################################################
        # Two local critic networks (clipped double-Q learning)
        self.critic_local = Network(self.state_dim, self.action_dim)
        self.critic_local2 = Network(self.state_dim, self.action_dim)
        
        # Optimizers for each critic
        self.critic_optimiser = optim.Adam(self.critic_local.parameters(), lr=self.LEARNING_RATE)
        self.critic_optimiser2 = optim.Adam(self.critic_local2.parameters(), lr=self.LEARNING_RATE)
        
        # Two target critic networks for stable training
        self.critic_target = Network(self.state_dim, self.action_dim)
        self.critic_target2 = Network(self.state_dim, self.action_dim)
        ##########################################################

        self.soft_update_target_networks(tau=1.)

        ##########################################################
        # TODO (2 points): 
        # Define the actor usig your impelmented feed forward netwrok(10 points).
        # Define the actor optimizer using torch.Adam (4 points)
        ##########################################################
        # Actor network with Softmax activation for discrete action probabilities
        self.actor_local = Network(self.state_dim, self.action_dim, 
                                   output_activation=torch.nn.Softmax(dim=-1))
        
        # Actor optimizer
        self.actor_optimiser = optim.Adam(self.actor_local.parameters(), lr=self.LEARNING_RATE)
        ##########################################################

        self.target_entropy = 0.98 * -np.log(1 / self.environment.action_space.n)
        self.log_alpha = torch.tensor(np.log(self.ALPHA_INITIAL), requires_grad=True)
        self.alpha = self.log_alpha
        self.alpha_optimiser = torch.optim.Adam([self.log_alpha], lr=self.LEARNING_RATE)

    def get_next_action(self, state, evaluation_episode=False):
        if evaluation_episode:
            discrete_action = self.get_action_deterministically(state)
        else:
            discrete_action = self.get_action_nondeterministically(state)
        return discrete_action

    def get_action_nondeterministically(self, state):
        action_probabilities = self.get_action_probabilities(state)
        discrete_action = np.random.choice(range(self.action_dim), p=action_probabilities)
        return discrete_action

    def get_action_deterministically(self, state):
        action_probabilities = self.get_action_probabilities(state)
        discrete_action = np.argmax(action_probabilities)
        return discrete_action

    def critic_loss(self, states_tensor, actions_tensor, rewards_tensor, 
                    next_states_tensor, done_tensor):
        ##########################################################
        # TODO (12 points): 
        # You are going to calculate critic losses in this method.
        # Also you should implement the CQL loss if the corresponding 
        # flag is set.
        ##########################################################
        with torch.no_grad():
            # Get action probabilities and log probabilities for next states
            action_probabilities, log_action_probabilities = self.get_action_info(next_states_tensor)
            
            # Get Q-values for next states from target networks
            next_q_values_target = self.critic_target.forward(next_states_tensor)
            next_q_values_target2 = self.critic_target2.forward(next_states_tensor)
            
            # Use minimum of two Q-values (clipped double-Q)
            soft_state_values = (action_probabilities * (
                torch.min(next_q_values_target, next_q_values_target2) - 
                self.log_alpha.exp() * log_action_probabilities
            )).sum(dim=1)
            
            # Compute target: r + gamma * (1 - done) * V(s')
            next_q_values = rewards_tensor + self.DISCOUNT_RATE * (1 - done_tensor.float()) * soft_state_values
        
        # Get current Q-values for the actions taken
        soft_q_values = self.critic_local(states_tensor).gather(1, actions_tensor.unsqueeze(-1).long()).squeeze(-1)
        soft_q_values2 = self.critic_local2(states_tensor).gather(1, actions_tensor.unsqueeze(-1).long()).squeeze(-1)
        
        # Compute TD errors
        critic_loss = F.mse_loss(soft_q_values, next_q_values)
        critic2_loss = F.mse_loss(soft_q_values2, next_q_values)
        
        # Add CQL regularization if enabled
        if self.use_cql:
            # CQL regularizer: pushes down Q-values of all actions
            q_values_all = self.critic_local(states_tensor)
            q_values_all2 = self.critic_local2(states_tensor)
            
            # Log-sum-exp of Q-values (approximates max)
            cql_loss = torch.logsumexp(q_values_all, dim=1).mean() - soft_q_values.mean()
            cql_loss2 = torch.logsumexp(q_values_all2, dim=1).mean() - soft_q_values2.mean()
            
            # Add CQL penalty with tradeoff factor
            critic_loss = critic_loss + self.TRADEOFF_FACTOR * cql_loss
            critic2_loss = critic2_loss + self.TRADEOFF_FACTOR * cql_loss2

        return critic_loss, critic2_loss
        ##########################################################

    def actor_loss(self, states_tensor):
        ##########################################################
        # TODO (8 points): 
        # Now implement the actor loss.
        ##########################################################
        # Get action probabilities and log probabilities from actor
        action_probabilities, log_action_probabilities = self.get_action_info(states_tensor)
        
        # Get Q-values from both critics
        q_values = self.critic_local(states_tensor)
        q_values2 = self.critic_local2(states_tensor)
        
        # Use minimum Q-value (clipped double-Q)
        min_q_values = torch.min(q_values, q_values2)
        
        # Actor loss: E[alpha * log(pi(a|s)) - Q(s,a)]
        # Weighted by action probabilities for expectation over actions
        actor_loss = (action_probabilities * (
            self.log_alpha.exp() * log_action_probabilities - min_q_values
        )).sum(dim=1).mean()

        return actor_loss, log_action_probabilities
        ##########################################################

    def train_on_transition(self, state, discrete_action, next_state, reward, done):
        transition = (state, discrete_action, reward, next_state, done)
        self.train_networks(transition)

    def train_networks(self, transition=None, batch_deterministic_start=None):
        ##########################################################
        # TODO (6 points): 
        # Set all the gradients stored in the optimisers to zero.
        # add the new transition to the replay buffer for online case.
        ##########################################################
        # Zero all gradients
        self.critic_optimiser.zero_grad()
        self.critic_optimiser2.zero_grad()
        self.actor_optimiser.zero_grad()
        self.alpha_optimiser.zero_grad()
        
        # Add transition to replay buffer (only in online mode)
        if not self.offline and transition is not None:
            self.replay_buffer.add_transition(transition)
        ##########################################################

        if self.replay_buffer.get_size() >= self.REPLAY_BUFFER_BATCH_SIZE:
            minibatch = self.replay_buffer.sample_minibatch(self.REPLAY_BUFFER_BATCH_SIZE,
                                                            batch_deterministic_start=batch_deterministic_start)
            minibatch_separated = list(map(list, zip(*minibatch)))

            states_tensor = torch.tensor(np.array(minibatch_separated[0]))
            actions_tensor = torch.tensor(np.array(minibatch_separated[1]))
            rewards_tensor = torch.tensor(np.array(minibatch_separated[2])).float()
            next_states_tensor = torch.tensor(np.array(minibatch_separated[3]))
            done_tensor = torch.tensor(np.array(minibatch_separated[4]))

            ##########################################################
            # TODO (16 points): 
            # Here, you should compute the gradients based on this loss, i.e. the gradients
            # of the loss with respect to the Q-network parameters.
            # Given a minibatch of 100 transitions from replay buffer,
            # compute the critic loss and perform the backward and step functions,
            # and compute the actor loss and perform the backward and step functions.
            # You also need to update \alpha.
            ##########################################################
            # Convert states and actions to float tensors
            states_tensor = states_tensor.float()
            next_states_tensor = next_states_tensor.float()
            
            # 1. Update Critics
            critic_loss, critic2_loss = self.critic_loss(states_tensor, actions_tensor, 
                                                         rewards_tensor, next_states_tensor, 
                                                         done_tensor)
            
            # Backpropagate critic losses
            critic_loss.backward()
            critic2_loss.backward()
            
            # Update critic parameters
            self.critic_optimiser.step()
            self.critic_optimiser2.step()
            
            # Zero gradients for actor update
            self.critic_optimiser.zero_grad()
            self.critic_optimiser2.zero_grad()
            self.actor_optimiser.zero_grad()
            self.alpha_optimiser.zero_grad()
            
            # 2. Update Actor
            actor_loss, log_action_probabilities = self.actor_loss(states_tensor)
            
            # Backpropagate actor loss
            actor_loss.backward()
            
            # Update actor parameters
            self.actor_optimiser.step()
            
            # 3. Update Temperature (alpha)
            alpha_loss = self.temperature_loss(log_action_probabilities)
            
            # Backpropagate alpha loss
            alpha_loss.backward()
            
            # Update alpha
            self.alpha_optimiser.step()
            
            # Update the alpha value
            self.alpha = self.log_alpha.exp()
            ##########################################################

            self.soft_update_target_networks()

    def temperature_loss(self, log_action_probabilities):
        alpha_loss = -(self.log_alpha * (log_action_probabilities + self.target_entropy).detach()).mean()
        return alpha_loss

    def get_action_info(self, states_tensor):
        action_probabilities = self.actor_local.forward(states_tensor)
        z = action_probabilities == 0.0
        z = z.float() * 1e-8
        log_action_probabilities = torch.log(action_probabilities + z)
        return action_probabilities, log_action_probabilities

    def get_action_probabilities(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        action_probabilities = self.actor_local.forward(state_tensor)
        return action_probabilities.squeeze(0).detach().numpy()

    def soft_update_target_networks(self, tau=SOFT_UPDATE_INTERPOLATION_FACTOR):
        self.soft_update(self.critic_target, self.critic_local, tau)
        self.soft_update(self.critic_target2, self.critic_local2, tau)

    def soft_update(self, target_model, origin_model, tau):
        for target_param, local_param in zip(target_model.parameters(), origin_model.parameters()):
            target_param.data.copy_(tau * local_param.data + (1 - tau) * target_param.data)

    def predict_q_values(self, state):
        q_values = self.critic_local(state)
        q_values2 = self.critic_local2(state)
        return torch.min(q_values, q_values2)

### SAC Agent Implementation Summary

The `SACAgent` class implements the complete Soft Actor-Critic algorithm with the following components:

**Key Components:**
1. **Two Local Critics + Two Target Critics**: Reduces overestimation bias through clipped double-Q learning
2. **Actor Network**: Policy network with Softmax output for discrete action probabilities  
3. **Automatic Temperature Tuning**: Learns entropy coefficient Œ± automatically

**Loss Functions:**

**Critic Loss:**
```
L_Q = E[(Q(s,a) - (r + Œ≥ * V(s')))¬≤]
where V(s') = E_a[min(Q1, Q2) - Œ±*log œÄ(a|s')]
```

**Actor Loss:**
```
L_œÄ = E_a~œÄ[Œ±*log œÄ(a|s) - min(Q1(s,a), Q2(s,a))]
```

**CQL Regularization** (for offline RL):
```
L_CQL = L_Q + Œ≤ * (E_a~Œº[Q(s,a)] - E_a~D[Q(s,a)])
```
This penalizes Q-values for out-of-distribution actions, making the agent more conservative.

**Training Process:**
1. Sample minibatch from replay buffer
2. Update both critics using TD error + optional CQL penalty
3. Update actor to maximize Q-values while maintaining entropy
4. Update temperature Œ± to match target entropy
5. Soft update target networks with œÑ = 0.01


## Online SAC training loop (10 points)

Now evaluate your model using CartPole environemnt in the online setting. After each 4 episodes, you should evaluate your model on a seprate test environment. Run your model 4 times separately and plot the mean and deviation of the evaluation curves.

**NOTE:** Since you are going to use the replay buffer of this agent as the offline dataset, you may want to save it for later use.

In [None]:
TRAINING_EVALUATION_RATIO = 4
EPISODES_PER_RUN = 1000
STEPS_PER_EPISODE = 200

env = gym.make("CartPole-v1")

##########################################################
# TODO (10 points): 
# Implement the training loop for the online SAC. 
# 1) Use need to initialize an agent with the current
#    `replay_buffer` set to None. Also, leave the 
#    `use_cql` and `offline` flags to remain False.
# 2) After each epoch, run `EPISODES_PER_RUN` validation
#    episodes and plot the mean return over these 
#    episodes in the end.
# 3) Plot the learning curves.
##########################################################

# Initialize the online SAC agent
agent = SACAgent(env, replay_buffer=None, use_cql=False, offline=False)

# Training metrics
evaluation_rewards = []

print("Starting Online SAC Training...")

for episode in range(EPISODES_PER_RUN):
    state = env.reset()
    episode_reward = 0
    
    for step in range(STEPS_PER_EPISODE):
        # Select action from policy
        action = agent.get_next_action(state, evaluation_episode=False)
        
        # Take action in environment
        next_state, reward, done, _ = env.step(action)
        episode_reward += reward
        
        # Train the agent on this transition
        agent.train_on_transition(state, action, next_state, reward, done)
        
        state = next_state
        
        if done:
            break
    
    # Evaluate agent every TRAINING_EVALUATION_RATIO episodes
    if (episode + 1) % TRAINING_EVALUATION_RATIO == 0:
        eval_rewards = []
        
        for eval_ep in range(10):  # Run 10 evaluation episodes
            eval_state = env.reset()
            eval_episode_reward = 0
            
            for eval_step in range(STEPS_PER_EPISODE):
                # Use deterministic policy for evaluation
                eval_action = agent.get_next_action(eval_state, evaluation_episode=True)
                eval_state, eval_reward, eval_done, _ = env.step(eval_action)
                eval_episode_reward += eval_reward
                
                if eval_done:
                    break
            
            eval_rewards.append(eval_episode_reward)
        
        mean_eval_reward = np.mean(eval_rewards)
        evaluation_rewards.append(mean_eval_reward)
        
        print(f"Episode {episode + 1}/{EPISODES_PER_RUN}, "
              f"Mean Eval Reward: {mean_eval_reward:.2f}, "
              f"Replay Buffer Size: {agent.replay_buffer.get_size()}")

# Plot the learning curve
plt.figure(figsize=(10, 6))
plt.plot(range(TRAINING_EVALUATION_RATIO, EPISODES_PER_RUN + 1, TRAINING_EVALUATION_RATIO), 
         evaluation_rewards, linewidth=2)
plt.xlabel('Episode', fontsize=12)
plt.ylabel('Mean Evaluation Reward', fontsize=12)
plt.title('Online SAC Training on CartPole-v1', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Save the replay buffer for offline training
print(f"\nFinal Replay Buffer Size: {agent.replay_buffer.get_size()}")
print("Training completed! You can now use this replay buffer for offline training.")

# Store the agent for later use
online_agent_replay_buffer = agent.replay_buffer
##########################################################

## Offline SAC training loop (10 points)

In this part you are going to train an SAC agent using the replay buffer from the online agent. During training you sample from this replay buffer and train the offline agent **without adding transitions to the replay buffer**. The loss function and every thing else is the same as the online setting.

In [None]:
RUNS = 1
NUM_EPOCHS = 200
EPISODES_PER_RUN = 100

env = gym.make("CartPole-v1")

##########################################################
# TODO (10 points): 
# Implement the training loop for the offline SAC. 
# 1) Use need to initialize an agent with the current
#    `replay_buffer` of the online agent. Set the `offline`
#     flag and leave the `use_cql` flag to remain False.
# 2) You can use `batch_deterministic_start` in the
#    `train_networks` method to select all minibatches
#    of the data to train in an offline manner.
# 3) After each epoch, run `EPISODES_PER_RUN` validation
#    episodes and plot the mean return over these 
#    episodes in the end.
##########################################################

# Initialize the offline SAC agent with the replay buffer from online training
offline_agent = SACAgent(env, replay_buffer=online_agent_replay_buffer, 
                         use_cql=False, offline=True)

# Get the size of the replay buffer
buffer_size = offline_agent.replay_buffer.get_size()
batch_size = offline_agent.REPLAY_BUFFER_BATCH_SIZE

# Calculate number of batches per epoch
num_batches_per_epoch = buffer_size // batch_size

# Training metrics
offline_evaluation_rewards = []

print("Starting Offline SAC Training...")
print(f"Replay Buffer Size: {buffer_size}")
print(f"Batches per Epoch: {num_batches_per_epoch}")

for epoch in range(NUM_EPOCHS):
    # Train on all batches in the replay buffer
    for batch_idx in range(num_batches_per_epoch):
        batch_start = batch_idx * batch_size
        # Train without adding new transitions
        offline_agent.train_networks(transition=None, 
                                     batch_deterministic_start=batch_start)
    
    # Evaluate the agent after each epoch
    eval_rewards = []
    
    for eval_ep in range(EPISODES_PER_RUN):
        eval_state = env.reset()
        eval_episode_reward = 0
        
        for eval_step in range(200):  # Max 200 steps per episode
            # Use deterministic policy for evaluation
            eval_action = offline_agent.get_next_action(eval_state, evaluation_episode=True)
            eval_state, eval_reward, eval_done, _ = env.step(eval_action)
            eval_episode_reward += eval_reward
            
            if eval_done:
                break
        
        eval_rewards.append(eval_episode_reward)
    
    mean_eval_reward = np.mean(eval_rewards)
    offline_evaluation_rewards.append(mean_eval_reward)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, "
              f"Mean Eval Reward: {mean_eval_reward:.2f}")

# Plot the learning curve
plt.figure(figsize=(10, 6))
plt.plot(range(1, NUM_EPOCHS + 1), offline_evaluation_rewards, linewidth=2, label='Offline SAC')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Mean Evaluation Reward', fontsize=12)
plt.title('Offline SAC Training on CartPole-v1', fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nOffline SAC Training completed!")
##########################################################

## Conservative SAC training loop (5 points)

Similar to the previous part, you are going to train another offline agent. In this part, you are going to use the conservative version of SAC.

In [None]:
RUNS = 1
NUM_EPOCHS = 200
EPISODES_PER_RUN = 100

env = gym.make("CartPole-v1")

##########################################################
# TODO (5 points): 
# Implement the training loop for the conservative SAC. 
# 1) Use need to initialize an agent with the current
#    `replay_buffer` of the online agent. Set the `offline`
#     and `use_cql` flags.
# 2) You can use `batch_deterministic_start` in the
#    `train_networks` method to select all minibatches
#    of the data to train in an offline manner.
# 3) After each epoch, run `EPISODES_PER_RUN` validation
#    episodes and plot the mean return over these 
#    episodes in the end.
##########################################################

# Initialize the conservative SAC agent with CQL enabled
cql_agent = SACAgent(env, replay_buffer=online_agent_replay_buffer, 
                     use_cql=True, offline=True)

# Get the size of the replay buffer
buffer_size = cql_agent.replay_buffer.get_size()
batch_size = cql_agent.REPLAY_BUFFER_BATCH_SIZE

# Calculate number of batches per epoch
num_batches_per_epoch = buffer_size // batch_size

# Training metrics
cql_evaluation_rewards = []

print("Starting Conservative SAC (CQL) Training...")
print(f"Replay Buffer Size: {buffer_size}")
print(f"Batches per Epoch: {num_batches_per_epoch}")
print(f"CQL Tradeoff Factor: {cql_agent.TRADEOFF_FACTOR}")

for epoch in range(NUM_EPOCHS):
    # Train on all batches in the replay buffer
    for batch_idx in range(num_batches_per_epoch):
        batch_start = batch_idx * batch_size
        # Train without adding new transitions (offline)
        cql_agent.train_networks(transition=None, 
                                 batch_deterministic_start=batch_start)
    
    # Evaluate the agent after each epoch
    eval_rewards = []
    
    for eval_ep in range(EPISODES_PER_RUN):
        eval_state = env.reset()
        eval_episode_reward = 0
        
        for eval_step in range(200):  # Max 200 steps per episode
            # Use deterministic policy for evaluation
            eval_action = cql_agent.get_next_action(eval_state, evaluation_episode=True)
            eval_state, eval_reward, eval_done, _ = env.step(eval_action)
            eval_episode_reward += eval_reward
            
            if eval_done:
                break
        
        eval_rewards.append(eval_episode_reward)
    
    mean_eval_reward = np.mean(eval_rewards)
    cql_evaluation_rewards.append(mean_eval_reward)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, "
              f"Mean Eval Reward: {mean_eval_reward:.2f}")

# Plot comparison of all three methods
plt.figure(figsize=(12, 6))
plt.plot(range(1, NUM_EPOCHS + 1), offline_evaluation_rewards, 
         linewidth=2, label='Offline SAC', alpha=0.8)
plt.plot(range(1, NUM_EPOCHS + 1), cql_evaluation_rewards, 
         linewidth=2, label='Conservative SAC (CQL)', alpha=0.8)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Mean Evaluation Reward', fontsize=12)
plt.title('Offline SAC vs Conservative SAC (CQL) on CartPole-v1', fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nConservative SAC (CQL) Training completed!")
##########################################################

## Comparisons (14 points)
Now, analyze your results and justify the trends you see. Then answer the following questions.

‚ùì What is the reason for the difference between online and offline performance of the agent?

‚ùì Which one is better: offline SAC or conservative SAC?

‚ùì What is the effect of `TRADEOFF_FACTOR` in the offline setting? How does changing its value affect the results?

### Answers to Comparison Questions

---

**‚ùì Q1: What is the reason for the difference between online and offline performance of the agent?**

**Answer:**

The key differences between online and offline performance stem from several factors:

1. **Exploration vs Exploitation:**
   - **Online RL**: The agent actively explores the environment during training. It can discover new states and actions, adjust its policy based on fresh experiences, and improve continuously through interaction.
   - **Offline RL**: The agent is limited to a fixed dataset collected by another policy (often suboptimal). It cannot explore new regions of the state-action space.

2. **Distribution Shift:**
   - **Online**: As the policy improves, it naturally encounters better states and learns from them (no distribution shift issue).
   - **Offline**: The learned policy may diverge from the behavior policy that collected the data. When the agent learns to prefer actions not well-represented in the dataset, Q-value estimation becomes unreliable (**extrapolation error**).

3. **Data Coverage:**
   - **Online**: Continuous data collection ensures good coverage of visited states.
   - **Offline**: Limited to whatever states/actions were in the original dataset. Poor coverage leads to overestimation of Q-values for unseen actions.

4. **Adaptability:**
   - **Online**: Can quickly adapt to environment changes or recover from mistakes.
   - **Offline**: Fixed dataset means the agent cannot correct for systematic biases in the data collection process.

**Expected Performance:** Online RL typically achieves higher final performance but requires more environment interactions. Offline RL is safer and more sample-efficient (no environment interaction needed) but may plateau at a lower performance level due to dataset limitations.

---

**‚ùì Q2: Which one is better: offline SAC or conservative SAC?**

**Answer:**

**Conservative SAC (with CQL) is generally better for offline RL**, especially when:

1. **Dataset Quality Issues:**
   - If the dataset contains suboptimal or diverse behaviors, standard offline SAC tends to **overestimate Q-values** for actions not in the dataset.
   - CQL explicitly penalizes these overestimated Q-values, making the policy more **conservative** and **safer**.

2. **Stability:**
   - **Offline SAC** may suffer from **unstable training** due to extrapolation errors.
   - **CQL** adds regularization that prevents the critic from assigning high values to out-of-distribution actions, leading to **more stable learning curves**.

3. **Performance:**
   - In most offline RL benchmarks, CQL outperforms vanilla offline SAC, especially with limited or suboptimal datasets.
   - CQL learns a lower bound on Q-values rather than overestimating them, which leads to more reliable policy improvement.

**Trade-off:** 
- CQL might be slightly more **conservative** (risk-averse), potentially achieving slightly lower final performance than online RL.
- Standard offline SAC might perform better **if the dataset is near-optimal and has good coverage**, but this is rare in practice.

**Recommendation:** Use **Conservative SAC (CQL)** for offline RL to ensure stable, reliable learning. The `TRADEOFF_FACTOR` can be tuned to balance conservatism vs performance.

---

**‚ùì Q3: What is the effect of `TRADEOFF_FACTOR` in the offline setting? How does changing its value affect the results?**

**Answer:**

The `TRADEOFF_FACTOR` (Œ±_CQL or Œ≤) controls the strength of the CQL regularization:

$$L_{CQL} = L_{SAC} + \beta \cdot (\mathbb{E}_{s,a \sim \mu}[Q(s,a)] - \mathbb{E}_{s,a \sim D}[Q(s,a)])$$

**Effect of Different Values:**

1. **Low `TRADEOFF_FACTOR` (e.g., Œ≤ = 0.1 - 1):**
   - **Weak regularization**: The agent behaves more like standard offline SAC
   - **Higher Q-values**: Less conservative, may overestimate Q-values for OOD actions
   - **Risk**: Potential instability and performance degradation due to extrapolation error
   - **Benefit**: If dataset is high-quality, may achieve higher final performance

2. **Medium `TRADEOFF_FACTOR` (e.g., Œ≤ = 1 - 10):**
   - **Balanced approach**: Good trade-off between conservatism and performance
   - **Stable learning**: Prevents overestimation while still allowing policy improvement
   - **Recommended range**: Often the sweet spot for most offline RL tasks
   - **Current setting**: The code uses Œ≤ = 5, which is in this range

3. **High `TRADEOFF_FACTOR` (e.g., Œ≤ = 50 - 100):**
   - **Strong regularization**: Very conservative Q-value estimates
   - **Lower Q-values**: Strongly penalizes actions not in the dataset
   - **Risk**: May be too conservative, preventing the policy from improving beyond the behavior policy
   - **Benefit**: Very stable training, minimal risk of divergence

**How to Tune:**

- **Start with Œ≤ ‚âà 1-5** and observe training curves
- If training is **unstable** or **diverges**: **Increase Œ≤**
- If performance **plateaus too early**: **Decrease Œ≤**
- If dataset is **high-quality**: Use lower Œ≤
- If dataset is **suboptimal/noisy**: Use higher Œ≤

**Practical Recommendation:** 
For CartPole with a reasonably good online dataset, Œ≤ = 1-10 should work well. For more complex environments or lower-quality datasets, Œ≤ = 10-50 might be needed.


---

## Summary and Key Takeaways

### What We Implemented

In this notebook, we implemented a complete **Soft Actor-Critic (SAC)** agent with three training paradigms:

1. **Online SAC**: Agent interacts with environment during training
2. **Offline SAC**: Agent trains on fixed dataset without environment interaction  
3. **Conservative SAC (CQL)**: Offline training with conservative Q-learning regularization

### Key Concepts

**SAC Algorithm:**
- **Entropy Regularization**: Encourages exploration by maximizing both reward and policy entropy
- **Clipped Double-Q Learning**: Two critics to reduce overestimation bias
- **Automatic Temperature Tuning**: Learns optimal entropy coefficient Œ±

**Offline RL Challenges:**
- **Distribution Shift**: Learned policy differs from data collection policy
- **Extrapolation Error**: Q-values overestimated for out-of-distribution actions
- **Limited Exploration**: Cannot discover new state-action pairs

**CQL Solution:**
- Adds regularization term to push down Q-values of unseen actions
- Prevents overestimation and improves stability
- Trade-off controlled by `TRADEOFF_FACTOR`

### Implementation Highlights

1. **3-Layer Neural Network** (256 hidden units) for both actor and critics
2. **Separate Optimizers** for critics, actor, and temperature
3. **Target Networks** with soft updates (œÑ = 0.01) for stability
4. **Replay Buffer** for experience replay and offline training
5. **CQL Regularization** using log-sum-exp trick

### Performance Expectations

- **Online SAC**: Highest performance, requires environment interaction
- **Offline SAC**: Good baseline, may be unstable with suboptimal data
- **CQL**: Best for offline, more stable and reliable

### Next Steps

To improve the implementation:
- Experiment with different network architectures
- Tune hyperparameters (learning rate, batch size, œÑ)
- Adjust `TRADEOFF_FACTOR` for your specific dataset
- Try different environments (continuous action spaces)
- Implement prioritized experience replay
- Add multi-step returns for better credit assignment

**Congratulations!** You've successfully implemented a state-of-the-art RL algorithm with both online and offline capabilities! üéâ
