# ABE Tutorial 2
## Value Based Reinforcement Learning

In this second tutorial let's dive deeper into the RL algorithm that we used in the first tutorial. This time let's open up the policy and see how it learns in more detail.

Steps:
* Create a custom policy
* Check how it updates based on temporal difference learning



## Using a Custom Policy



### 1. Setup a simple RL example

Let's build a simple RL example based on what we learnt last tutorial.


In [None]:
# Standard library imports
from datetime import datetime
import os
import shutil
import subprocess
import tempfile
import time

# Third-party imports
import gymnasium as gym
import pygame
import torch
from IPython.display import IFrame, display
from torch.utils.tensorboard import SummaryWriter

# Local application/library-specific imports
import tianshou as ts
from tianshou.data import Collector, ReplayBuffer
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils.net.common import Net

In [None]:
# Timestamped ID for this run to avoid overwriting previous runs and to keep track of different runs

agent_id = datetime.now().strftime("%Y%m%d_%H%M%S")  # Format: YYYYMMDD_HHMMSS

In [None]:
# Setup directories for saving models and logs

offline_policy_dir = f"offline_policy_{agent_id}"
logs_dir = os.path.join(offline_policy_dir, "logs")
models_dir = os.path.join(offline_policy_dir, "models")

os.makedirs(offline_policy_dir, exist_ok=True) # Ensure the directory exists
print(f"All files for this run will be saved in the directory: {offline_policy_dir}")
os.makedirs(logs_dir, exist_ok=True) # Ensure the directory exists
print(f"Tensorboard logs will be saved in the directory: {logs_dir}")
os.makedirs(models_dir, exist_ok=True) # Ensure the directory exists
print(f"Models will be saved in the directory: {models_dir}")

In [None]:
# Create a logger

logger = ts.utils.TensorboardLogger(SummaryWriter(logs_dir))
print(f"TensorBoard logs are being saved in: {logs_dir}")

Now let's use the CartPole example an environment to train our agent in, and extract the observations, and the action space.

In [None]:
# Create the environment without opening a window.
env = gym.make("CartPole-v1")

# Reset the environment to obtain the initial observation and return a tuple (observation, info).
obs, info = env.reset()

# Retrieve the observation and action spaces directly from the environment.
observation_space = env.observation_space
action_space = env.action_space

# Get the shape of the observation space (what the agent sees).
state_shape = observation_space.shape

# Determine the action space shape.
# For discrete action spaces, use the number of available actions.
if isinstance(action_space, gym.spaces.Discrete):
    action_shape = action_space.n
else:
    action_shape = action_space.shape

print("State shape:", state_shape)
print("Action shape:", action_shape)

In [None]:
# Print the observation space, which describes what the agent can observe from the environment.
print(f"Observation space: {observation_space}")

# Print the action space, which describes the actions the agent can take in the environment.
print(f"Action space: {action_space}")

Let's start building our agent's brain. 

To start off let's build a neural network that take what the agent observes and converts that into actions. Then we'll add an optimizer to shift the weights in the network to better predict the value of states.

In [None]:
# Build the network that maps observations to action values.

# Select the appropriate device: CUDA (NVIDIA GPUs), MPS (Apple GPUs), or CPU.
# AMD GPUs with ROCm support accessed using the 'cuda' device string.
device = torch.device("cuda" if torch.cuda.is_available() else
                      "mps" if torch.backends.mps.is_available() else
                      "cpu")
print(f"Using device: {device}")

# The Tianshou Net class (from tianshou.utils.net.common) includes several useful parameters:
#   * input_size: Unpacked observation dimensions (e.g., 4 for CartPole).
#   * output_size: Number of possible actions (e.g., 2 for CartPole).
#   * hidden_sizes: List specifying the number of neurons in each hidden layer.
#   * device: Device to run the network on (default: 'cpu'; set to 'cuda' if available).
#   * activation: Activation function used between layers (default: torch.nn.ReLU).

net = Net(
    *state_shape,                    # Unpacked input dimensions (state size).
    action_shape,                    # Output size: number of actions.
    hidden_sizes=[64, 64],           # Hidden layer sizes; adjust for desired model capacity.
    device=device,                   # Device to run the network; switch to 'cuda' if a GPU is available.
    # activation=torch.nn.ReLU       # Activation function; change if another function is preferred.
)

# Print the network architecture to verify its structure.
print("Network architecture:")
print(net)

In [None]:
# Create an Adam optimizer for the network parameters.

# torch.optim.Adam includes several parameters:
#   * params: Iterable of network parameters to optimize.
#   * lr: Learning rate (default here: 0.001). Tuning this affects convergence speed.
#   * betas: Tuple for coefficients used in computing running averages (default: (0.9, 0.999)).
#   * eps: Term added for numerical stability (default: 1e-08).
#   * weight_decay: L2 penalty (default: 0).
#   * amsgrad: Boolean flag to enable the AMSGrad variant (default: False).

optim = torch.optim.Adam(
    net.parameters(),                # Network parameters to be optimized.
    lr=0.001,                        # Learning rate for updating parameters.
    # betas=(0.9, 0.999),            # Coefficients for running averages; adjust for optimization behavior.
    # eps=1e-08,                     # Epsilon for numerical stability; modify if experiencing issues.
    # weight_decay=0,                # L2 regularization factor to prevent overfitting.
    # amsgrad=False                  # Use AMSGrad variant if desired.
)

# Print the optimizer to verify its configuration.
print("Optimizer configuration:")
print(optim)

Now that we have a network and an optimizer let's define a policy that will control how learning takes place.

Let's use a pre-built policy first (we'll open it up and take a look inside... but first let's get this all working!)

In [None]:
# Create the DQN policy

policy = ts.policy.DQNPolicy(
    model=net,                          # Q-network approximating state-action values.
    optim=optim,                        # Network optimizer.
    discount_factor=0.9,                # Gamma. Balances immediate and future rewards. Values near 1 favor long-term rewards.
    estimation_step=3,                  # Number of steps used in n-step return calculations. Increasing this incorporates more future reward but may also increase variance.
    target_update_freq=320,             # Steps between updating the target network. A higher value results in more stable targets.
    action_space=env.action_space,      # Informs the policy about available actions.
    # reward_normalization=False,       # Flag to normalize rewards. Set to True if rewards have large variance or different scales.
)

# Print the policy to verify its configuration.
print("Policy configuration:")
print(policy)

Now let's setup a collector to feed observations to the policy as the agent interacts with it's environment.

> We'll add a test collector that will run tests periodically to see how well our agent is performing.

In [None]:
# Setup the training and testing data collectors.
# The Collector class in Tianshou collects experiences from interactions with the environment.

# Its useful parameters include:
#   * policy: The policy instance used to interact with the environment.
#   * env: The environment instance from which to collect experiences.
#   * buffer: The memory buffer that stores experiences; here, ReplayBuffer is used.
#   * exploration_noise: (Optional) Adds noise for continuous control; not used for DQN.
#   * preprocess_fn: (Optional) A function to preprocess data before storage.
#   * Additional hidden parameters may involve trajectory recording and statistics.

train_collector = Collector(
    policy,                         # Policy used to select actions during data collection.
    env,                            # Environment from which to collect experiences.
    ReplayBuffer(10000),            # Replay buffer with a capacity of 10,000 experiences.
    # exploration_noise=False,      # For DQN, epsilon-greedy is used; leave as default.
    # preprocess_fn=None,           # Optional preprocessing function for modifying data before storage.
)

# Setup the test collector for evaluation.
# Similar to the training collector but typically without additional exploration noise.
test_collector = Collector(
    policy,                         # Policy used for evaluation.
    env,                            # Environment for testing the policy performance.
    # exploration_noise=False,      # Typically disabled to ensure deterministic actions during testing.
)

Now that we have:

1. An environment
2. A Policy with a network model and an optimizer
3. A collector to store the agent experiences

We can now train our agent!

We'll use an Off Policy Trainer for now. 

In [None]:
import os
import shutil
import subprocess
import tempfile
import time

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import ReplayBuffer, Collector, Batch
import tianshou as ts
from tqdm.notebook import tqdm
from IPython.display import IFrame, display

def kill_port(port):
    """
    Terminates any processes that are listening on the specified port.
    Works on both Unix-based systems and Windows.
    """
    try:
        if os.name == 'nt':
            # Windows: Use netstat and taskkill to kill processes on the given port.
            # The command below might fail (exit status 1) if no process is found.
            cmd = f'for /f "tokens=5" %a in (\'netstat -aon ^| findstr :{port}\') do taskkill /F /PID %a'
            result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            print(f"Killed processes on port {port}.")
        else:
            # Unix (Linux/Mac): Use lsof to find processes on the port and kill them.
            cmd = f"lsof -ti:{port} | xargs kill -9"
            result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            print(f"Killed processes on port {port}.")
    except subprocess.CalledProcessError as e:
        # If the error message indicates that no process was found, we can ignore it.
        if "returned non-zero exit status 1" in str(e):
            pass
        else:
            print(f"Could not kill process on port {port}: {e}")

# Kill any processes on port 6006 to ensure it is free.
kill_port(6006)

# Clear previous TensorBoard sessions (cross-platform)
tensorboard_info = os.path.join(tempfile.gettempdir(), ".tensorboard-info")
if os.path.exists(tensorboard_info):
    shutil.rmtree(tensorboard_info)

# Launch TensorBoard in the background on port 6006.
tb_command = [
    "tensorboard",
    "--logdir", logs_dir,
    "--port", "6006",
    "--host", "localhost",
    "--reload_interval", "30"
]
tb_process = subprocess.Popen(tb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

# Allow time for TensorBoard to start and display its dashboard.
time.sleep(5)
display(IFrame(src="http://localhost:6006", width="100%", height="800"))

#------------------------------------------------------------------------------ 

# Configure and run the training loop using OffpolicyTrainer.

# OffpolicyTrainer includes numerous parameters for fine-tuning the training process:
#   * policy: The DQN policy instance to be trained.
#   * train_collector: Collector that gathers training experiences.
#   * test_collector: Collector that gathers evaluation data.
#   * max_epoch: Maximum number of epochs for training (here: 10). Increase for longer training.
#   * step_per_epoch: Number of environment steps per epoch (here: 10000). Adjust based on task complexity.
#   * step_per_collect: Steps to collect between policy updates (here: 30). Lower values result in more frequent updates.
#   * episode_per_test: Number of episodes used for evaluation in each test phase (here: 100).
#   * batch_size: Mini-batch size for training updates (here: 64). Adjust based on available memory.
#   * update_per_collect: Ratio of update steps per collected environment step (here: 1/10). Fine-tune for optimal learning.
#   * train_fn: Function called during training at each epoch; here used to set epsilon for exploration.
#   * test_fn: Function called during testing; typically sets a lower epsilon for evaluation.
#   * stop_fn: Function to determine when to stop training (e.g., when a reward threshold is reached).
#   * logger: Logger instance to record training progress and metrics.
#   * save_checkpoint_fn: Function to save training checkpoints (default: None).

trainer = OffpolicyTrainer(
    policy=policy,                                                              # DQN policy to be trained.
    train_collector=train_collector,                                            # Collector for training experiences.
    test_collector=test_collector,                                              # Collector for evaluation.
    max_epoch=5,                                                                # Total training epochs.
    step_per_epoch=10000,                                                       # Environment steps per epoch.
    step_per_collect=100,                                                       # Steps to collect between each update.
    episode_per_test=10,                                                        # Episodes run during each evaluation phase.
    batch_size=64,                                                              # Mini-batch size for training updates.
    update_per_step=1 / 10,                                                     # Ratio of gradient updates per collected step.
    train_fn=lambda epoch, env_step: policy.set_eps(0.1),                       # Function to adjust training parameters (e.g., epsilon).
    test_fn=lambda epoch, env_step: policy.set_eps(0.05),                       # Function to adjust evaluation parameters.
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,     # Stops training when the mean reward exceeds the threshold.
    logger=logger,                                                              # Logger for tracking training progress.
    # save_checkpoint_fn=None                                                   # Optional function to save training checkpoints.
).run()

# Print the full training summary.
print("\nTraining Summary:\n")
for key, value in trainer.items():
    print(f"{key}: {value}")

***
## **Walkthrough**: SARSA Policy Using Temporal-Difference (TD) Learning

In this section we build a custom policy that implements the SARSA algorithm using TD-learning.

Instead of training offline on a fixed dataset, this approach learns continuously from the agent’s interactions with the environment. We define a custom policy class called `SARSAPolicy` that inherits from Tianshou’s `BasePolicy`.

We will:

- Define a neural network (Q-network) for estimating state–action values.
- Create a custom SARSA policy that selects actions using an epsilon-greedy strategy and updates the network parameters with TD-learning.
- Build a Gymnasium environment, train the agent using a custom training loop, and then save and evaluate the trained model.

For visualization and logging, we use TensorBoard again. Comments within the code explain each parameter and its impact, as well as how the deep learning methods are applied.

The policy includes three key methods:

- **`__init__`**: Initializes the policy by setting the model, optimizer, and hyperparameters.
- **`forward`**: Receives the current observation and returns an action based on predicted state–action values.
- **`learn`**: Updates the model using TD-learning by comparing the estimated state–action value with the TD target.

```python
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from tianshou.policy import BasePolicy
    from tianshou.data import Batch

    # Custom SARSA policy class
    class SARSAPolicy(BasePolicy):
        """
        Custom SARSA policy implementing TD-learning for continuous online training.
        """
        def __init__(self, model, optim, action_space, gamma=0.99, epsilon=0.1):
            pass

        def forward(self, batch: Batch, state=None, **kwargs) -> Batch:
            """
            Forward pass to compute Q-values and select actions.
            """
            pass

        def learn(self, batch: Batch, next_action: torch.Tensor, **kwargs) -> dict:
            """
            Update the model using TD-learning.
            """
            pass
```

#### 1. **SARSA Policy Initialization**

Let's look at the initialization of the policy first.

The `__init__` method of `SARSAPolicy` sets up the model, optimizer, and important hyperparameters such as the discount factor $ \gamma $ and exploration probability $ \epsilon $. This ensures that the policy has all the necessary components for both action selection and learning.

When we initalize the training policy we'll provide information about:
* the *model*, this will be used to choose actions given states
* the *optim*, the optim (optimizer) will be used to update the model, i.e., it's how the agent will learn.
* *gamma*, defines how the agent values future rewards vs. immediate rewards.
* *epsilon*, defines the probability that the agent will select an action at random.

```python
        class SARSAPolicy(BasePolicy):
        """
        Custom SARSA policy implementing TD-learning for continuous online training.

        Args:
                model (nn.Module): Neural network that predicts Q-values.
                optim (torch.optim.Optimizer): Optimizer for training the model.
                action_space (gym.Space): The environment's action space.
                gamma (float, optional): Discount factor for future rewards. Default is 0.99.
                epsilon (float, optional): Probability for epsilon-greedy exploration. Default is 0.1.
        """
        def __init__(self, model, optim, action_space, gamma=0.99, epsilon=0.1):
                super().__init__(action_space=action_space)
                self.model = model
                self.optim = optim
                self.gamma = gamma
                self.epsilon = epsilon
```

#### 2. **Action Selection with the Forward Method**

- The `forward` method of `SARSAPolicy` computes Q-values for a batch of observations.
- An epsilon-greedy strategy is implemented: the highest-value action is chosen most of the time, but a random action is occasionally selected for exploration.

```python
    def forward(self, batch: Batch, state=None, **kwargs) -> Batch:
        """
        Forward pass to compute Q-values and select actions.

        Uses an epsilon-greedy strategy: with probability epsilon, a random action is chosen;
        otherwise, the action with the highest estimated Q-value is selected.

        Args:
            batch (Batch): A batch of observations.
            state: (optional) The state information for recurrent models.
            **kwargs: Additional keyword arguments.

        Returns:
            Batch: A batch containing the chosen actions.
        """
        # Predict Q-values from the model given current observations.
        q_values, _ = self.model(batch.obs)
        # Greedy action: select the action with the highest Q-value.
        act = q_values.argmax(dim=1).cpu().numpy()

        # Epsilon-greedy exploration: occasionally select a random action.
        if np.random.rand() < self.epsilon:
            act = np.random.randint(0, q_values.shape[1], size=act.shape)

        return Batch(act=act)
```

#### 3. **Learning with TD-Learning**

Now that we initialized our policy with a model, an optimizer, and some parameters, as well as setup how the agent will choose actions using the model, we need to think about how the agent will learn. This is an important step. We'll use TD-learning here to update the model so that it can predict the value of actions. It will do this in a few steps:

* Estimate the current state,action values: i.e., Q(s,a) or q_values.
* Estimate the TD target: this is the expected value of the next state, action value.
* Use the difference between the current q_value and td-target to update the neural network.
* If we repeatedly do this the network should make better estimates of the q_values of state/action pairs.

```python
    def learn(self, batch: Batch, next_action: torch.Tensor, **kwargs) -> dict:
        """
        Update the model using TD-learning.

        Steps:
        1. Compute current state-action values Q(s,a).
        2. Estimate the TD target using the observed reward and the predicted next Q-values.
        3. Calculate the mean squared error loss between the current Q-values and the TD target.
        4. Perform a gradient step with gradient clipping to update the model.

        Args:
            batch (Batch): A batch containing observations, actions, rewards, etc.
            next_action (torch.Tensor): The actions chosen for the next state.
            **kwargs: Additional keyword arguments.

        Returns:
            dict: A dictionary with the computed loss.
        """
        # Get the predicted Q-values for the current state.
        q_values, _ = self.model(batch.obs)

        # Ensure actions are a tensor and have the correct shape.
        batch_act = (
            torch.tensor(batch.act, dtype=torch.long)
            if not isinstance(batch.act, torch.Tensor)
            else batch.act
        )
        # Gather Q-values corresponding to the chosen actions.
        q_values = q_values.gather(1, batch_act.unsqueeze(1)).squeeze(1)

        # Compute the TD target without tracking gradients.
        with torch.no_grad():
            next_q_values, _ = self.model(batch.obs_next)
            next_q_values = next_q_values.gather(1, next_action.unsqueeze(1)).squeeze(1)
            td_target = batch.rew + self.gamma * (1 - batch.done) * next_q_values

        # Compute the mean squared error loss between current Q-values and the TD target.
        loss = F.mse_loss(q_values, td_target)

        # Reset gradients, backpropagate loss, clip gradients, and update weights.
        self.optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optim.step()

        return {"loss": loss.item()}
```

In the `learn` method, the agent computes the temporal-difference (TD) target using:
  
  $$ TD_{target} = r + \gamma (1 - d) Q(s', a') $$
  
  where $ r $ is the reward, $ \gamma $ is the discount factor, $ d $ indicates if the state is terminal, and $ Q(s', a') $ is the next state–action value.
The loss is computed using the mean squared error between the current Q-value estimate and the TD target.
The optimizer then updates the model parameters, using gradient clipping to prevent large updates.

Now we've defined the policy we have to define a neural network (i.e., the model) that the policy can use!

### Neural Network Model (QNet)

Let's look at how to create an agent brain. The agent’s “brain” is a neural network model that estimates Q-values given the current state. To do this we'll create a new `QNet` class that extends PyTorch’s `nn.Module` and defines the network structure and forward pass.

```python
    import numpy as np
    import torch
    import torch.nn as nn

    class QNet(nn.Module):
    """
    Neural network model to predict Q-values from state observations.
    """
        def __init__(self, state_shape, action_shape, hidden_size=128):
            pass

        def forward(self, obs: torch.Tensor, state=None, info: dict = {}) -> tuple:
            """
            Forward pass to compute Q-values.
            """
            pass
```

#### 4. **Neural Network (QNet) Initialization**


Let's take a look at the initialization first. When we initialize the agent's brain we'll provide information about:
* the state_shape, this will be the size of the observation space, i.e., how many variables act as inputs for the agent?
* the action_shape, the number of actions the agent can take.
* hidden_size, defines the number of nodes in each hidden layer of the neural network.

The `QNet` class constructs a network using fully connected layers. We flatten the input state (using $ \prod(state\_shape) $) to feed it into dense layers, and apply normalization layers to stabilize training.

```python    
    class QNet(nn.Module):
        """
        Neural network model to predict Q-values from state observations.

        Args:
            state_shape (tuple): Shape of the input state.
            action_shape (tuple): Shape of the action space.
            hidden_size (int, optional): Number of units in the hidden layers. Default is 128.
        """
        def __init__(self, state_shape, action_shape, hidden_size=128):
            super().__init__()
            # Build the neural network layers.
            self.net = nn.Sequential(
                nn.Linear(np.prod(state_shape), hidden_size),  # Dense layer to hidden_size units.
                nn.ReLU(),                                      # Activation function.
                nn.LayerNorm(hidden_size),                      # Normalization to stabilize training.
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.LayerNorm(hidden_size),
                nn.Linear(hidden_size, np.prod(action_shape))   # Output layer for Q-values.
            )
```

The super().__init__() initializes the network based on PyTorch's nn.Module. Once this is done, let's add a neural network that the agent can use to predict the best action to take in the current context. Note: here we are only using the current state to make predictions about actions, we'l see later on how we can add in a trajectory of states from the recent past to help our agent make better decision in environments that are dynamic and stochastic.

We'll cover how to build your own neural networks a little further along in the tutorials, but for now it's enough to know that we are building a neural network with 3 layers, using a RELU function to connect them, and doing some layerNormalization to avoid large changes in edge weights. The input will be the observed variables and the output will be the action taken.

#### 5. **Model Forward Pass**

Next let's look at how to build the *forward* method to predict actions:

```python
    def forward(self, obs: torch.Tensor, state=None, info: dict = {}) -> tuple:
        """
        Forward pass to compute Q-values.

        Converts the input observation to a tensor if it is a numpy array, then passes it
        through the network.

        Args:
            obs (torch.Tensor or np.ndarray): The observation/state input.
            state: (optional) State information for recurrent networks.
            info (dict, optional): Additional information (unused here).

        Returns:
            tuple: A tuple containing:
                - q_values (torch.Tensor): Predicted Q-values for each action.
                - state: Unchanged state (for compatibility with recurrent models).
        """
        # Convert numpy arrays to torch tensors.
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float32)
        # Compute Q-values from the network.
        q_values = self.net(obs)
        return q_values, state
```

The `forward` method in `QNet` converts observations to a tensor (if necessary) and feeds them through the network to produce Q-values. This method is critical for both action selection and learning.

***
## **Full Implementation and Testing**: SARSA Policy Using Temporal-Difference (TD) Learning

Now that we've gone through the code piece-by-piece, let's take a look at that full code.

### 1. Preliminary Code

In [None]:
# Standard library imports
from datetime import datetime
import os
import shutil
import subprocess
import tempfile
import time

# Third-party imports
import gymnasium as gym
import pygame
import torch
from IPython.display import IFrame, display
from torch.utils.tensorboard import SummaryWriter

# Local application/library-specific imports
import tianshou as ts
from tianshou.data import Collector, ReplayBuffer
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils.net.common import Net

# Timestamped ID for this run to avoid overwriting previous runs and to keep track of different runs
agent_id = datetime.now().strftime("%Y%m%d_%H%M%S")  # Format: YYYYMMDD_HHMMSS

# Setup directories for saving models and logs
sarsa_td_dir = f"sarsa_td_cartpole_{agent_id}"
logs_dir = os.path.join(sarsa_td_dir, "logs")
models_dir = os.path.join(sarsa_td_dir, "models")

os.makedirs(sarsa_td_dir, exist_ok=True) # Ensure the directory exists
print(f"All files for this run will be saved in the directory: {sarsa_td_dir}")
os.makedirs(logs_dir, exist_ok=True) # Ensure the directory exists
print(f"Tensorboard logs will be saved in the directory: {logs_dir}")
os.makedirs(models_dir, exist_ok=True) # Ensure the directory exists
print(f"Models will be saved in the directory: {models_dir}")

# Create a logger
logger = ts.utils.TensorboardLogger(SummaryWriter(logs_dir))
print(f"TensorBoard logs are being saved in: {logs_dir}")

### 2. SARSA Policy (`SARSAPolicy`) Full Code

The `SARSAPolicy` class extends Tianshou’s `BasePolicy` and implements an epsilon-greedy strategy. It also performs the learning update using TD-learning by computing the loss between the current Q-values and the TD target.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tianshou.policy import BasePolicy
from tianshou.data import Batch

class SARSAPolicy(BasePolicy):
    """
    Custom SARSA policy implementing TD-learning.

    Args:
        model (nn.Module): Neural network that predicts Q-values.
        optim (torch.optim.Optimizer): Optimizer for updating model parameters.
        action_space (gym.Space): The action space of the environment.
        gamma (float, optional): Discount factor for future rewards. Default is 0.99.
        epsilon (float, optional): Exploration probability for epsilon-greedy action selection. Default is 0.1.
    """
    def __init__(self, model, optim, action_space, gamma=0.99, epsilon=0.1):
        super().__init__(action_space=action_space)
        self.model = model
        self.optim = optim
        self.gamma = gamma
        self.epsilon = epsilon

    def forward(self, batch: Batch, state=None, **kwargs) -> Batch:
        """
        Select actions using the model's Q-values.

        Implements an epsilon-greedy strategy:
            - With probability ε, a random action is chosen.
            - Otherwise, the action with the highest Q-value is selected.

        Args:
            batch (Batch): Batch containing observations.
            state: Optional state for recurrent models.
            **kwargs: Additional arguments.

        Returns:
            Batch: Batch containing selected actions.
        """
        q_values, _ = self.model(batch.obs)
        # Greedy action selection
        act = q_values.argmax(dim=1).cpu().numpy()
        # Epsilon-greedy exploration: random action with probability epsilon
        if np.random.rand() < self.epsilon:
            act = np.random.randint(0, q_values.shape[1], size=act.shape)
        return Batch(act=act)

    def learn(self, batch: Batch, next_action: torch.Tensor, **kwargs) -> dict:
        """
        Update the model using TD-learning.

        Steps:
            1. Compute current Q-values for the actions taken.
            2. Estimate the TD target: TD_target = r + γ(1 - d) Q(s', a')
            3. Compute the mean squared error (MSE) loss between the current Q-values and the TD target.
            4. Backpropagate the loss and update the model parameters using the optimizer.
            5. Apply gradient clipping to limit the change in weights.

        Args:
            batch (Batch): Batch containing observations, actions, rewards, etc.
            next_action (torch.Tensor): Next action chosen based on the next observation.
            **kwargs: Additional arguments.

        Returns:
            dict: Dictionary containing the computed loss.
        """
        # Predict Q-values for current states
        q_values, _ = self.model(batch.obs)
        # Ensure actions are a tensor with correct type and shape
        batch_act = (
            torch.tensor(batch.act, dtype=torch.long)
            if not isinstance(batch.act, torch.Tensor)
            else batch.act
        )
        # Gather the Q-values corresponding to the actions taken
        q_values = q_values.gather(1, batch_act.unsqueeze(1)).squeeze(1)

        # Compute the TD target without tracking gradients
        with torch.no_grad():
            next_q_values, _ = self.model(batch.obs_next)
            next_q_values = next_q_values.gather(1, next_action.unsqueeze(1)).squeeze(1)
            td_target = batch.rew + self.gamma * (1 - batch.done) * next_q_values

        # Compute the loss: MSE between current Q-values and TD target
        loss = F.mse_loss(q_values, td_target)

        # Zero gradients, backpropagate loss, clip gradients, and update parameters
        self.optim.zero_grad()
        loss.backward()
        # Clip gradients to avoid large updates; max_norm can be tuned as needed
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optim.step()

        return {"loss": loss.item()}

### 3. Q-Network (`QNet`) Full Code

The `QNet` class builds a fully connected neural network to predict Q-values. It flattens the state input, passes it through two hidden layers with ReLU activations and layer normalization, and finally outputs the Q-values for each action.

In [None]:
import numpy as np
import torch
import torch.nn as nn

class QNet(nn.Module):
    """
    Neural network model to predict Q-values from state observations.

    Args:
        state_shape (tuple): Shape of the input state.
        action_shape (int or tuple): Number of actions available.
        hidden_size (int, optional): Number of neurons in hidden layers. Default is 128.
    """
    def __init__(self, state_shape, action_shape, hidden_size=128):
        super().__init__()
        # Build the sequential model:
        # - First layer: flatten the state and map to hidden_size neurons.
        # - ReLU activation introduces non-linearity.
        # - LayerNorm stabilizes training by normalizing activations.
        # - Second layer: another hidden layer with the same structure.
        # - Output layer: maps to the number of actions.
        self.net = nn.Sequential(
            nn.Linear(np.prod(state_shape), hidden_size),  # Dense layer
            nn.ReLU(),                                      # Activation function
            nn.LayerNorm(hidden_size),                      # Normalization layer
            nn.Linear(hidden_size, hidden_size),            # Hidden layer
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, np.prod(action_shape))   # Output layer for Q-values
        )

    def forward(self, obs: torch.Tensor, state=None, info: dict = {}) -> tuple:
        """
        Forward pass to compute Q-values.

        Converts input observations to torch tensors (if not already) and passes them
        through the network to produce Q-values.

        Args:
            obs (torch.Tensor or np.ndarray): Input observation.
            state: Optional state for recurrent models (unused here).
            info (dict): Additional information (unused).

        Returns:
            tuple: (q_values, state)
        """
        # Convert numpy array to torch tensor if necessary
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float32)
        q_values = self.net(obs)
        return q_values, state

### 4. Create the Environment

Let's create the environments!

We initialize two environments: one for training and one for testing. The state and action shapes are extracted to build the network correctly.

In [None]:
import gymnasium as gym

# Create training and testing environments using Gymnasium
env = gym.make("CartPole-v1")
test_env = gym.make("CartPole-v1")

# Extract environment information:
# - state_shape: dimensions of the observation space
# - action_shape: number of actions
state_shape = env.observation_space.shape or (env.observation_space.n,)
action_shape = env.action_space.n if hasattr(env.action_space, "n") else env.action_space.shape
action_space = env.action_space

### 5. Build the Q-Network, Optimizer, and SARSA Policy

Here we instantiate the Q-network and set up the optimizer. We then create our custom SARSA policy by passing the network, optimizer, and hyperparameters.

In [None]:
import torch.optim as optim

# Select the appropriate device: CUDA (NVIDIA GPUs), MPS (Apple GPUs), or CPU.
# AMD GPUs with ROCm support accessed using the 'cuda' device string.
device = torch.device("cuda" if torch.cuda.is_available() else
                      "mps" if torch.backends.mps.is_available() else
                      "cpu")
print(f"Using device: {device}")

# Initialize the Q-network with state and action shapes
net = QNet(state_shape, action_shape).to(device)

# Set up the Adam optimizer:
# - lr: learning rate (can be tuned for faster or slower convergence)
# - weight_decay: L2 regularization to prevent overfitting
optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=0)

# Create the SARSA policy using the network and optimizer.
# gamma: discount factor; epsilon: exploration rate for epsilon-greedy strategy.
policy = SARSAPolicy(model=net, optim=optimizer, action_space=action_space, gamma=0.995, epsilon=0.1)

### 6. Custom Online Training Loop

We train the agent using an online approach. This means we collect experience transitions, perform manual conversion of data to torch tensors (if needed), and apply the SARSA learning update. The loop logs rewards and loss for each epoch.

In [None]:
import os
import shutil
import subprocess
import tempfile
import time

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import ReplayBuffer, Collector, Batch
import tianshou as ts
from tqdm.notebook import tqdm
from IPython.display import IFrame, display

def kill_port(port):
    """
    Terminates any processes that are listening on the specified port.
    Works on both Unix-based systems and Windows.
    """
    try:
        if os.name == 'nt':
            # Windows: Use netstat and taskkill to kill processes on the given port.
            # The command below might fail (exit status 1) if no process is found.
            cmd = f'for /f "tokens=5" %a in (\'netstat -aon ^| findstr :{port}\') do taskkill /F /PID %a'
            result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            print(f"Killed processes on port {port}.")
        else:
            # Unix (Linux/Mac): Use lsof to find processes on the port and kill them.
            cmd = f"lsof -ti:{port} | xargs kill -9"
            result = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            print(f"Killed processes on port {port}.")
    except subprocess.CalledProcessError as e:
        # If the error message indicates that no process was found, we can ignore it.
        if "returned non-zero exit status 1" in str(e):
            pass
        else:
            print(f"Could not kill process on port {port}: {e}")

# Kill any processes on port 6006 to ensure it is free.
kill_port(6006)

# Clear previous TensorBoard sessions (cross-platform)
tensorboard_info = os.path.join(tempfile.gettempdir(), ".tensorboard-info")
if os.path.exists(tensorboard_info):
    shutil.rmtree(tensorboard_info)

# Launch TensorBoard in the background on port 6006.
tb_command = [
    "tensorboard",
    "--logdir", logs_dir,
    "--port", "6006",
    "--host", "localhost",
    "--reload_interval", "30"
]
tb_process = subprocess.Popen(tb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

# Allow time for TensorBoard to start and display its dashboard.
time.sleep(5)
display(IFrame(src="http://localhost:6006", width="100%", height="800"))

#------------------------------------------------------------------------------ 
# Hyperparameters for training
max_epoch = 3               # Total number of epochs for training
step_per_epoch = 300        # Steps collected per epoch
keep_n_steps = 300          # Number of recent steps to use for learning

# Create a ReplayBuffer to store a fixed number of recent transitions.
buffer = ReplayBuffer(size=keep_n_steps)

# Set up data collectors for training and testing.
train_collector = Collector(policy, env, buffer)
test_collector = Collector(policy, test_env)

# Lists to store epoch summaries.
epoch_training_losses = []
epoch_test_rewards = []
epoch_durations = []

global_start_time = time.time()  # Start overall timer

# Training loop with progress reporting, logging, and detailed summaries.
for epoch in range(max_epoch):
    epoch_start_time = time.time()  # Start timer for the epoch
    train_collector.reset()
    running_loss = 0.0              # Accumulate loss to compute average loss per epoch

    # Set up a tqdm progress bar with dynamic post-fix metrics.
    progress_bar = tqdm(range(step_per_epoch),
                        desc=f"Epoch {epoch+1}/{max_epoch}",
                        dynamic_ncols=True)
    
    for step in progress_bar:
        # Collect n steps and store transitions in the buffer.
        train_collector.collect(n_step=keep_n_steps)
        
        # Retrieve the last keep_n_steps transitions from the buffer.
        batch = train_collector.buffer[-keep_n_steps:]
        
        # Convert batch fields to torch tensors.
        batch.obs = torch.tensor(batch.obs, dtype=torch.float32)
        batch.act = torch.tensor(batch.act, dtype=torch.long)
        batch.rew = torch.tensor(batch.rew, dtype=torch.float32)
        batch.done = torch.tensor(batch.done, dtype=torch.float32)
        batch.obs_next = torch.tensor(batch.obs_next, dtype=torch.float32)
        
        # Determine the next action using the current policy.
        next_action = policy.forward(Batch(obs=batch.obs_next)).act
        next_action_tensor = torch.tensor(next_action, dtype=torch.long)
        
        # Perform the SARSA update and capture the loss.
        learn_info = policy.learn(batch, next_action_tensor)
        loss_val = learn_info.get("loss", 0)
        running_loss += loss_val
        
        global_step = epoch * step_per_epoch + step
        
        # Log step-level loss continuously to TensorBoard.
        logger.writer.add_scalar("Loss/train_step", loss_val, global_step)
        logger.writer.add_scalar("Loss/train_running_avg", running_loss / (step + 1), global_step)
        
        # Optionally flush periodically (or at the end of each epoch).
        if step % 50 == 0:
            logger.writer.flush()
        
        # Update progress bar with current metrics.
        progress_bar.set_postfix({
            "Step": f"{step}/{step_per_epoch}",
            "Loss": f"{loss_val:07.3f}",
            "AvgLoss": f"{running_loss / (step + 1):07.3f}"
        })
        
        # Print progress summary every 25% of the epoch.
        if step % (step_per_epoch // 4) == 0:
            print(
                f"Epoch {epoch+1}, Step {step}/{step_per_epoch}: "
                f"Step Loss = {loss_val}, Running Avg Loss = {running_loss/(step+1)}"
            )
    
    # Compute average loss over the epoch.
    avg_loss = running_loss / step_per_epoch

    # Reset the test collector and evaluate the agent on 10 episodes.
    test_collector.reset()
    test_result = test_collector.collect(n_episode=10)
    mean_reward = np.mean(test_result["rews"])
    std_reward = np.std(test_result["rews"])
    min_reward = np.min(test_result["rews"])
    p25_reward = np.percentile(test_result["rews"], 25)
    median_reward = np.median(test_result["rews"])
    p75_reward = np.percentile(test_result["rews"], 75)
    max_reward = np.max(test_result["rews"])

    # Log epoch-level metrics to TensorBoard.
    logger.writer.add_scalar("Reward/test_avg", mean_reward, epoch)
    logger.writer.add_scalar("Loss/train_avg", avg_loss, epoch)
    logger.writer.flush()

    # Calculate epoch elapsed time.
    epoch_elapsed = time.time() - epoch_start_time
    epoch_training_losses.append(avg_loss)
    epoch_test_rewards.append(mean_reward)
    epoch_durations.append(epoch_elapsed)

    # Print detailed epoch summary.
    print(
        f"\nEpoch {epoch+1} Summary:\n"
        f"  - Epoch Elapsed Time      : {epoch_elapsed} seconds\n"
        f"  - Steps Collected         : {step_per_epoch}\n"
        f"  - Average Training Loss   : {avg_loss}\n"
        f"  - Mean Test Reward        : {mean_reward}\n"
        f"  - Std Test Reward         : {std_reward}\n"
        f"  - Min Test Reward         : {min_reward}\n"
        f"  - 25th Percentile Reward  : {p25_reward}\n"
        f"  - Median Test Reward      : {median_reward}\n"
        f"  - 75th Percentile Reward  : {p75_reward}\n"
        f"  - Max Test Reward         : {max_reward}\n"
    )

# Final flush and close the TensorBoard writer.
logger.writer.close()

# Calculate overall training statistics.
total_elapsed = time.time() - global_start_time
overall_avg_loss = np.mean(epoch_training_losses)
overall_avg_reward = np.mean(epoch_test_rewards)
total_epochs = len(epoch_durations)

# Print overall summary of all epochs.
print("\nOverall Training Summary:")
print(f"  - Total Epochs            : {total_epochs}")
print(f"  - Overall Average Loss    : {overall_avg_loss}")
print(f"  - Overall Average Reward  : {overall_avg_reward}")
print(f"  - Total Elapsed Time      : {total_elapsed} seconds")

# Reiterate the final epoch's metrics.
print("\nFinal Epoch Summary:")
print(
    f"  - Epoch {total_epochs}:\n"
    f"      * Average Training Loss : {epoch_training_losses[-1]}\n"
    f"      * Average Test Reward   : {epoch_test_rewards[-1]}\n"
    f"      * Epoch Elapsed Time    : {epoch_durations[-1]} seconds\n"
)

**Notes on Key Parameters:**

- **Learning Rate (`lr`)** in Adam: A smaller value (e.g., 1e-5) means slower but more stable updates. Increase if learning is too slow.
- **Weight Decay:** Regularizes the model by penalizing large weights.
- **Gamma (`γ`):** Determines how much future rewards are considered. Values close to 1 favor long-term rewards.
- **Epsilon:** Controls exploration. A higher epsilon increases exploration but may delay convergence.

Did it learn? Do you see rewards increasing? 

If so let's save the model:

In [None]:
import os
import torch

# Save the model's state dictionary for future use.
model_path = os.path.join(models_dir, f"{sarsa_td_dir}.pth")
torch.save(net.state_dict(), model_path)
print(f"Model saved to {model_path}")

### 7. Load the Trained Model for Evaluation

After saving the model, load it into a new network instance and build a policy for evaluation. Note that we set epsilon to 0 for pure exploitation during testing.

In [None]:
import os
import torch
import gymnasium as gym
from tianshou.data import Batch  # Ensure Batch is available

# Select the appropriate device: CUDA (NVIDIA GPUs), MPS (Apple GPUs), or CPU.
# AMD GPUs with ROCm support accessed using the 'cuda' device string.
device = torch.device("cuda" if torch.cuda.is_available() else
                      "mps" if torch.backends.mps.is_available() else
                      "cpu")
print(f"Using device: {device}")

# Initialize a new network with the same architecture and move it to the device.
loaded_net = QNet(state_shape, action_shape).to(device)

# Load the trained weights, mapping to the appropriate device.
model_path = os.path.join(models_dir, f"{sarsa_td_dir}.pth")
loaded_net.load_state_dict(torch.load(model_path, map_location=device))
loaded_net.eval()  # Set the network to evaluation mode

# Create a testing environment with rendering enabled.
eval_env = gym.make("CartPole-v1", render_mode="human")

# Build a new SARSA policy for evaluation using the loaded network.
# Set epsilon=0 to disable exploration.
loaded_policy = SARSAPolicy(
    model=loaded_net,
    optim=optimizer,
    action_space=action_space,
    gamma=0.99,
    epsilon=0.0
)
print("Model loaded and evaluation environment initialized.")

### 8. Run the Agent in the Evaluation Environment

Finally, let's test out the model and watch it play the game! The agent will use the loaded model to select actions and interact with the environment.

In [None]:
import numpy as np
from tianshou.data import Batch  # Ensure Batch is available

# Number of episodes to run for evaluation.
num_episodes = 20
episode_rewards = []
episode_lengths = []

for episode in range(num_episodes):
    obs, _ = eval_env.reset()
    total_reward = 0
    step_count = 0

    print(f"Starting episode {episode + 1}")

    while True:
        step_count += 1

        # Create a Batch for the current observation.
        obs_batch = Batch(obs=[obs])
        # Disable gradient computations during inference.
        with torch.no_grad():
            action = loaded_policy.forward(obs_batch).act[0]

        # Step the environment with the chosen action.
        obs, reward, terminated, truncated, _ = eval_env.step(action)
        total_reward += reward

        # End the episode if terminated or truncated.
        if terminated or truncated:
            print(
                f"Episode {episode + 1} ended: Total Reward = {total_reward}, "
                f"Steps = {step_count}"
            )
            break

    episode_rewards.append(total_reward)
    episode_lengths.append(step_count)

# Convert lists to numpy arrays for statistical operations.
episode_rewards = np.array(episode_rewards)
episode_lengths = np.array(episode_lengths)

# Compute and print performance statistics if data has been collected.
if episode_rewards.size > 0 and episode_lengths.size > 0:
    # Calculate statistics for episode rewards.
    count_rewards   = len(episode_rewards)
    mean_rewards    = np.mean(episode_rewards)
    std_rewards     = np.std(episode_rewards)
    min_rewards     = np.min(episode_rewards)
    p25_rewards     = np.percentile(episode_rewards, 25)
    median_rewards  = np.median(episode_rewards)
    p75_rewards     = np.percentile(episode_rewards, 75)
    max_rewards     = np.max(episode_rewards)
    
    # Calculate statistics for episode lengths.
    count_lengths   = len(episode_lengths)
    mean_lengths    = np.mean(episode_lengths)
    std_lengths     = np.std(episode_lengths)
    min_lengths     = np.min(episode_lengths)
    p25_lengths     = np.percentile(episode_lengths, 25)
    median_lengths  = np.median(episode_lengths)
    p75_lengths     = np.percentile(episode_lengths, 75)
    max_lengths     = np.max(episode_lengths)
    
    # Print the summary table.
    print("\nFinal Evaluation Performance Summary:")
    print(f"Total Episodes Evaluated: {num_episodes}\n")
    header = "{:<22} {:>15} {:>20}".format("Statistic", "Rewards", "Episode Lengths")
    print(header)
    print("-" * len(header))
    print("{:<22} {:>15d} {:>20d}".format("Count", count_rewards, count_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("Mean", mean_rewards, mean_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("Std Dev", std_rewards, std_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("Min", min_rewards, min_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("25th Percentile", p25_rewards, p25_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("Median", median_rewards, median_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("75th Percentile", p75_rewards, p75_lengths))
    print("{:<22} {:>15.2f} {:>20.2f}".format("Max", max_rewards, max_lengths))
else:
    print("No performance data was collected. Please verify the Collector configuration.")

# Close the environment once evaluation is complete.
eval_env.close()
print("Evaluation completed and environment closed.")

***
**Things to Try**

Experiment with different hyperparameters and settings to see how they affect learning:

- **Epsilon:**  
  - When to explore?
  - *High value:* The agent explores more but the agent cannot take advantage of the environment as they are acting randomly.
  - *Low value:* The agent exploits learned behavior but may get stuck in suboptimal actions because the agent gets stuck learning the first useful thing rather than finding the best actions.

- **Learning Rate:**  
  - How fast to learn new data?
  - *High learning rate:* Faster updates may lead to unstable training because the agent learns spurious correlations between actions and outcomes.
  - *Low learning rate:* More stable updates but slower learning progress because the agent might take forever to figure out what actions lead to good rewards.

- **Discount Factor (γ):**  
  - How much does the agent value future vs. near rewards?
  - *High gamma:* Emphasizes future rewards, which may slow down learning of immediate actions. 
  - *Low gamma:* Focuses on short-term rewards and may miss long-term benefits.

Adjust these parameters based on the environment and your training goals. Visualize the learning progress with TensorBoard to understand the effects of your changes.