# Soft Actor Critic Agent(115 Points)

> Name:

> SID: 



In this notebook, we are going to implement **Soft Actor Critic (SAC)** 
on the **CartPole** environment in online and offline settings. In this framework, the actor aims to maximize expected reward while also maximizing **entropy**. That is, to succeed at the task while acting as randomly as possible. This method seeks a high entropy in the policy to explicitly encourage exploration. For the offline setting, you are going to make SAC conservative using CQL method. 

* SAC is an off-policy algorithm.
* The version of SAC implemented here can only be used for environments with discrete action spaces.
* An alternate version of SAC, which slightly changes the policy update  rule, can be implemented to handle continouse action spaces.
* Complete the **TODO** parts in the code accordingly.
* Remember to answer the conceptual questions.




In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim
import numpy as np
import random
import gym
import matplotlib.pyplot as plt


seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

## Network Structure (8 points)
For constructing SAC agent, we use objects of feedforward neural networks with 3 layers. Complete the code below.

In [None]:
class Network(torch.nn.Module):

    def __init__(self, input_dimension, output_dimension, output_activation=torch.nn.Identity()):
        super(Network, self).__init__()
        ##########################################################
        # TODO (4 points): 
        # Define your network layers.
        ##########################################################
        # 3-layer feedforward neural network with hidden size 256
        self.layer_1 = torch.nn.Linear(input_dimension, 256)
        self.layer_2 = torch.nn.Linear(256, 256)
        self.output_layer = torch.nn.Linear(256, output_dimension)
        self.output_activation = output_activation
        ##########################################################

    def forward(self, inpt):  
        output = None      
        ##########################################################
        # TODO (4 points): 
        # Use relu and the output activation functions to calculate the output
        ##########################################################
        # Forward pass through the network with ReLU activations
        x = torch.nn.functional.relu(self.layer_1(inpt))
        x = torch.nn.functional.relu(self.layer_2(x))
        output = self.output_activation(self.output_layer(x))
        return output
        ##########################################################

### Network Architecture Explanation

The `Network` class implements a 3-layer feedforward neural network:

1. **Input Layer → Hidden Layer 1**: Maps from input dimension to 256 neurons
2. **Hidden Layer 1 → Hidden Layer 2**: 256 → 256 neurons  
3. **Hidden Layer 2 → Output Layer**: 256 → output dimension

**Activation Functions:**
- **ReLU** is used between hidden layers to introduce non-linearity
- **Output activation** is customizable (e.g., `Softmax` for actor, `Identity` for critics)

This network architecture will be used for both the **actor** (policy network) and **critics** (Q-value networks) in SAC.


## Replay Buffer

A SAC agent needs a replay buffer, from which previously visited states can be sampled. You can use the implemented code below. You are going to use the replay buffer of an online-trained agent to train the offline model.

In [11]:
import numpy as np


class ReplayBuffer:

    def __init__(self, environment, capacity=500000):
        transition_type_str = self.get_transition_type_str(environment)
        self.buffer = np.zeros(capacity, dtype=transition_type_str)
        self.weights = np.zeros(capacity)
        self.head_idx = 0
        self.count = 0
        self.capacity = capacity
        self.max_weight = 10**-2
        self.delta = 10**-4
        self.indices = None
        self.mirror_index = np.random.permutation(range(self.buffer.shape[0]))

    def get_transition_type_str(self, environment):
        state_dim = environment.observation_space.shape[0]
        state_dim_str = '' if state_dim == () else str(state_dim)
        state_type_str = environment.observation_space.sample().dtype.name
        action_dim = environment.action_space.shape
        action_dim_str = '' if action_dim == () else str(action_dim)
        action_type_str = environment.action_space.sample().__class__.__name__

        # type str for transition = 'state type, action type, reward type, state type'
        transition_type_str = '{0}{1}, {2}{3}, float32, {0}{1}, bool'.format(state_dim_str, state_type_str,
                                                                             action_dim_str, action_type_str)

        return transition_type_str

    def add_transition(self, transition):
        self.buffer[self.head_idx] = transition
        self.weights[self.head_idx] = self.max_weight

        self.head_idx = (self.head_idx + 1) % self.capacity
        self.count = min(self.count + 1, self.capacity)

    def sample_minibatch(self, size=100, batch_deterministic_start=None):
        set_weights = self.weights[:self.count] + self.delta
        probabilities = set_weights / sum(set_weights)
        if batch_deterministic_start is None:
            self.indices = np.random.choice(range(self.count), size, p=probabilities, replace=False)
        else:
            self.indices = self.mirror_index[batch_deterministic_start:batch_deterministic_start+size]
        return self.buffer[self.indices]

    def update_weights(self, prediction_errors):
        max_error = max(prediction_errors)
        self.max_weight = max(self.max_weight, max_error)
        self.weights[self.indices] = prediction_errors

    def get_size(self):
        return self.count

## Questions (18 points)

❓ We know that standard RL maximizes the expected sum of rewards. What is the objective function of SAC algorithm? Compare it to the standard RL loss.

❓ Write down the actor cost function.

❓ Write down the critic cost function.

❓ Elaborate on the reason why most implementations of SAC use two critics (one local and one target).

❓ What is the difference between training samples in offline and online settings?

❓ How does adding CQL on top of SAC change the objective function?



### Answers to Conceptual Questions

**❓ Q1: What is the objective function of SAC algorithm? Compare it to the standard RL loss.**

**Answer:**  
Standard RL maximizes:
$$J(\pi) = \mathbb{E}_{\tau \sim \pi}[\sum_{t=0}^{\infty} \gamma^t r(s_t, a_t)]$$

SAC maximizes **entropy-regularized objective**:
$$J(\pi) = \mathbb{E}_{\tau \sim \pi}[\sum_{t=0}^{\infty} \gamma^t (r(s_t, a_t) + \alpha H(\pi(\cdot|s_t)))]$$

where $H(\pi(\cdot|s_t)) = -\mathbb{E}_{a \sim \pi}[\log \pi(a|s_t)]$ is the entropy.

**Key Difference:** SAC encourages exploration by rewarding high entropy (randomness) in the policy, while standard RL focuses only on maximizing rewards.

---

**❓ Q2: Write down the actor cost function.**

**Answer:**  
$$J_\pi(\phi) = \mathbb{E}_{s_t \sim D}\mathbb{E}_{a_t \sim \pi_\phi}[\alpha \log \pi_\phi(a_t|s_t) - Q_\theta(s_t, a_t)]$$

Or equivalently for discrete actions:
$$J_\pi(\phi) = \mathbb{E}_{s_t \sim D}[\sum_a \pi_\phi(a|s_t)(\alpha \log \pi_\phi(a|s_t) - Q_\theta(s_t, a))]$$

The actor minimizes this cost, which balances between maximizing Q-values and maintaining high entropy.

---

**❓ Q3: Write down the critic cost function.**

**Answer:**  
$$J_Q(\theta) = \mathbb{E}_{(s_t,a_t,r_t,s_{t+1}) \sim D}[(Q_\theta(s_t, a_t) - y_t)^2]$$

where the target is:
$$y_t = r_t + \gamma(1-d_t) \mathbb{E}_{a_{t+1} \sim \pi}[Q_{\theta'}(s_{t+1}, a_{t+1}) - \alpha \log \pi(a_{t+1}|s_{t+1})]$$

For discrete actions:
$$y_t = r_t + \gamma(1-d_t) \sum_a \pi(a|s_{t+1})[Q_{\theta'}(s_{t+1}, a) - \alpha \log \pi(a|s_{t+1})]$$

---

**❓ Q4: Elaborate on the reason why most implementations of SAC use two critics (one local and one target).**

**Answer:**  
SAC uses **two local critics** and **two target critics** (4 critics total):

1. **Two Local Critics (Q1, Q2):** Helps reduce **overestimation bias**. We take the minimum: $Q(s,a) = \min(Q_1(s,a), Q_2(s,a))$. This clipped double-Q learning prevents the critic from being overly optimistic.

2. **Target Networks (Q1_target, Q2_target):** Provides **stable training targets**. Target networks are slowly updated (soft update with $\tau \ll 1$), preventing the "moving target" problem where the Q-values we're trying to match keep changing rapidly.

---

**❓ Q5: What is the difference between training samples in offline and online settings?**

**Answer:**  

| Aspect | Online RL | Offline RL |
|--------|-----------|------------|
| **Data Collection** | Agent interacts with environment during training | Uses pre-collected fixed dataset |
| **Exploration** | Can explore new states/actions | Limited to dataset coverage |
| **Distribution Shift** | Policy improves, collects better data | Policy may diverge from dataset distribution |
| **Sample Efficiency** | Requires many environment interactions | No environment interaction needed |
| **Safety** | May take dangerous actions during exploration | Safe (no real-world interaction) |

**Key Challenge in Offline RL:** **Extrapolation error** - the agent may learn to take actions not well-represented in the dataset, leading to overestimated Q-values for out-of-distribution actions.

---

**❓ Q6: How does adding CQL on top of SAC change the objective function?**

**Answer:**  
CQL (Conservative Q-Learning) adds a **conservative regularizer** to the critic loss:

**Standard SAC Critic Loss:**
$$J_Q(\theta) = \mathbb{E}_{(s,a,r,s') \sim D}[(Q_\theta(s, a) - y)^2]$$

**CQL Critic Loss:**
$$J_{CQL}(\theta) = \alpha_{CQL} \cdot \underbrace{(\mathbb{E}_{s \sim D, a \sim \mu(a|s)}[Q_\theta(s,a)] - \mathbb{E}_{s,a \sim D}[Q_\theta(s,a)])}_{\text{CQL regularizer}} + J_Q(\theta)$$

where $\mu$ is a behavior policy (e.g., uniform or current policy).

**Effect:** 
- **Increases** Q-values for actions in the dataset $D$
- **Decreases** Q-values for out-of-distribution actions $\mu$
- This makes the agent **conservative**, avoiding actions not seen in the offline dataset
- The tradeoff factor $\alpha_{CQL}$ controls the strength of this conservatism


## SAC Agent (50 points)

Now complete the following class. You can use the auxiliary methods provided in the class.

In [51]:
class SACAgent:

    ALPHA_INITIAL = 1.
    REPLAY_BUFFER_BATCH_SIZE = 100
    DISCOUNT_RATE = 0.99
    LEARNING_RATE = 10 ** -4
    SOFT_UPDATE_INTERPOLATION_FACTOR = 0.01
    TRADEOFF_FACTOR = 5 # trade-off factor in the CQL

    def __init__(self, environment, replay_buffer=None, use_cql=False, offline=False):

        assert not use_cql or offline, 'Please activate the offline flag for CQL.' 
        assert not offline or not replay_buffer is None, 'Please pass a replay buffer to the offline method.' 

        self.environment = environment
        self.state_dim = self.environment.observation_space.shape[0]
        self.action_dim = self.environment.action_space.n

        self.offline = offline
        self.replay_buffer = ReplayBuffer(self.environment) if replay_buffer is None else replay_buffer
        self.use_cql = use_cql

        ##########################################################
        # TODO (6 points): 
        # Define critiss usig your impelmented feed forward netwrok(10 points).
        # To have easier critic updates, you can use two local critic networks 
        # and two target critics.
        ##########################################################
        self.critic_local = None
        self.critic_local2 = None
        self.critic_optimiser = None
        self.critic_optimiser2 = None
        self.critic_target = None
        self.critic_target2 = None
        ##########################################################

        self.soft_update_target_networks(tau=1.)

        ##########################################################
        # TODO (2 points): 
        # Define the actor usig your impelmented feed forward netwrok(10 points).
        # Define the actor optimizer using torch.Adam (4 points)
        ##########################################################
        self.actor_local = None
        self.actor_optimiser  = None
        ##########################################################

        self.target_entropy = 0.98 * -np.log(1 / self.environment.action_space.n)
        self.log_alpha = torch.tensor(np.log(self.ALPHA_INITIAL), requires_grad=True)
        self.alpha = self.log_alpha
        self.alpha_optimiser = torch.optim.Adam([self.log_alpha], lr=self.LEARNING_RATE)

    def get_next_action(self, state, evaluation_episode=False):
        if evaluation_episode:
            discrete_action = self.get_action_deterministically(state)
        else:
            discrete_action = self.get_action_nondeterministically(state)
        return discrete_action

    def get_action_nondeterministically(self, state):
        action_probabilities = self.get_action_probabilities(state)
        discrete_action = np.random.choice(range(self.action_dim), p=action_probabilities)
        return discrete_action

    def get_action_deterministically(self, state):
        action_probabilities = self.get_action_probabilities(state)
        discrete_action = np.argmax(action_probabilities)
        return discrete_action

    def critic_loss(self, states_tensor, actions_tensor, rewards_tensor, 
                    next_states_tensor, done_tensor):
        ##########################################################
        # TODO (12 points): 
        # You are going to calculate critic losses in this method.
        # Also you should implement the CQL loss if the corresponding 
        # flag is set.
        ##########################################################
        critic_loss, critic2_loss = 0, 0

        return critic_loss, critic2_loss
        ##########################################################

    def actor_loss(self, states_tensor):
        ##########################################################
        # TODO (8 points): 
        # Now implement the actor loss.
        ##########################################################
        actor_loss, log_action_probabilities = 0, 0

        return actor_loss, log_action_probabilities
        ##########################################################

    def train_on_transition(self, state, discrete_action, next_state, reward, done):
        transition = (state, discrete_action, reward, next_state, done)
        self.train_networks(transition)

    def train_networks(self, transition=None, batch_deterministic_start=None):
        ##########################################################
        # TODO (6 points): 
        # Set all the gradients stored in the optimisers to zero.
        # add the new transition to the replay buffer for online case.
        ##########################################################

        if self.replay_buffer.get_size() >= self.REPLAY_BUFFER_BATCH_SIZE:
            minibatch = self.replay_buffer.sample_minibatch(self.REPLAY_BUFFER_BATCH_SIZE,
                                                            batch_deterministic_start=batch_deterministic_start)
            minibatch_separated = list(map(list, zip(*minibatch)))

            states_tensor = torch.tensor(np.array(minibatch_separated[0]))
            actions_tensor = torch.tensor(np.array(minibatch_separated[1]))
            rewards_tensor = torch.tensor(np.array(minibatch_separated[2])).float()
            next_states_tensor = torch.tensor(np.array(minibatch_separated[3]))
            done_tensor = torch.tensor(np.array(minibatch_separated[4]))

            ##########################################################
            # TODO (16 points): 
            # Here, you should compute the gradients based on this loss, i.e. the gradients
            # of the loss with respect to the Q-network parameters.
            # Given a minibatch of 100 transitions from replay buffer,
            # compute the critic loss and perform the backward and step functions,
            # and compute the actor loss and perform the backward and step functions.
            # You also need to update \alpha.
            ##########################################################

            ##########################################################

            self.soft_update_target_networks()

    def temperature_loss(self, log_action_probabilities):
        alpha_loss = -(self.log_alpha * (log_action_probabilities + self.target_entropy).detach()).mean()
        return alpha_loss

    def get_action_info(self, states_tensor):
        action_probabilities = self.actor_local.forward(states_tensor)
        z = action_probabilities == 0.0
        z = z.float() * 1e-8
        log_action_probabilities = torch.log(action_probabilities + z)
        return action_probabilities, log_action_probabilities

    def get_action_probabilities(self, state):
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        action_probabilities = self.actor_local.forward(state_tensor)
        return action_probabilities.squeeze(0).detach().numpy()

    def soft_update_target_networks(self, tau=SOFT_UPDATE_INTERPOLATION_FACTOR):
        self.soft_update(self.critic_target, self.critic_local, tau)
        self.soft_update(self.critic_target2, self.critic_local2, tau)

    def soft_update(self, target_model, origin_model, tau):
        for target_param, local_param in zip(target_model.parameters(), origin_model.parameters()):
            target_param.data.copy_(tau * local_param.data + (1 - tau) * target_param.data)

    def predict_q_values(self, state):
        q_values = self.critic_local(state)
        q_values2 = self.critic_local2(state)
        return torch.min(q_values, q_values2)

## Online SAC training loop (10 points)

Now evaluate your model using CartPole environemnt in the online setting. After each 4 episodes, you should evaluate your model on a seprate test environment. Run your model 4 times separately and plot the mean and deviation of the evaluation curves.

**NOTE:** Since you are going to use the replay buffer of this agent as the offline dataset, you may want to save it for later use.

In [None]:
TRAINING_EVALUATION_RATIO = 4
EPISODES_PER_RUN = 1000
STEPS_PER_EPISODE = 200

env = gym.make("CartPole-v1")

##########################################################
# TODO (10 points): 
# Implement the training loop for the online SAC. 
# 1) Use need to initialize an agent with the current
#    `replay_buffer` set to None. Also, leave the 
#    `use_cql` and `offline` flags to remain False.
# 2) After each epoch, run `EPISODES_PER_RUN` validation
#    episodes and plot the mean return over these 
#    episodes in the end.
# 3) Plot the learning curves.
##########################################################

## Offline SAC training loop (10 points)

In this part you are going to train an SAC agent using the replay buffer from the online agent. During training you sample from this replay buffer and train the offline agent **without adding transitions to the replay buffer**. The loss function and every thing else is the same as the online setting.

In [None]:
RUNS = 1
NUM_EPOCHS = 200
EPISODES_PER_RUN = 100

env = gym.make("CartPole-v1")

##########################################################
# TODO (10 points): 
# Implement the training loop for the offline SAC. 
# 1) Use need to initialize an agent with the current
#    `replay_buffer` of the online agent. Set the `offline`
#     flag and leave the `use_cql` flag to remain False.
# 2) You can use `batch_deterministic_start` in the
#    `train_networks` method to select all minibatches
#    of the data to train in an offline manner.
# 3) After each epoch, run `EPISODES_PER_RUN` validation
#    episodes and plot the mean return over these 
#    episodes in the end.
##########################################################

## Conservative SAC training loop (5 points)

Similar to the previous part, you are going to train another offline agent. In this part, you are going to use the conservative version of SAC.

In [None]:
RUNS = 1
NUM_EPOCHS = 200
EPISODES_PER_RUN = 100

env = gym.make("CartPole-v1")

##########################################################
# TODO (5 points): 
# Implement the training loop for the conservative SAC. 
# 1) Use need to initialize an agent with the current
#    `replay_buffer` of the online agent. Set the `offline`
#     and `use_cql` flags.
# 2) You can use `batch_deterministic_start` in the
#    `train_networks` method to select all minibatches
#    of the data to train in an offline manner.
# 3) After each epoch, run `EPISODES_PER_RUN` validation
#    episodes and plot the mean return over these 
#    episodes in the end.
##########################################################

## Comparisons (14 points)
Now, analyze your results and justify the trends you see. Then answer the following questions.

❓ What is the reason for the difference between online and offline performance of the agent?

❓ Which one is better: offline SAC or conservative SAC?

❓ What is the effect of `TRADEOFF_FACTOR` in the offline setting? How does changing its value affect the results?