# ABE tutorial 2
## Looking at value based RL

In this second tutorial let's dive deeper into the RL algorithm that we used in the first tutorial. This time let's open up the policy and see how it learns in more detail.

Steps:
* Create a custom policy
* Check how it updates based on temporal difference learning



## Custom policy



### 1. Setup a simple RL example

Let's build a simple RL example based on what we learnt last tutorial.


In [None]:
#import libraries
import gymnasium as gym
import torch
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts
from tianshou.utils.net.common import Net

#start a logger
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn_custom'))

Now let's use the CartPole example an environment to train our agent in, and extract the observations (state_shape), and the action space (action_shape) .

In [None]:
from tianshou.utils.space_info import SpaceInfo

# Create an envrironment: render mode = human means we'd like to see the environment.
env = gym.make("CartPole-v1")

#start the environment at the "start"
env.reset()

#get all the info about it
space_info = SpaceInfo.from_env(env)

#What the agent 'sees'
state_shape = space_info.observation_info.obs_shape

#what actions the agent can take
action_shape = space_info.action_info.action_shape

In [None]:
print(space_info)

Let's start building our agent's brain. 

To start off let's build a neural network that take what the agent observes and converts that into actions. Then we'll add an optimizer to shift the weights in the network to better predict the value of states.

In [None]:


#build a network that takes observations and converts it to actions
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[64, 64])

#this will shift the network to better predict actions/values
optim = torch.optim.Adam(net.parameters(), lr=0.001)

Now that we have a network and an optimizer let's define a policy that will control how learning takes place.

Let's use a pre-built policy first (we'll open it up and take a look inside... but first let's get this all working!)

In [None]:
policy = ts.policy.DQNPolicy(
    model=net,
    optim=optim,
    discount_factor=0.9,
    action_space=env.action_space,
    estimation_step=3,
    target_update_freq=320
)

Now let's setup a collector to feed observations to the policy as the agent interacts with it's environment.

> We'll add a test collector that will run tests periodically to see how well our agent is performing.

In [None]:
train_collector = ts.data.Collector(policy, env, ts.data.VectorReplayBuffer(20000, 1), exploration_noise=True)
test_collector = ts.data.Collector(policy, env, exploration_noise=True)  # because DQN uses epsilon-greedy method (chooses best action, with some noise epsilon)

Now that we have:

1. An environment
2. A Policy with a network model and an optimizer
3. A collector to store the agent experiences

We can now train our agent!

We'll use an Off Policy Trainer for now. 

In [None]:
result = ts.trainer.OffpolicyTrainer(
    policy=policy,
    train_collector=train_collector,
    test_collector=test_collector,
    max_epoch=10,
    step_per_epoch=10000,
    step_per_collect=30,
    episode_per_test=100,
    batch_size=64,
    update_per_step=1 / 10,
    train_fn=lambda epoch, env_step: policy.set_eps(0.1),
    test_fn=lambda epoch, env_step: policy.set_eps(0.05),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
    logger=logger,
).run()
print(f"Finished training in {result.timing.total_time} seconds")

The code above can take a while to run! To see the training in progress you can launch tensorboard.

If you are using VSCode you can open command pallette and write:

```launch tensorboard```

### Custom Policy


Here we'll build a custom policy to better understand how TD-learning is working and how our agents are learning!

We'll also move from an offline training procedure, where the agent learns based on collected data after a training round, to an online training procedure, where the agent learns continuously on the fly in real time.  

To do this we will build a model that will collect states, actions, and rewards as the agent is interacting with the environment. The agent will then use this information to learn!

To create a custom policy we'll need to define how the agent will learn. To do this we'll create a new python class called SARSAPolicy. It will have three peices:

* *__init__*: and initialization method
* *forward*: a method that will take observations and predict actions, i.e., what action should the agent take? 
* *learn*: a method that will allow the agent to learn what the value of a state is, i.e., update the forward method

```python
# Custom SARSA policy class
class SARSAPolicy(BasePolicy):
    def __init__():
        #initialize the policy, i.e., set some parameters values

    def forward():
        
        #a model to predict actions

        return action

    def learn():

        # Estimate the value of a state based on TD-learning
```


#### Initialize the new policy

Let's look at the initialization of the policy first.

When we initalize the training policy we'll provide information about:
* the *model*, this will be used to choose actions given states
* the *optim*, the optim (optimizer) will be used to update the model, i.e., it's how the agent will learn.
* *gamma*, defines how the agent values future rewards vs. immediate rewards.
* *epsilon*, defines the probability that the agent will select an action at random.

```python
        def __init__(self, model, optim, action_space, gamma=0.99, epsilon=0.1):
                super().__init__(action_space=action_space)
                self.model = model
                self.optim = optim
                self.gamma = gamma
                self.epsilon = epsilon
```

#### Forward

We now add the forward method that takes values of states and actions Q(s,a) and uses those to choose what actions to take.


```python
    def forward(self, batch, state=None, **kwargs):
        
        #enter the observations into the neural network model to estimate the value of actions
        q_values, _ = self.model(batch.obs)

        #choose the action with the highest value
        act = q_values.argmax(dim=1).cpu().numpy()  # Greedy action

        #every so often choose a random action
        if np.random.rand() < self.epsilon:  # Epsilon-greedy exploration
            act = np.random.randint(0, q_values.shape[1], size=act.shape)
        
        #return the chosen action
        return Batch(act=act)
```

#### Learn

Now that we initialized our policy with a model, an optimizer, and some paramters, as well as setup how the agent will choose actions using the model, we need to think about how the agent will learn. This is an important step. We'll use TD-learning here to update the model so that it can predict the value of actions. It will do this in a few steps:

* Estimate the current state,action values: i.e., Q(s,a) or q_values.
* Estimate the TD target: this is the expected value of the next state, action value.
* Use the difference between the current q_value and td-target to update the neural network.
* If we repeatedly do this the network should make better estimates of the q_values of state/action pairs.

```python
    def learn(self, batch, next_action, **kwargs):

        #get the current estimate of state, action values
        q_values, _ = self.model(batch.obs)

        #get it into the right shape (this just ensures that the q_values are in a 1d tensor)
        batch_act = torch.tensor(batch.act, dtype=torch.long) if not isinstance(batch.act, torch.Tensor) else batch.act
        q_values = q_values.gather(1, batch_act.unsqueeze(1)).squeeze(1)

        # Estiamte the TD target: this is what was actually seen
        with torch.no_grad(): #this ensures that these calculations are not part of the optimization
            
            #predict what the value of state actions are going to be
            next_q_values, _ = self.model(batch.obs_next) 
            next_q_values = next_q_values.gather(1, next_action.unsqueeze(1)).squeeze(1)

            #TD is then the reward observed plus the value of the q_values in the next state
            td_target = batch.rew + self.gamma * (1 - batch.done) * next_q_values

        #now that we know the TD target, how close did our estimate of q_values get to the target?
        #with TD learning the value of a state is tied to the value of the next state.

        # calculate the difference between the current state action pariing (q_value) and the td target (based on the next state,action) 
        loss = nn.functional.mse_loss(q_values, td_target) #using MSE to measure the difference
        
        #make sure the optimizer is set back to zero (reset)
        self.optim.zero_grad()
        
        #based on the loss use backpropogation to adjust the neural network weights to make better predictions
        loss.backward()

        #limit the amount that weights can change
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  # Gradient clipping
        
        #Adjust the weights in the neural network
        self.optim.step()

        return {"loss": loss.item()}
```


Now that we've defined the policy we have to define a neural network (i.e., the model) that the policy can use!

### Nerual Network Model

Let's look at how to create an agent brain. To do this we'll create a new class called QNet. It will have two methods:

* *___init___*: this will initialize the network model
* *forward*: this will build a model that will take as input the state and output the action chosen.

```python
class QNet(nn.Module):
    
    def __init__():

        #initialize the neural netork model

    def forward(obs):
        
        #take observations (i.e., state) and make predictions about the value of actions (i.e., q_values)

        return q_values, state
```

#### Initializing


Let's take a look at the initialization first. When we initalize the agent's brain we'll provide information about:
* the state_shape, this will be the size of the observation space, i.e., how many varaibles act as inputs for the agent?
* the action_shape, the number of actions the agent can take.
* hidden_size, defines the number of nodes in each hidden layer of the neural network.

```python    
    #define the information needed to inialize the policy
    def __init__(self, state_shape, action_shape, hidden_size=128):
        super().__init__()
```

The super().__init__() initializes the network based on pytorches nn.Module. Once this is done, let's add a neural network that the agent can use to predict the best action to take in the current context. Note: here we are only using the current state to make predictions about actions, we'l see later on how we can add in a trajectory of states from the recent past to help our agent make better decision in environments that are dynamic and stochastic.

```python
        #build a neural network model
        self.net = nn.Sequential(
            nn.Linear(np.prod(state_shape), hidden_size), #also know as a dense layer
            nn.ReLU(),                          #Restricted linear unit activation function
            nn.LayerNorm(hidden_size),          #normalizes the edge weights (stops the weights from getting too big)
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, np.prod(action_shape))
        )
```

We'll cover how to build your own neural networks a little further along in the tutorials, but for now it's enough to know that we are building a neural network with 3 layers, using a RELU function to connect them, and doing some layerNormalization to avoid large changes in edge weights. The input will be the observed variables and the output will be the action taken.


#### Forward

Next let's look at how to build the *forward* method to predict actions:

```python
    def forward(self, obs, state=None, info={}):
        
        #convert observations to tensor (this is like an array but can be faster!)
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float32)
        
        #use the neural network to predict value of actions (i.e., q_values)
        q_values = self.net(obs)

        #return the estiamted value of actions and the state
        return q_values, state
```

### Full code

Now that we've gone through the code peice by peice, let's take a look at that full code.

In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
from tianshou.env import DummyVectorEnv
from tianshou.data import Batch, ReplayBuffer, Collector
from tianshou.policy import BasePolicy

class QNet(nn.Module):
    def __init__(self, state_shape, action_shape, hidden_size=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(np.prod(state_shape), hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, np.prod(action_shape))
        )

    def forward(self, obs, state=None, info={}):
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float32)
        q_values = self.net(obs)
        return q_values, state

# Custom SARSA policy class
class SARSAPolicy(BasePolicy):
    def __init__(self, model, optim, action_space, gamma=0.99, epsilon=0.1):
        super().__init__(action_space=action_space)
        self.model = model
        self.optim = optim
        self.gamma = gamma
        self.epsilon = epsilon

    def forward(self, batch, state=None, **kwargs):
        q_values, _ = self.model(batch.obs)
        act = q_values.argmax(dim=1).cpu().numpy()  # Greedy action
        if np.random.rand() < self.epsilon:  # Epsilon-greedy exploration
            act = np.random.randint(0, q_values.shape[1], size=act.shape)
        return Batch(act=act)

    def learn(self, batch, next_action, **kwargs):

        #use the model to get the current q_values
        q_values, _ = self.model(batch.obs)
        batch_act = torch.tensor(batch.act, dtype=torch.long) if not isinstance(batch.act, torch.Tensor) else batch.act
        q_values = q_values.gather(1, batch_act.unsqueeze(1)).squeeze(1)

        # Estimate TD target
        with torch.no_grad():
            next_q_values, _ = self.model(batch.obs_next)
            next_q_values = next_q_values.gather(1, next_action.unsqueeze(1)).squeeze(1)
            td_target = batch.rew + self.gamma * (1 - batch.done) * next_q_values

        # Use the difference between q_value and TD-target to update the model
        loss = nn.functional.mse_loss(q_values, td_target)
        self.optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)  # Gradient clipping
        self.optim.step()

        return {"loss": loss.item()}



Let's create an environment

In [None]:
# Create an environment
env =  gym.make("CartPole-v1") 
test_env = gym.make("CartPole-v1") 

#get the environment information
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.n if hasattr(env.action_space, 'n') else env.action_space.shape
action_space = env.action_space

Let's now build our network model, optimizer, and place both into our custom policy

In [None]:
# Setting up the Q-network and SARSA policy
net = QNet(state_shape, action_shape)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-5, weight_decay=1e-4)
policy = SARSAPolicy(model=net, optim=optimizer, action_space=action_space, gamma=0.99, epsilon=0.1)

Now that we have all the pieces in place, let's train our agent. However, let's take a online approach, and get our agent to learn on the fly! To do this we will create a custom training loop.

In [None]:

# Custom training loop
max_epoch = 10
step_per_epoch = 5000
keep_n_steps = 30
buffer = ReplayBuffer(size=keep_n_steps)

# Set up collectors to store data
train_collector = Collector(policy, env, buffer)
test_collector = Collector(policy, test_env)

#start a logger
logger_sarsa = ts.utils.TensorboardLogger(SummaryWriter('log/sarsa_custom'))


for epoch in range(max_epoch):
    
    #for each epoch reset the collector
    train_collector.reset()
    
    for step in range(step_per_epoch):
        
        # Collect one transition and store it in the buffer
        train_collector.collect(n_step=keep_n_steps)

        # Sample the most recent observations from the buffer
        batch = train_collector.buffer[-keep_n_steps:]
        
        # Manually convert each field to a torch tensor (for some reason to_torch isn't working here...)
        batch.obs = torch.tensor(batch.obs, dtype=torch.float32)
        batch.act = torch.tensor(batch.act, dtype=torch.long)
        batch.rew = torch.tensor(batch.rew, dtype=torch.float32)
        batch.done = torch.tensor(batch.done, dtype=torch.float32)
        batch.obs_next = torch.tensor(batch.obs_next, dtype=torch.float32)

        # Choose next action based on epsilon-greedy policy using `obs_next`
        next_action = policy.forward(Batch(obs=batch.obs_next)).act

        # Perform SARSA learning
        policy.learn(batch, torch.tensor(next_action))

    # Testing and evaluation
    result = test_collector.collect(n_episode=10, reset_before_collect=True)
    print(f'Epoch #{epoch + 1}: reward = {result.returns.mean()}, loss = {policy.learn(batch, torch.tensor(next_action))["loss"]}')
    
    # Log the average reward for the epoch
    logger.writer.add_scalar("Reward/test_avg", result.returns.mean(), epoch)

Did it learn? Do you see rewards increasing? 

If so let's save the model:

In [None]:
torch.save(net.state_dict(), "models/sarsa_cartpole_model.pth")

### Test the model

Let's test out the model, and watch what it learnt.

Load in the trained model.

In [None]:
# Initialize a new network with the same architecture
loaded_net = QNet(state_shape, action_shape)
loaded_net.load_state_dict(torch.load("models/sarsa_cartpole_model.pth"))

Let's create an environment and build a policy based on our saved model.

In [None]:
# Create the environment for evaluation with rendering enabled
eval_env = gym.make("CartPole-v1", render_mode="human")

# Set the loaded network as the model for a new SARSA policy
loaded_policy = SARSAPolicy(model=loaded_net, optim=optimizer, action_space=action_space, gamma=0.99, epsilon=0.0)  # Set epsilon=0 for pure exploitation


Now let's run our agent in the environment. Note: you can change the number of episodes to watch!

In [None]:

# Set the number of episodes you want to watch
num_episodes = 1

for episode in range(num_episodes):
    obs, _ = eval_env.reset()
    done = False
    total_reward = 0
    
    print(f"Starting episode {episode + 1}")

    while not done:
        # Create a batch for the current observation
        obs_batch = Batch(obs=[obs])
        
        # Get action based on loaded model's Q-values (no exploration)
        action = loaded_policy.forward(obs_batch).act[0]
        
        # Step the environment with the selected action
        obs, reward, done, truncated, _ = eval_env.step(action)
        total_reward += reward

        # Check if the episode has ended
        if done or truncated:
            print(f"Episode {episode + 1} ended with total reward: {total_reward}")
            break  # Break out of the loop to start the next episode


# Close the environment after finishing all episodes
eval_env.close()

**Things to try**

Try changing the environment or changing the hyperparameters:

* **Epsilon** (when to explore!): too high the agent cannot take advantage of the environment as they are acting randomly, too low the agent might get stuck learning the first useful thing rather than finding the best actions.

* **Learning rate** (how fast to learn from new data): too high and the agent might learn sprious correlations between actions and outcomes, too low and it might take the agent for ever to figure what actions lead to good rewards.

* **Discount factor** or **gamma** (how much does the agent value future vs. near rewards): too high and the agent might miss near rewards, too low and the agent might be too focused on the short term and miss longer term outcomes.

Try altering some of these hyperparameters and see how that changes the ability of your agent to learn! Which hyperparameters work best?

