# <center> Deep Reinforcement Learning - DQN with Target Network

# Activity 0. Setup
    
#### Install additional dependencies

Let us first make sure that all the required dependencies are installed

In [None]:
import sys
!{sys.executable} -m pip install swig
!{sys.executable} -m pip install gymnasium[box2d]

In [None]:
!git clone https://github.com/rvss-australia/RVSS.git

In [None]:
import os
print(os.getcwd())
os.chdir("./RVSS/Reinforcement_Learning/")

#### Import required dependencies

In [None]:
from RL_Support.render import *
import random

import gymnasium as gym
import math
import random
import numpy as np
from collections import namedtuple
import copy
from itertools import count
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd 


import matplotlib.pyplot as plt
from matplotlib import animation

import io
import base64
from IPython.display import HTML

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
sample_agent = 'RL_Support/DQNagent_sample.pt'

def set_seed(env, seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    env.reset(seed = seed)

# Activity 1. DQN Algorithm with Target Network

In this notebook we will extend DQN by adding a target network. This revised version of the DQN algorithm is shown below

![DQN_algorithm.png](https://i.postimg.cc/15tMBhjG/DQN-algorithm.png)

### The main changes are:

- We have extended the Agent class' attributes to include 2 DQN networks instead of one (one target and one policy network)

- We have change the method ``get_next_q(.)`` so the q-values are computed using the target network instead of the policy network

- We have added a new method called ``transfer_parameters``(.). This method replaces the parameters of the target network with those of the policy network 

- We have modified the main loop to include instructions that call ``transfer_parameters``(.) after a predefined number of episodes

Compared to DQN, DQN with a target network is more stable and robust, since it updates the network parameters at a lower frequency. It helps with the "catastrophic forgetting" problem that you may have observed in DQN.

## Replay Buffer

We use the same Replay Buffer implementation

**Note**: This implementation of the ReplayMemory class was taken from [***Pytorch DQN tutorial***](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)

In [None]:
# This tuple represents one observation in our environment
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayMemory:
    """
    A cyclic buffer of bounded size (capacity) that holds the transitions 
    observed recently. 
    
    It also implements a sample() method for selecting a random 
    batch of transitions for training.
    """
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        """Returns a minibatch of `Transition` randomly
        Args:
            batch_size (int): Size of mini-bach
        Returns:
            List[Transition]: Minibatch of `Transition`
        """
        return random.sample(self.memory, batch_size)

    def __len__(self):
        """Returns the length """
        return len(self.memory)

## DQN Network

Let us now define the Multi Layer Perceptron network that will be used as the function approximator for the action-value function (q-function)

In [None]:
class DQN(nn.Module):
    """DQN Network
        Args:
        input_dim (int): `state` dimension.
        output_dim (int): Number of actions.
        hidden_dim (int): Hidden dimension in fully connected layer
    """
    def __init__(self, num_inputs=8, num_actions=4, hidden_dim_1=32, hidden_dim_2=32):
        super(DQN, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(num_inputs, hidden_dim_1),
            nn.ReLU(),
            nn.Linear(hidden_dim_1, hidden_dim_2),
            nn.ReLU(),
            nn.Linear(hidden_dim_2, num_actions)
        )
        
    def forward(self, x):
        """Returns a Q_value
        Args:
            x (torch.Tensor): `State` 2-D tensor of shape (n, num_inputs)
        Returns:
            torch.Tensor: Q_value, 2-D tensor of shape (n, num_actions)
        """
        return self.layers(x)

## DQN Agent

This class contains the main steps of the Deep Q-learnig algorithm.

In [None]:
class DQNAgent(object):
    """DQN Agent
    This class contains the main steps of the DQN algorithm
    
    Attributes:
    action_value_net (DQN): Function approximator for our action-value function (predictor)
    target_net (DQN): Function approximator for our target action-value function
    loss_fn (MSELoss): Criterion that measures the mean squared error (squared L2 norm) 
                       between each element of the predicted and target q-values.
    optimizer (Adam): Stochastic gradient optimize
    gamma (float): Discount factor
    """
    
    def __init__(self, input_dim=8, output_dim=4, 
                 hidden_dim_1=32, hidden_dim_2=32, gamma=0.99, lr=0.0001):
        """
        Define instance of DQNAgent
        Args:
        input_dim (int): `state` dimension.
        output_dim (int): Number of actions.
        hidden_dim (int): Hidden dimension in fully connected layer
        """
        self.action_value_net = DQN(input_dim, output_dim, hidden_dim_1, hidden_dim_2).to(device)
                
        # We add a target network. Both the policy and target networks must start with same parameters
        self.target_net = DQN(input_dim, output_dim, hidden_dim_1, hidden_dim_2).to(device)
        self.target_net.load_state_dict(self.action_value_net.state_dict())
        self.target_net.eval()
                        
        self.loss_fn = nn.MSELoss()
        self.optimizer = optim.Adam(self.action_value_net.parameters(), lr=lr)
                
        self.gamma = torch.tensor(gamma).float().to(device)
        
    def get_action(self, state, action_space_dim, epsilon):
        """
        Select next action using epsilon-greedy policy
        Args:
        epsilon (float): Threshold used to decide whether a random or maximum-value action 
                         should be taken next
         Returns:
            int: action index
        """        
        with torch.no_grad():
            cur_q = self.action_value_net(torch.from_numpy(state).float().to(device))
        q_value, action = torch.max(cur_q, axis=0)
        action = action if torch.rand(1,).item() > epsilon else torch.randint(0, action_space_dim, (1,)).item()
        action = torch.tensor([action]).to(device)
        return action
    
    def get_next_q(self, state):
        """Returns Q_value for maximum valued action at each state s
        Args:
            x (torch.Tensor): `State` 2-D tensor of shape (n, num_inputs)
        Returns:
            torch.Tensor: Q_value, 1 tensor of shape (n)
        """
        with torch.no_grad():
            next_q = self.target_net(state)
        q, _ = torch.max(next_q, axis=1)
        return q
    
    def optimize(self, batch):
        """Computes `loss` and backpropagation
        Args:
            batch: List[Transition]: Minibatch of `Transition`
        Returns:
            float: loss value
        """
        
        state_batch = torch.stack(batch.state)
        action_batch = torch.stack(batch.action)
        reward_batch = torch.stack(batch.reward)
        next_state_batch = torch.stack(batch.next_state)
                
        # Compute a mask of non-final states and concatenate the batch elements
        # (a final state is the one after which the simulation ends)
        non_final_mask = torch.tensor(tuple(map(lambda s: s.item() is not True,
                                          batch.done)), device=device, dtype=torch.bool)
        non_final_next_states = torch.stack([s for i, s in enumerate(batch.next_state)
                                            if batch.done[i].item() is not True])

        # Compute predicted q-values
        predicted_q = self.action_value_net(state_batch).gather(1, action_batch).reshape(1,-1)
        
        # Compute expected values for non-terminal and terminal states (this is our TD target)
        target_q = torch.zeros(len(batch.state), device=device)
        target_q[non_final_mask] = self.get_next_q(non_final_next_states)
        expected_q = reward_batch.reshape(1,-1)+(self.gamma * target_q)
        
        # Compute loss
        loss = self.loss_fn(expected_q, predicted_q)
        
        # Use loss to compute gradient and update policy parameters through backpropagation
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        self.optimizer.step()
                
        return loss.item()
    
    def transfer_parameters(self):
        """Transfer parameters from action-value to target network
        """
        self.target_net.load_state_dict(self.action_value_net.state_dict())
        self.target_net.eval()

### Running Parameters

The parameter ``freq_sync`` defines how often parameters are transferred between networks. Here for every 10 episodes we copy the parameters over to the target network.

Note that here we are controlling the training duration by the number of episodes, instead of using the number of frames as in DQN. This is why the epsilon_decay is dropped.

In [None]:
# Define running hyper-parameters and epsilon training sequence
# Feel free to change the parameters
memory_capacity = 2500
batch_size = 64
num_episodes = 2500
epsilon_start = 1.0
epsilon_end = 0.05
epsilon_decay = 400
gamma = 0.99
lr = 1e-3
hidden_dim_1 = 32
hidden_dim_2 = 32
freq_sync = 10
seed_value = 42

epsilon_by_step = lambda ep_idx: epsilon_end + (epsilon_start - epsilon_end) * math.exp(-1. * ep_idx / epsilon_decay)

# Plotting out epsilon over episodes that should match what you expect
fig, ax = plt.subplots(figsize=(4, 4))
ax.plot([epsilon_by_step(i) for i in range(num_episodes)])
ax.set_xlabel("Num. episodes")
ax.set_ylabel("Epsilon")

plt.show()

### Main Loop and Replay Buffer Control

This is the main loop of our DQN implementation. Here we generate the samples added to the replay memory and train the agent using a batch sampled for the replay memory

In [None]:
losses_list, rewards_list, episode_len_list = [], [], []
list_epsilon = []
replay_buffer = ReplayMemory(memory_capacity)

env = gym.make("LunarLander-v2")
set_seed(env, seed_value)
n_actions = env.action_space.n
dim_state = env.observation_space.shape[0]

agent = DQNAgent(input_dim=dim_state, 
                 output_dim=n_actions, 
                 hidden_dim_1=hidden_dim_1, 
                 hidden_dim_2=hidden_dim_2, 
                 gamma=gamma, lr=lr)

cur_epsilon = epsilon_start


for i_episode in tqdm(range(num_episodes)):

    state, is_finished, ep_len, losses, rewards = env.reset()[0], False, 0, 0, 0

    cur_epsilon = epsilon_by_step(i_episode+1)
    list_epsilon += [cur_epsilon]
    while not is_finished:
        ep_len += 1
        action = agent.get_action(state, n_actions, cur_epsilon)
        next_state, reward, done, truncated, _ = env.step(action.item())
        is_finished = done or truncated
        rewards += reward

        t_s = torch.tensor(state).float().to(device)
        t_r = torch.tensor([reward]).float().to(device)
        t_ns = torch.tensor(next_state).float().to(device)
        t_a = action.to(device)
        t_done = torch.tensor([is_finished]).bool().to(device)

        replay_buffer.push(t_s, t_a, t_ns, t_r, t_done)
        state = next_state

        if len(replay_buffer) > batch_size:
            transitions = replay_buffer.sample(batch_size)
            batch = Transition(*zip(*transitions))
            loss = agent.optimize(batch)
            losses += loss

    losses_list.append(losses / ep_len)
    rewards_list.append(rewards)
    episode_len_list.append(ep_len)

    # Add rule that call transfer_parameters() every freq_sync episodes
    if i_episode % freq_sync == 0:
        agent.transfer_parameters()

    # Every 10 episodes we plot the approximator's progress and performance
    if i_episode % 10 == 0:
        plot_ep(i_episode, rewards_list, losses_list)

In [None]:
# save the trained network, this is what you will submit along with your ipynb
file_name = 'DQNTarget_Agent.pt' # replace 12345 with your student ID
agent.seed = seed_value
torch.save(agent, file_name)

In [None]:
# Load the trained agent
file_name = file_name

agent = torch.load(file_name)
agent.action_value_net.to(device)

# run n trials to examine how this agent performs
env = gym.make("LunarLander-v2", render_mode="rgb_array")
set_seed(env, seed_value)
n_trials = 5
list_rewards = []

wrapagent = DQNWrapper(agent)

# test on n trials 
visualize_trial_k = 2
for i in range(n_trials):

    # if you want to see the animation increase max_frames (slower)
    if i == visualize_trial_k:
        framebuffer, ep_return = simulate(wrapagent, env, max_frames=700)
    else:
        _, ep_return = simulate(wrapagent, env, max_frames=1)
    env.close()
    # Store the returns
    list_rewards.append(ep_return)
    print(f"Trial No.{i}, return: {ep_return}")

# summarize the trials
print(f'\nAverage return {np.round(np.mean(list_rewards),3)} +- {np.round(np.std(list_rewards), 3)}')
html = animate(framebuffer)
display(html)