# ABE Tutorial 3  
## Actor-Critic Based Reinforcement Learning

In this tutorial, we dive deeper into the reinforcement learning (RL) algorithm used in previous tutorials. We focus on the Actor-Critic method, which is a hybrid approach that combines ideas from value-based methods and policy-based methods. The actor network selects actions, while the critic network evaluates states, allowing the agent to learn both a policy and a value function simultaneously.

**Key components:**

- **Actor:**  
  Takes the current state as input and produces raw scores (logits) for each possible action. When these logits are passed through a softmax, they yield a probability distribution over actions. The actor is responsible for learning the policy that selects actions expected to yield high rewards.

- **Critic:**  
  Evaluates the current state by outputting a single scalar value, representing the expected cumulative reward (or return) if the agent follows its policy from that state onward. This estimate is used as a baseline to compute the advantage, which guides the actor's updates.

This architecture is used in many state-of-the-art deep RL methods, because the critic’s evaluation stabilizes the policy updates by reducing variance in the gradient estimates.

In this tutorial, you will learn how to:
- Build an Actor-Critic neural network using PyTorch.
- Create a custom policy with TD-learning.
- Train and evaluate the agent in a Gymnasium environment (CartPole-v1).

***
## **Walkthrough**: Actor-Critic Neural Network Model

### 1. Model Architecture

To start off one of the main differences between the SARSA algorithm in the last tutorial and the Actor-Critic model we'll be learning here is that the Actor-Critic architecture has **two specialized components that work in tandem**:

1. **Actor Branch**  
   - **Input:**         The current state.  
   - **Output:**        Raw scores (logits) for each possible action.  
   - **Processing:**    A softmax converts these logits into a probability distribution over actions.  
   - **Role:**          Learns the policy by favoring actions that are likely to yield higher rewards.

2. **Critic Branch**  
   - **Input:**     The same state.  
   - **Output:**    A single scalar value estimating the expected return (or state value).  
   - **Role:**      Provides a baseline to compute the advantage, i.e., the difference between the actual returns and the estimated value, which guides the policy update.

Together, these branches enable the agent to both choose actions and assess their quality.

Just like with the SARSA code let's create a new class:

```python
    import torch
    import torch.nn as nn
    import numpy as np


    class ActorCriticNet(nn.Module):
        """
        Actor-Critic network that outputs both the action logits (for the actor)
        and the state value (for the critic).
        """

        def __init__(self, state_shape, action_shape, hidden_size=128):
            """
            Initialize the ActorCriticNet.

            Parameters:
                state_shape (tuple): Shape of the state space.
                action_shape (int or tuple): Number of possible actions.
                hidden_size (int): Number of units in hidden layers.
            """
            super().__init__()

            # Build the actor network: maps state -> action logits.

            # Build the critic network: maps state -> scalar value.

        def forward(self, obs, state=None, info={}):
            """
            Perform a forward pass of the network.

            Parameters:
                obs (np.ndarray or torch.Tensor): Input observation.
                state (optional): Not used here (for compatibility with recurrent models).
                info (dict, optional): Additional information (unused).

            Returns:
                tuple: (action_logits, state_value)
                    - action_logits (torch.Tensor): Logits for actions.
                    - state_value (torch.Tensor): Estimated value of the state.
            """
```

Now that we know the parts let's fill in the code. Below we'll define some network layers for the critic and actor models.

```python
    import torch
    import torch.nn as nn
    import numpy as np


    class ActorCriticNet(nn.Module):
        """
        Actor-Critic network that outputs both the action logits (for the actor)
        and the state value (for the critic).
        """

        def __init__(self, state_shape, action_shape, hidden_size=128):
            """
            Initialize the ActorCriticNet.

            Parameters:
                state_shape (tuple): Shape of the state space.
                action_shape (int or tuple): Number of possible actions.
                hidden_size (int): Number of units in hidden layers.
            """
            super().__init__()

            # Build the actor network: maps state -> action logits.
            self.actor = nn.Sequential(
                nn.Linear(np.prod(state_shape), hidden_size),
                nn.ReLU(),
                nn.LayerNorm(hidden_size),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LayerNorm(hidden_size),
                nn.Linear(hidden_size, np.prod(action_shape))
            )

            # Build the critic network: maps state -> scalar value.
            self.critic = nn.Sequential(
                nn.Linear(np.prod(state_shape), hidden_size),
                nn.ReLU(),
                nn.LayerNorm(hidden_size),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LayerNorm(hidden_size),
                nn.Linear(hidden_size, 1)
            )
```

Next let's build the method that will get predictions from the actor-critic network models

```python
        def forward(self, obs, state=None, info={}):
            """
            Perform a forward pass of the network.

            Parameters:
                obs (np.ndarray or torch.Tensor): Input observation.
                state (optional): Not used here (for compatibility with recurrent models).
                info (dict, optional): Additional information (unused).

            Returns:
                tuple: (action_logits, state_value)
                    - action_logits (torch.Tensor): Logits for actions.
                    - state_value (torch.Tensor): Estimated value of the state.
            """
            # Convert observations to tensor if needed.
            if isinstance(obs, np.ndarray):
                obs = torch.tensor(obs, dtype=torch.float32)
            action_logits = self.actor(obs)
            state_value = self.critic(obs).squeeze(-1)
            return action_logits, state_value
```

### 2. Custom Policy Definition

In deep RL, the policy determines the action selection process. Here we implement an Advantage Actor-Critic (A2C) policy:
- **Forward Method:** Computes action probabilities using the actor network.
- **Learn Method:** Updates both actor and critic networks using TD-learning.  

The critic computes a TD target and the advantage, which is used to update the actor via policy gradients.

The advantage is calculated as:  
$$
\text{advantage} = \text{TD target} - \text{state value},
$$  
where the TD target incorporates the observed reward and the discounted value of the next state. This formulation reduces the variance in policy updates.

Below is the initial skeleton for our custom A2C policy.

```python
    import torch
    import torch.nn as nn
    from tianshou.policy import BasePolicy
    from tianshou.data import Batch

    class A2CPolicy(BasePolicy):
        """
        Advantage Actor-Critic (A2C) policy combining actor and critic networks.
        Uses the critic's evaluation to compute an advantage for updating the actor.
        """

        def __init__(self, model, optim, action_space, gamma=0.99):
            """
            Initialize the A2C policy.

            Parameters:
                model (nn.Module): The Actor-Critic network.
                optim (torch.optim.Optimizer): Optimizer for training.
                action_space: Environment's action space.
                gamma (float): Discount factor.
            """

        def forward(self, batch, state=None, **kwargs):
            """
            Compute actions for given observations.

            Parameters:
                batch (Batch): Contains environment observations.
                state (optional): Not used.
                kwargs: Additional parameters.

            Returns:
                Batch: Contains chosen actions and the distribution.
            """

        def learn(self, batch, **kwargs):
            """
            Update the network using TD-learning.

            Steps:
            1. Compute the TD target: r + γ(1-d) V(s')
            2. Calculate the advantage: TD target - current state value.
            3. Compute the actor (policy) loss and critic (value) loss.
            4. Backpropagate and update the network weights.

            Parameters:
                batch (Batch): Batch of transitions.
                kwargs: Additional parameters.

            Returns:
                dict: Loss values for monitoring.
            """
```

#### Initialize the New Policy

Let's look at the initialization of the policy first.

When we initalize the training policy we'll provide information about:
* the *model*, this will be used to choose actions given states
* the *optim*, the optim (optimizer) will be used to update the model, i.e., it's how the agent will learn.
* *gamma*, defines how the agent values future rewards vs. immediate rewards.
* *Note:*, epsilon is no longer needed as our actor outputs a distribution of probabilities (logit scale) that defines the probability that the agent will select an action. So in this way our A2C algorithm already has exploration built in and does not need to take random actions.

```python
        import torch
        import torch.nn as nn
        from tianshou.policy import BasePolicy
        from tianshou.data import Batch

        class A2CPolicy(BasePolicy):
        """
        Advantage Actor-Critic (A2C) policy combining actor and critic networks.
        Uses the critic's evaluation to compute an advantage for updating the actor.
        """

        def __init__(self, model, optim, action_space, gamma=0.99):
                """
                Initialize the A2C policy.

                Parameters:
                model (nn.Module): The Actor-Critic network.
                optim (torch.optim.Optimizer): Optimizer for training.
                action_space: Environment's action space.
                gamma (float): Discount factor.
                """
                super().__init__(action_space=action_space)
                self.model = model
                self.optim = optim
                self.gamma = gamma
```

#### Forward

We now add the forward method that takes the probabilities of action (on the logit scale) and uses those to choose what actions to take.

* **Note**: The categorical distribution below is a good choice when the actions are discrete actions. We'll see how to modify this so that we can also work with continuous actions spaces.


```python
        def forward(self, batch, state=None, **kwargs):
            """
            Compute actions for given observations.

            Parameters:
                batch (Batch): Contains environment observations.
                state (optional): Not used.
                kwargs: Additional parameters.

            Returns:
                Batch: Contains chosen actions and the distribution.
            """
            logits, _ = self.model(batch.obs)
            # Create a categorical distribution (suitable for discrete actions).
            dist = torch.distributions.Categorical(logits=logits)
            action = dist.sample()
            return Batch(act=action.cpu().numpy(), dist=dist)
```

#### Learn

Now that we initialized our policy with a model, an optimizer, and some paramters, as well as setup how the agent will choose actions using the model, we need to think about how the agent will learn. This is an important step. We'll use TD-learning here to update the model so that it can predict the value of actions. It will do this in a few steps:

* Estimate the current state value and the action probabilities.
* Estimate the TD target: this is the expected value of the next state, action value.
* Calculate the advantage: this measures how much better our predicted state values are when compared to what was seen before.
* Use the difference between the current action probabilities and the advantage of each action to make the actor better. The ideal here is if the action probabilities exactly match the expected value returns of each action.
* Use the difference between the estimated state value and the TD-target to make the critic better.
* If we repeatedly do this the actor and critic network should make better estimates action probabilities, and state values.

```python
        def learn(self, batch, **kwargs):
            """
            Update the network using TD-learning.

            Steps:
            1. Compute the TD target: r + γ(1-d) V(s')
            2. Calculate the advantage: TD target - current state value.
            3. Compute the actor (policy) loss and critic (value) loss.
            4. Backpropagate and update the network weights.

            Parameters:
                batch (Batch): Batch of transitions.
                kwargs: Additional parameters.

            Returns:
                dict: Loss values for monitoring.
            """
            logits, state_values = self.model(batch.obs)
            dist = torch.distributions.Categorical(logits=logits)
            log_probs = dist.log_prob(batch.act)

            # Compute TD target and advantage (without gradient tracking).
            with torch.no_grad():
                _, next_state_values = self.model(batch.obs_next)
                td_target = batch.rew + self.gamma * (1 - batch.done) * next_state_values
                advantage = td_target - state_values
                # Normalize the advantage for stable learning.
                advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)

            # Compute policy (actor) loss and value (critic) loss.
            policy_loss = -(log_probs * advantage.detach()).mean()
            value_loss = nn.functional.mse_loss(state_values, td_target)
            loss = policy_loss + value_loss

            # Backpropagation with gradient clipping.
            self.optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optim.step()

            return {"loss": loss.item(), "policy_loss": policy_loss.item(), "value_loss": value_loss.item()}
```

Now that we've defined the policy and an actor/critic model, let's take a look at the full code!

***
## **Full Implementation and Testing**: : Actor-Critic Reinforcement Learning

Now that we've gone through the code piece-by-piece, let's take a look at that full code.

This section combines the network and policy definitions, sets up the environment, runs a training loop, saves the model, and finally tests the trained agent.

### 1. Preliminary Code

In [None]:
# Standard library imports
from datetime import datetime
import os
import shutil
import subprocess
import tempfile
import time

# Third-party imports
import gymnasium as gym
import pygame
import torch
from IPython.display import IFrame, display
from torch.utils.tensorboard import SummaryWriter

# Local application/library-specific imports
import tianshou as ts
from tianshou.data import Collector, ReplayBuffer
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils.net.common import Net

# Timestamped ID for this run to avoid overwriting previous runs and to keep track of different runs
agent_id = datetime.now().strftime("%Y%m%d_%H%M%S")  # Format: YYYYMMDD_HHMMSS

# Setup directories for saving models and logs
actor_critic_dir = f"actor_critic_cartpole{agent_id}"
logs_dir = os.path.join(actor_critic_dir, "logs")
models_dir = os.path.join(actor_critic_dir, "models")

os.makedirs(actor_critic_dir, exist_ok=True) # Ensure the directory exists
print(f"All files for this run will be saved in the directory: {actor_critic_dir}")
os.makedirs(logs_dir, exist_ok=True) # Ensure the directory exists
print(f"Tensorboard logs will be saved in the directory: {logs_dir}")
os.makedirs(models_dir, exist_ok=True) # Ensure the directory exists
print(f"Models will be saved in the directory: {models_dir}")

# Create a logger
logger = ts.utils.TensorboardLogger(SummaryWriter(logs_dir))
print(f"TensorBoard logs are being saved in: {logs_dir}")

### 2. Actor-Critic Neural Network (`ActorCriticNet`) Full Code

In [None]:
import torch
import torch.nn as nn
import numpy as np


class ActorCriticNet(nn.Module):
    """
    Actor-Critic network that outputs both the action logits (for the actor)
    and the state value (for the critic).

    The network consists of two branches:
    - Actor: Predicts the logits for each action.
    - Critic: Estimates the scalar value of the current state.
    """

    def __init__(self, state_shape, action_shape, hidden_size=128):
        """
        Initialize the ActorCriticNet.

        Parameters:
            state_shape (tuple): The shape of the state space.
            action_shape (int or tuple): The number of possible actions.
            hidden_size (int): Number of units in the hidden layers.
        """
        super().__init__()

        # Build the actor network
        self.actor = nn.Sequential(
            nn.Linear(np.prod(state_shape), hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, np.prod(action_shape))
        )

        # Build the critic network
        self.critic = nn.Sequential(
            nn.Linear(np.prod(state_shape), hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, obs, state=None, info={}):
        """
        Perform a forward pass of the network.

        Parameters:
            obs (np.ndarray or torch.Tensor): Input observation from the environment.
            state (optional): Unused parameter for compatibility with recurrent models.
            info (dict, optional): Additional information (unused).

        Returns:
            tuple: (action_logits, state_value)
                - action_logits (torch.Tensor): Logits for each action from the actor network.
                - state_value (torch.Tensor): Estimated value of the state from the critic network.
        """
        # Ensure the observation is a torch tensor
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float32)
        action_logits = self.actor(obs)
        state_value = self.critic(obs).squeeze(-1)
        return action_logits, state_value

### 3. Advantage Actor-Critic (A2C) Policy (`A2CPolicy`) Full Code

In [None]:
import torch
import torch.nn as nn
from tianshou.policy import BasePolicy
from tianshou.data import Batch


class A2CPolicy(BasePolicy):
    """
    Advantage Actor-Critic (A2C) policy that combines actor and critic networks.
    It samples actions based on the actor's logits and updates both networks using TD-learning.
    """

    def __init__(self, model, optim, action_space, gamma=0.99):
        """
        Initialize the A2C policy.

        Parameters:
            model (nn.Module): The Actor-Critic network.
            optim (torch.optim.Optimizer): Optimizer for training the network.
            action_space: The action space from the environment.
            gamma (float): Discount factor for future rewards.
        """
        super().__init__(action_space=action_space)
        self.model = model
        self.optim = optim
        self.gamma = gamma

    def forward(self, batch, state=None, **kwargs):
        """
        Compute the action for given observations.

        Parameters:
            batch (Batch): Contains observations from the environment.
            state (optional): Not used in this implementation.
            kwargs: Additional parameters.

        Returns:
            Batch: Contains the chosen actions and the associated distribution.
        """
        logits, _ = self.model(batch.obs)
        # Create a categorical distribution from the logits.
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        return Batch(act=action.cpu().numpy(), dist=dist)

    def learn(self, batch, **kwargs):
        """
        Update the model based on a batch of experience.

        The update involves:
          1. Computing the TD target and advantage.
          2. Calculating the loss for both actor and critic.
          3. Performing a gradient descent step with gradient clipping.

        Parameters:
            batch (Batch): Batch of transitions containing observations, actions, rewards, etc.
            kwargs: Additional parameters.

        Returns:
            dict: Contains overall loss, actor loss (policy_loss), and critic loss (value_loss).
        """
        # Forward pass to get logits and state values.
        logits, state_values = self.model(batch.obs)
        dist = torch.distributions.Categorical(logits=logits)
        log_probs = dist.log_prob(batch.act)

        # Compute the TD target and advantage using no gradient tracking.
        with torch.no_grad():
            _, next_state_values = self.model(batch.obs_next)
            td_target = batch.rew + self.gamma * (1 - batch.done) * next_state_values
            advantage = td_target - state_values
            # Normalize the advantage for stable training.
            advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)

        # Calculate the actor loss (policy loss) to encourage actions with higher advantage.
        policy_loss = -(log_probs * advantage.detach()).mean()
        # Calculate the critic loss as the mean squared error between predicted and target values.
        value_loss = nn.functional.mse_loss(state_values, td_target)
        # Combine both losses.
        loss = policy_loss + value_loss

        # Perform backpropagation.
        self.optim.zero_grad()
        loss.backward()
        # Clip gradients to avoid exploding gradients.
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optim.step()

        return {"loss": loss.item(), "policy_loss": policy_loss.item(), "value_loss": value_loss.item()}

### 4. Create the Environment

**Recall**:

Here, we extract:
- *State Shape:* The dimensionality of the observation space.
- *Action Shape:* The number of discrete actions.
- *Action Space:* For later use in the policy definition.

In [None]:
import gymnasium as gym

# Create the environment instance.
env = gym.make("CartPole-v1")

# Extract the shape of the observation space.
state_shape = env.observation_space.shape

# Extract the number of actions.
action_shape = env.action_space.n

# Store the action space for later use.
action_space = env.action_space

### 5. Creating the Actor-Critic Network and Policy

**Recall**:

Now we instantiate our Actor-Critic network and set up the optimizer. The network is then used to initialize the A2C policy, which will drive our agent’s learning. Notice how we pass the environment's action space for compatibility.

In [None]:
import torch.optim as optim

# Select the appropriate device: CUDA (NVIDIA GPUs), MPS (Apple GPUs), or CPU.
# AMD GPUs with ROCm support accessed using the 'cuda' device string.
device = torch.device("cuda" if torch.cuda.is_available() else
                      "mps" if torch.backends.mps.is_available() else
                      "cpu")
print(f"Using device: {device}")

# Instantiate the Actor-Critic network.
net = ActorCriticNet(state_shape, action_shape).to(device)

# Define the optimizer with a small learning rate and weight decay.
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)

# Create the A2C policy using the network and optimizer.
policy = A2CPolicy(model=net, optim=optimizer, action_space=action_space, gamma=0.99)

### 6. Training the Agent

**Recall**:

The training loop uses Tianshou’s components to collect experience and update the policy. We use:
- *ReplayBuffer:* To store recent transitions.
- *Collector:* To gather data from interactions with the environment.
- *TensorBoard Logger:* For step-level and epoch-level metric tracking.

During training, each step collects transitions and performs a TD-learning update.

The following code implements the training loop:

In [None]:
import os
import shutil
import subprocess
import tempfile
import time

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import ReplayBuffer, Collector, Batch
import tianshou as ts
from tqdm.notebook import tqdm
from IPython.display import IFrame, display

def kill_port(port):
    """
    Terminates any processes that are listening on the specified port.
    Works on both Unix-based systems and Windows.
    """
    try:
        if os.name == 'nt':
            # Windows: Use netstat and taskkill to kill processes on the given port.
            # The command below might fail (exit status 1) if no process is found.
            cmd = f'for /f "tokens=5" %a in (\'netstat -aon ^| findstr :{port}\') do taskkill /F /PID %a'
            result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            print(f"Killed processes on port {port}.")
        else:
            # Unix (Linux/Mac): Use lsof to find processes on the port and kill them.
            cmd = f"lsof -ti:{port} | xargs kill -9"
            result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            print(f"Killed processes on port {port}.")
    except subprocess.CalledProcessError as e:
        # If the error message indicates that no process was found, we can ignore it.
        if "returned non-zero exit status 1" in str(e):
            pass
        else:
            print(f"Could not kill process on port {port}: {e}")

# Kill any processes on port 6006 to ensure it is free.
kill_port(6006)

# Clear previous TensorBoard sessions (cross-platform)
tensorboard_info = os.path.join(tempfile.gettempdir(), ".tensorboard-info")
if os.path.exists(tensorboard_info):
    shutil.rmtree(tensorboard_info)

# Launch TensorBoard in the background on port 6006.
tb_command = [
    "tensorboard",
    "--logdir", logs_dir,
    "--port", "6006",
    "--host", "localhost",
    "--reload_interval", "30"
]
tb_process = subprocess.Popen(tb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

# Allow time for TensorBoard to start and display its dashboard.
time.sleep(5)
display(IFrame(src="http://localhost:6006", width="100%", height="800"))

#------------------------------------------------------------------------------ 

# Hyperparameters for training.
max_epoch = 3               # Total number of epochs for training.
steps_per_epoch = 1000      # Number of training steps per epoch.
keep_n_steps = 30           # Number of recent transitions to use for learning.

# Initialize a ReplayBuffer to store recent transitions.
buffer = ReplayBuffer(size=keep_n_steps)

# Create collectors for training and testing.
train_collector = Collector(policy, env, buffer)
test_collector = Collector(policy, env)

# Lists to store epoch summaries for later analysis.
epoch_training_losses = []
epoch_test_rewards = []
epoch_durations = []

global_start_time = time.time()  # Start the overall training timer.

# Training loop with comprehensive progress tracking and logging.
for epoch in range(max_epoch):
    epoch_start_time = time.time()  # Timer for the current epoch.
    train_collector.reset()         # Reset collector at the start of each epoch.
    running_loss = 0.0              # Accumulate loss to compute the average loss.

    # Set up a tqdm progress bar with dynamic post-fix metrics.
    progress_bar = tqdm(
        range(steps_per_epoch),
        desc=f"Epoch {epoch+1}/{max_epoch}",
        dynamic_ncols=True
    )
    
    for step in progress_bar:
        # Collect a fixed number of steps and store transitions in the buffer.
        train_collector.collect(n_step=keep_n_steps)
        # Retrieve the most recent transitions.
        batch = train_collector.buffer[-keep_n_steps:]
        
        # Convert batch fields to torch tensors for compatibility.
        batch.obs = torch.tensor(batch.obs, dtype=torch.float32)
        batch.act = torch.tensor(batch.act, dtype=torch.long)
        batch.rew = torch.tensor(batch.rew, dtype=torch.float32)
        batch.done = torch.tensor(batch.done, dtype=torch.float32)
        batch.obs_next = torch.tensor(batch.obs_next, dtype=torch.float32)
        
        # Update the policy using the collected batch and capture the loss.
        learn_info = policy.learn(batch)
        loss_val = learn_info.get("loss", 0)
        running_loss += loss_val
        
        global_step = epoch * steps_per_epoch + step
        
        # Log step-level metrics to TensorBoard.
        logger.writer.add_scalar("Loss/train_step", loss_val, global_step)
        logger.writer.add_scalar("Loss/train_running_avg", running_loss / (step + 1), global_step)
        
        # Flush logs periodically.
        if step % 50 == 0:
            logger.writer.flush()
        
        # Update progress bar with current metrics.
        progress_bar.set_postfix({
            "Step": f"{step}/{steps_per_epoch}",
            "Loss": f"{loss_val:07.3f}",
            "AvgLoss": f"{running_loss / (step + 1):07.3f}"
        })
        
        # Print summary at every 25% of the epoch.
        if step % (steps_per_epoch // 4) == 0:
            print(
                f"Epoch {epoch+1}, Step {step}/{steps_per_epoch}: "
                f"Step Loss = {loss_val}, Running Avg Loss = {running_loss / (step + 1)}"
            )
    
    # Compute average loss over the epoch.
    avg_loss = running_loss / steps_per_epoch

    # Reset the test collector and evaluate the agent on 10 episodes.
    test_collector.reset()
    test_result = test_collector.collect(n_episode=10)
    mean_reward = np.mean(test_result["rews"])
    std_reward = np.std(test_result["rews"])
    min_reward = np.min(test_result["rews"])
    p25_reward = np.percentile(test_result["rews"], 25)
    median_reward = np.median(test_result["rews"])
    p75_reward = np.percentile(test_result["rews"], 75)
    max_reward = np.max(test_result["rews"])

    # Log epoch-level metrics to TensorBoard.
    logger.writer.add_scalar("Reward/test_avg", mean_reward, epoch)
    logger.writer.add_scalar("Loss/train_avg", avg_loss, epoch)
    logger.writer.flush()

    # Calculate epoch elapsed time.
    epoch_elapsed = time.time() - epoch_start_time
    epoch_training_losses.append(avg_loss)
    epoch_test_rewards.append(mean_reward)
    epoch_durations.append(epoch_elapsed)

    # Print detailed epoch summary.
    print(
        f"\nEpoch {epoch+1} Summary:\n"
        f"  - Epoch Elapsed Time      : {epoch_elapsed} seconds\n"
        f"  - Steps Collected         : {steps_per_epoch}\n"
        f"  - Average Training Loss   : {avg_loss}\n"
        f"  - Mean Test Reward        : {mean_reward}\n"
        f"  - Std Test Reward         : {std_reward}\n"
        f"  - Min Test Reward         : {min_reward}\n"
        f"  - 25th Percentile Reward  : {p25_reward}\n"
        f"  - Median Test Reward      : {median_reward}\n"
        f"  - 75th Percentile Reward  : {p75_reward}\n"
        f"  - Max Test Reward         : {max_reward}\n"
    )

# Final flush and close the TensorBoard writer.
logger.writer.close()

# Calculate overall training statistics.
total_elapsed = time.time() - global_start_time
overall_avg_loss = np.mean(epoch_training_losses)
overall_avg_reward = np.mean(epoch_test_rewards)
total_epochs = len(epoch_durations)

# Print overall training summary.
print("\nOverall Training Summary:")
print(f"  - Total Epochs            : {total_epochs}")
print(f"  - Overall Average Loss    : {overall_avg_loss}")
print(f"  - Overall Average Reward  : {overall_avg_reward}")
print(f"  - Total Elapsed Time      : {total_elapsed} seconds")

# Reiterate the final epoch's metrics.
print("\nFinal Epoch Summary:")
print(
    f"  - Epoch {total_epochs}:\n"
    f"      * Average Training Loss : {epoch_training_losses[-1]}\n"
    f"      * Average Test Reward   : {epoch_test_rewards[-1]}\n"
    f"      * Epoch Elapsed Time    : {epoch_durations[-1]} seconds\n"
)

Did it learn? Do you see rewards increasing? 

If so, let's save the model's state dictionary. This allows you to reload the trained agent later without retraining.

In [None]:
import os
import torch

# Save the model's state dictionary for future use.
model_path = os.path.join(models_dir, f"{actor_critic_dir}.pth")
torch.save(net.state_dict(), model_path)
print(f"Model saved to {model_path}")

### 7. Evaluating the Trained Agent

Let's test out the model, and watch what it learnt.

Load in the trained model.

In [None]:
import os
import torch

model_path = os.path.join(models_dir, f"{actor_critic_dir}.pth")

# Select the appropriate device: CUDA (NVIDIA GPUs), MPS (Apple GPUs), or CPU.
# AMD GPUs with ROCm support accessed using the 'cuda' device string.
device = torch.device("cuda" if torch.cuda.is_available() else
                      "mps" if torch.backends.mps.is_available() else
                      "cpu")
print(f"Using device: {device}")

# Load the trained policy
loaded_net = ActorCriticNet(state_shape, action_shape).to(device)
loaded_net.load_state_dict(torch.load(model_path, map_location=device))
print("Model loaded successfully!")

Let's create an environment and build a policy based on our saved model.

In [None]:
import gymnasium as gym

# Create an evaluation environment with rendering enabled.
eval_env = gym.make("CartPole-v1", render_mode="human")

# Create a new policy using the loaded network.
loaded_policy = A2CPolicy(model=loaded_net, optim=optimizer, action_space=action_space, gamma=0.99)  # Set epsilon=0 for pure exploitation

Now let's run our agent in the environment.

**Note**: You can change the number of episodes to watch!

In [None]:
import numpy as np
import torch
from tianshou.data import Batch

num_episodes = 20     # Number of evaluation episodes.
episode_rewards = []  # To store the total reward of each episode.
episode_lengths = []  # To store the number of steps in each episode.

for episode in range(num_episodes):
    obs, _ = eval_env.reset()
    done = False
    total_reward = 0
    step_count = 0
    print(f"Starting evaluation of episode {episode + 1}")

    while not done:
        step_count += 1
        # Create a batch from the current observation.
        obs_batch = Batch(obs=[obs])
        
        # Get action from the policy (exploitation mode, no exploration noise).
        with torch.no_grad():
            action = loaded_policy.forward(obs_batch).act[0]
        
        # Take the action in the environment.
        obs, reward, done, truncated, _ = eval_env.step(action)
        total_reward += reward

        # End the episode if finished.
        if done or truncated:
            print(f"Episode {episode + 1} ended with total reward: {total_reward} after {step_count} steps.")
            break

    episode_rewards.append(total_reward)
    episode_lengths.append(step_count)

# Convert lists to numpy arrays for statistical calculations.
episode_rewards = np.array(episode_rewards)
episode_lengths = np.array(episode_lengths)

# Compute and print comprehensive performance statistics.
if episode_rewards.size > 0 and episode_lengths.size > 0:
    count_rewards = len(episode_rewards)
    mean_rewards = np.mean(episode_rewards)
    std_rewards = np.std(episode_rewards)
    min_rewards = np.min(episode_rewards)
    p25_rewards = np.percentile(episode_rewards, 25)
    median_rewards = np.median(episode_rewards)
    p75_rewards = np.percentile(episode_rewards, 75)
    max_rewards = np.max(episode_rewards)
    
    count_lengths = len(episode_lengths)
    mean_lengths = np.mean(episode_lengths)
    std_lengths = np.std(episode_lengths)
    min_lengths = np.min(episode_lengths)
    p25_lengths = np.percentile(episode_lengths, 25)
    median_lengths = np.median(episode_lengths)
    p75_lengths = np.percentile(episode_lengths, 75)
    max_lengths = np.max(episode_lengths)
    
    print("\nFinal Evaluation Performance Summary:")
    print(f"Total Episodes Evaluated: {num_episodes}\n")
    header = "{:<22} {:>15} {:>20}".format("Statistic", "Rewards", "Episode Lengths")
    print(header)
    print("-" * len(header))
    print("{:<22} {:>15d} {:>20d}".format("Count", count_rewards, count_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("Mean", mean_rewards, mean_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("Std Dev", std_rewards, std_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("Min", min_rewards, min_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("25th Percentile", p25_rewards, p25_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("Median", median_rewards, median_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("75th Percentile", p75_rewards, p75_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("Max", max_rewards, max_lengths))
else:
    print("No performance data was collected. Please verify the Collector configuration.")

# Close the environment after evaluation.
eval_env.close()
print("Evaluation completed and environment closed.")

***
**Things to Try**

Try changing the environment or changing the hyperparameters:

- **Learning Rate:** Experiment with different learning rates. A rate that is too high may cause unstable updates; a rate that is too low may slow down learning.  
- **Discount Factor ($\gamma$):** Adjust the discount factor to balance the importance of immediate versus future rewards.  
- **Network Architecture:** Modify the number of layers or hidden units to observe how it affects performance.  
- **Environment Variations:** Test the agent in other environments to assess the robustness of the algorithm.

Try altering some of these hyperparameters and see how that changes the ability of your agent to learn! Which hyperparameters work best?