# ABE tutorial 3
## Looking at actor critic based RL

In this third tutorial let's dive deeper into the RL algorithm that we used in the first and second tutorials. This time we'll learn how we can have two models: the critic model that will be dedicated to estimating the value of states, and the actor model that will help predict which actions will lead to more rewards in the long term (i.e., a policy).

Steps:
* Create an Actor-critic neural network
* Create a custom Policy and training procedure
* Test out the Actor-critic RL algorithm in an environment



## Actor-Critic model

### Neural network model 

To start off one of the main differences between the SARSA algorithm in the last tutorial and the Actor-Critic model we'll be learning here is that we need two models:

* Actor: Finds the probabilities of actions in states that leads to the highest return in rewards.
* Critic: Finds the expected returns of being in a state.

Just like with the SARSA code let's create a new class:

```python
class ActorCriticNet(nn.Module):
    
    #initialize the network
    def __init__(self, state_shape, action_shape, hidden_size=128):
                
        #actor model
        self.actor = nn.Sequential(
        )

        #critic model
        self.critic = nn.Sequential(
        )
        
        
    #get predictions from the models
    def forward():
                
        return action, state_value
```


Now that we know the parts let's fill in the code. Below we'll define some network layers for the critic and actor models.


```python
class ActorCriticNet(nn.Module):
    
    #initialize the network
    def __init__(self, state_shape, action_shape, hidden_size=128):
        super().__init__()
        
        #build the actor model
        self.actor = nn.Sequential(
            nn.Linear(np.prod(state_shape), hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, np.prod(action_shape))
        )

        #build the actor model
        self.critic = nn.Sequential(
            nn.Linear(np.prod(state_shape), hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, 1)
        )
```

Next let's build the method that will get predictions from the actor-critic network models

```python
    def forward(self, obs, state=None, info={}):
        
        #make sure observations are in tensor format
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float32)
                
        # Actor network output: logits for actions, i.e., how likely are the actions
        action_logits = self.actor(obs)
        
        # Critic network output: state value
        state_value = self.critic(obs).squeeze(-1)
        
        return action_logits, state_value
```

### Custom Policy


Here we'll build a custom policy to better understand how TD-learning is working with our actor-critic algorithm! We'll see that the policy stays very close to the SARSA policy.

To create a custom policy we'll need to define how the agent will learn. To do this we'll create a new python class called A2CPolicy. It will have three peices:

* *__init__*: and initialization method
* *forward*: a method that will take observations and predict actions, i.e., what action should the agent take? 
* *learn*: a method that will allow the agent to learn what the value of a state is, i.e., update the forward method

```python
# Custom SARSA policy class
class A2CPolicy(BasePolicy):
    def __init__():
        #initialize the policy, i.e., set some parameters values

    def forward():
        
        #a model to predict actions

        return action

    def learn():

        # Estimate the value of a state based on TD-learning
```


#### Initialize the new policy

Let's look at the initialization of the policy first.

When we initalize the training policy we'll provide information about:
* the *model*, this will be used to choose actions given states
* the *optim*, the optim (optimizer) will be used to update the model, i.e., it's how the agent will learn.
* *gamma*, defines how the agent values future rewards vs. immediate rewards.
* *Note:*, epsilon is no longer needed as our actor outputs a distribution of probabilities (logit scale) that defines the probability that the agent will select an action. So in this way our A2C algorithm already has exploration built in and does not need to take random actions.

```python
        def __init__(self, model, optim, action_space, gamma=0.99):
                super().__init__(action_space=action_space)
                self.model = model
                self.optim = optim
                self.gamma = gamma
```

#### Forward

We now add the forward method that takes the probabilities of action (on the logit scale) and uses those to choose what actions to take.

* **Note**, The categorical distribution below is a good choice when the actions are discrete actions. We'll see how to modify this so that we can also work with continuous actions spaces.


```python
    def forward(self, batch, state=None, **kwargs):
        logits, _ = self.model(batch.obs)
        
        # Sample action based on policy (softmax)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        
        return Batch(act=action.cpu().numpy(), dist=dist)
```

#### Learn

Now that we initialized our policy with a model, an optimizer, and some paramters, as well as setup how the agent will choose actions using the model, we need to think about how the agent will learn. This is an important step. We'll use TD-learning here to update the model so that it can predict the value of actions. It will do this in a few steps:

* Estimate the current state value and the action probabilities.
* Estimate the TD target: this is the expected value of the next state, action value.
* Calculate the advantage: this measures how much better our predicted state values are when compared to what was seen before.
* Use the difference between the current action probabilities and the advantage of each action to make the actor better. The ideal here is if the action probabilities exactly match the expected value returns of each action.
* Use the difference between the estimated state value and the TD-target to make the critic better.
* If we repeatedly do this the actor and critic network should make better estimates action probabilities, and state values.

```python
    def learn(self, batch, **kwargs):
        
        # Forward pass to get actor (logits) and critic (value)
        logits, state_values = self.model(batch.obs)
        dist = torch.distributions.Categorical(logits=logits)
        
        # Compute the log probabilities of the taken actions
        log_probs = dist.log_prob(batch.act)
        
        # Compute the critic's next state values (for TD target)
        with torch.no_grad():
            _, next_state_values = self.model(batch.obs_next)
            td_target = batch.rew + self.gamma * (1 - batch.done) * next_state_values
            
            # Calculate the normalized advantage
            advantage = td_target - state_values  # Advantage calculation
            advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8) #normalization

        # Calculate entropy for the policy distribution
        #entropy = dist.entropy().mean()
        
        # Calculate policy (actor) loss (include entropy regularization)
        policy_loss = -(log_probs * advantage.detach()).mean() #- 0.01 * entropy  # Adjust weight as needed
        
        # Calculate value (critic) loss
        value_loss = nn.functional.mse_loss(state_values, td_target)
        
        # Combine the losses
        loss = policy_loss + value_loss
        
        # Backpropagation
        self.optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) #gradient clipping: make sure the updates to the network are small.
        self.optim.step()

        return {"loss": loss.item(), "policy_loss": policy_loss.item(), "value_loss": value_loss.item()}
```


Now that we've defined the policy and an actor/critic model, let's take a look at the full code!

### Full code

Now that we've gone through the code peice by peice, let's take a look at that full code.

In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
from tianshou.env import DummyVectorEnv
from tianshou.data import Batch, ReplayBuffer, Collector
from tianshou.policy import BasePolicy
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts


class ActorCriticNet(nn.Module):
    def __init__(self, state_shape, action_shape, hidden_size=128):
        super().__init__()

        self.actor = nn.Sequential(
            nn.Linear(np.prod(state_shape), hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, np.prod(action_shape))
        )

        self.critic = nn.Sequential(
            nn.Linear(np.prod(state_shape), hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, 1)
        )
        
        
    def forward(self, obs, state=None, info={}):
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float32)
        
        # Actor network output: logits for actions
        action_logits = self.actor(obs)
        
        # Critic network output: state value
        state_value = self.critic(obs).squeeze(-1)
        
        return action_logits, state_value


class A2CPolicy(BasePolicy):
    def __init__(self, model, optim, action_space, gamma=0.99):
        super().__init__(action_space=action_space)
        self.model = model
        self.optim = optim
        self.gamma = gamma

    def forward(self, batch, state=None, **kwargs):
        logits, _ = self.model(batch.obs)
        
        # Sample action based on policy (softmax)
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        
        return Batch(act=action.cpu().numpy(), dist=dist)

    def learn(self, batch, **kwargs):
        
        # Forward pass to get actor (logits) and critic (value)
        logits, state_values = self.model(batch.obs)
        dist = torch.distributions.Categorical(logits=logits)
        
        # Compute the log probabilities of the taken actions
        log_probs = dist.log_prob(batch.act)
        
        # Compute the critic's next state values (for TD target)
        with torch.no_grad():
            _, next_state_values = self.model(batch.obs_next)
            td_target = batch.rew + self.gamma * (1 - batch.done) * next_state_values
            
            #calculate the normalized advantage
            advantage = td_target - state_values  # Advantage calculation
            advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)

        # Calculate entropy for the policy distribution
        #entropy = dist.entropy().mean()
        
        # Calculate policy (actor) loss (include entropy regularization)
        policy_loss = -(log_probs * advantage.detach()).mean() #- 0.01 * entropy  # Adjust weight as needed
        
        # Calculate value (critic) loss
        value_loss = nn.functional.mse_loss(state_values, td_target)
        
        # Combine the losses
        loss = policy_loss + value_loss
        
        # Backpropagation
        self.optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optim.step()

        return {"loss": loss.item(), "policy_loss": policy_loss.item(), "value_loss": value_loss.item()}



### Test out the A2C algorithm

To test out our new A2C algorithm, let's create an environment and train an agent in it!

In [None]:
# Create a single environment instance to access the space information
single_env = gym.make("CartPole-v1")
state_shape = single_env.observation_space.shape 
action_shape = single_env.action_space.n
action_space = single_env.action_space

Next let's built the Actor-critic network and the optimizer to allow the network to learn. Then lets put that all in the new A2CPolicy.

In [None]:
# Setting up the actor-critic network and A2C policy
net = ActorCriticNet(state_shape, action_shape)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-5, weight_decay=1e-4)
policy = A2CPolicy(model=net, optim=optimizer, action_space=action_space, gamma=0.99)



Finally, let's run a training loop so that the agent can interact with the environment, and we can store obserations and use the observations to learn.

In [None]:


# Custom training loop
max_epoch = 10
step_per_epoch = 5000
keep_n_steps = 30
buffer = ReplayBuffer(size=keep_n_steps)

# Set up collectors
train_collector = Collector(policy, single_env, buffer)
test_collector = Collector(policy, single_env)

#start a logger
logger_a2c = ts.utils.TensorboardLogger(SummaryWriter('log/a2c_custom'))

for epoch in range(max_epoch):
    train_collector.reset()
    for step in range(step_per_epoch):
        # Collect one transition and store it in the buffer
        #train_collector.collect(n_step=1)
        train_collector.collect(n_step=keep_n_steps)

        # Sample the most recent observations from the buffer
        #batch, _ = train_collector.buffer.sample(batch_size=30)
        batch = train_collector.buffer[-keep_n_steps:]

        # Manually convert each field to a torch tensor
        batch.obs = torch.tensor(batch.obs, dtype=torch.float32)
        batch.act = torch.tensor(batch.act, dtype=torch.long)
        batch.rew = torch.tensor(batch.rew, dtype=torch.float32)
        batch.done = torch.tensor(batch.done, dtype=torch.float32)
        batch.obs_next = torch.tensor(batch.obs_next, dtype=torch.float32)

        # Perform A2C learning
        policy.learn(batch)

    # Testing and evaluation
    result = test_collector.collect(n_episode=10, reset_before_collect=True)
    print(f'Epoch #{epoch + 1}: reward = {result.returns.mean()}, loss = {policy.learn(batch)["loss"]}')

    # Log the average reward for the epoch
    logger_a2c.writer.add_scalar("Reward/test_avg", result.returns.mean(), epoch)

Did it learn? Do you see rewards increasing? 

If so let's save the model:

In [None]:
torch.save(net.state_dict(), "models/A2C_cartpole_model.pth")

### Test the model

Let's test out the model, and watch what it learnt.

Load in the trained model.

In [None]:
# Initialize a new network with the same architecture
loaded_net = ActorCriticNet(state_shape, action_shape)
loaded_net.load_state_dict(torch.load("models/A2C_cartpole_model.pth"))

Let's create an environment and build a policy based on our saved model.

In [None]:
# Create the environment for evaluation with rendering enabled
eval_env = gym.make("CartPole-v1", render_mode="human")

# Set the loaded network as the model for a new SARSA policy
loaded_policy = A2CPolicy(model=loaded_net, optim=optimizer, action_space=action_space, gamma=0.99)  # Set epsilon=0 for pure exploitation


Now let's run our agent in the environment. Note: you can change the number of episodes to watch!

In [None]:

# Set the number of episodes you want to watch
num_episodes = 1

for episode in range(num_episodes):
    obs, _ = eval_env.reset()
    done = False
    total_reward = 0
    
    print(f"Starting episode {episode + 1}")

    while not done:
        # Create a batch for the current observation
        obs_batch = Batch(obs=[obs])
        
        # Get action based on loaded model's Q-values (no exploration)
        action = loaded_policy.forward(obs_batch).act[0]
        
        # Step the environment with the selected action
        obs, reward, done, truncated, _ = eval_env.step(action)
        total_reward += reward

        # Check if the episode has ended
        if done or truncated:
            print(f"Episode {episode + 1} ended with total reward: {total_reward}")
            break  # Break out of the loop to start the next episode


# Close the environment after finishing all episodes
eval_env.close()

**Things to try**

Try changing the environment or changing the hyperparameters:

* **Learning rate** (how fast to learn from new data): too high and the agent might learn sprious correlations between actions and outcomes, too low and it might take the agent for ever to figure what actions lead to good rewards.

* **Discount factor** or **gamma** (how much does the agent value future vs. near rewards): too high and the agent might miss near rewards, too low and the agent might be too focused on the short term and miss longer term outcomes.

Try altering some of these hyperparameters and see how that changes the ability of your agent to learn! Which hyperparameters work best?

