<a href="https://colab.research.google.com/github/sradicwebster/RL_implementation/blob/master/dqn_cartpole.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DQN Cartpole implementation
The RL objective is to learn a policy that maximises return - the discounted, cumulative reward:
$$ R_{t0} = \sum_{t=t0}^\infty \gamma^{t−t0} r_t $$

If we had a correct Q-function which maps from state-action pair to rewards ($Q^*:State \space x \space Action \rightarrow \mathbb{R}$) then we would know which action to take in a given state that maximises return by choosing the action which the highest Q-value:
$$ \pi^*(s) = \underset{a}{\mathrm{argmin}} \space Q^*(s,a) $$

We go about learning the Q-function from interacting in an environment by forming an update equation based on the  Bellman equation (the Q-value of being state $s$ and taking action $a$ is the immediate reward plus the discounted Q-value of the subsequent state and taking actions following policy $\pi$):
$$ Q_\pi(s,a) = r + \gamma Q(s^\prime,\pi(s^\prime))  $$

The right hand side of the equation is known as the target $y$ and the difference between the current Q-value estimation $Q(s,a)$ and $y$ is the temporal difference (TD) error which is minimised in the Q-learning algorithm.

The Deep Q-Networks (DQN) algorithm combines the Q-Learning algorithm with deep neural networks for approximating the Q-function. DQN uses an experience replay memory which stores the transitions that the agent observes, allowing us to reuse this data later. Batches of transitions are sampled to update the parameters of the Q-function by gradient decent to minimise the TD error (MSE loss or Huber loss). By sampling a batch randomly, the transitions are decorrelated which stabilises and improves the DQN training procedure.

Overestimation of Q-values causes large positive biases in updating procedure so in practice we use a different DNN to estimate the target Q-value to stabilise the learning process (giving the algorithm the name double DQN referring to the 2 networks). This target network is fixed and is used to calculate the expected return $ \hat{Q}(s^\prime,a^\prime;\theta^\prime) $. The TD error is now:

$$ TD = r + \gamma \space max_{a^\prime} \hat{Q}(s^\prime,a^\prime;\theta^\prime) - Q(s,a;\theta) $$

The action-value (or policy) neural network $ Q(s,a;\theta) $ is updated as batches of transitions are sampled. The target network parameters are updated every so often (say 1000 iterations) by taking a copy of the policy network.

<img src="dqn_algo.png" style="width:504px;height:410px;">

Mnih et al (2015) [Human Level Control Through Deep Reinforcement Learning](https://deepmind.com/research/publications/human-level-control-through-deep-reinforcement-learning)

#### Dueling DQN
Decomposes Q-function into state value (parameterised by $\alpha$) and advantage (parameterised by $\beta$) for each action. The value of a state is independent of action. Implementation involves the addition of an aggregation layer in network: $$Q(s, a; \alpha, \beta) = V(s; \alpha) + A(s, a; \beta) - \frac{1}{|A|} \sum_{a^{\prime}} A(s, a; \beta) $$

<img src="dueling.png" style="width:250px;height:210px;">

Wang et al (2015) [Dueling Network Architectures for Deep Reinforcement Learning](https://arxiv.org/abs/1511.06581v3)

#### Prioritised Experience Replay (to be implemented)
Changes the sampling distribution from uniform (random) to favour experiences that are deemed more important according to the absolute TD error, which indicates how unexpected a certain transition was (plus a small constant e). The probability of transition i being chosen:
$$ p_i = \frac{(TD_i + e)^a}{\sum_k (TD_k + e)^a} $$
where a is hyperparameter: a=1 is uniform, a=0 is greedy

The Q-value estimation with stochastic updates requires that the updates correspond to the same distribution as the expectation. However, we have introduced a bias toward sampling high-priority experiences so risk overfitting. To correct this bias, we use importance sampling (IS) weights (which reduces the TD error for high priority experiences):
$$ w_i = \left(\frac{1}{N}\frac{1}{p_i}\right)^\beta $$
where $\beta$ is a hyperparameter. Normally $\beta$ starts from zero and gradually reaches 1 ($\beta=1$ then non-uniform probabilities are fully compensated).

Schaul et al (2016) [Prioritized Experience Replay](https://arxiv.org/abs/1511.05952)

Use [bisect](https://docs.python.org/3/library/bisect.html) for implementation?

In [None]:
# Uncomment for colab
#!pip install wandb

In [None]:
import numpy as np
from tqdm import tqdm
import gym
import random
import time
import matplotlib.pyplot as plt
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import wandb
wandb.init(project='dqn_cartpole')

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('GPU', torch.cuda.is_available())

env = gym.make('CartPole-v0')
#env = gym.wrappers.Monitor(env, '.dqn_video/', video_callable=lambda episode_id: episode_id%100==0, force=True)

# Get size of observation space
obs_size = env.observation_space.shape[0]
print(f'Observation space: {obs_size}')
# Cart Position, Cart Velocity, Pole Angle, Pole Angular Velocity 

# Get number of actions from gym action space
n_actions = env.action_space.n
print(f'Action space: {n_actions}')
# Left, Right

In [None]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'terminal'))

class ReplayMemory():
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def store(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self): # this needed??
        return len(self.memory)

    
class Qnetwork(nn.Module):
    def __init__(self):
        super(Qnetwork, self).__init__()
        self.fc1 = nn.Linear(obs_size, 64) 
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, n_actions)
            
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class Dueling_network(nn.Module):
    def __init__(self):
        super(Dueling_network, self).__init__()

        self.fc1 = nn.Linear(obs_size, 64)
        self.fc_value = nn.Linear(64, 128)
        self.fc_adv = nn.Linear(64, 128)

        self.value = nn.Linear(128, 1)
        self.adv = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        value = F.relu(self.fc_value(x))
        adv = F.relu(self.fc_adv(x))

        value = self.value(value)
        adv = self.adv(adv)

        advAverage = torch.mean(adv, dim=0, keepdim=True)
        Q = value + adv - advAverage

        return Q

    
def forward_prop(network, state):
    return network(torch.from_numpy(state).float().to(device))

def epsilon_greedy_action(state):
    
    epsilon = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * episode / EPS_DECAY)
    wandb.log({"epsilon": epsilon}, step=episode)
    
    if np.random.rand() < epsilon:
        return torch.tensor(random.randrange(n_actions)).to(device)
    else:
        with torch.no_grad():
            return torch.argmax(forward_prop(Q_net, state)).to(device)

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return

    minibatch = memory.sample(BATCH_SIZE)
    targets = []
    current_qs = []
    for transition in minibatch:

        if transition.terminal is True:
            target = torch.tensor(transition.reward, device=device)
        else:
            target = transition.reward + GAMMA * forward_prop(target_net, transition.next_state).max()

        current_q = forward_prop(Q_net, transition.state)[transition.action]

        targets.append(target)
        current_qs.append(current_q)

    loss = loss_fn(torch.stack(current_qs).to(device), torch.stack(targets).to(device))
    wandb.log({"loss": loss}, step=episode)

    optimizer.zero_grad()
    loss.backward()
    for param in Q_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [None]:
MEMORY_SIZE = 5000
BATCH_SIZE = 32
GAMMA = 0.99
LEARNING_RATE = 1e-4
UPDATE_TARGET_FREQ = 1000 # steps between target network updates (C parameter)
# epsilon 
EPS_START = 1
EPS_END = 0.0
EPS_DECAY = 250
NETWORK = Qnetwork # Dueling_network

num_episodes = 1000

# Save model inputs and hyperparameters
wandb.config = wandb.config
wandb.config.learning_rate = LEARNING_RATE
wandb.config.batch_size = BATCH_SIZE
wandb.config.update_target = UPDATE_TARGET_FREQ
wandb.config.epsilon_start = EPS_START
wandb.config.epsilon_end = EPS_END
wandb.config.epsilon_decay = EPS_DECAY
wandb.config.memoery_size = MEMORY_SIZE


# initialise parameterised action-value functions
Q_net = NETWORK().to(device) # Q net gets updated
target_net = NETWORK().to(device) # target net updated set to equal Q net every UPDATE_TARGET steps
target_net.load_state_dict(Q_net.state_dict())
wandb.config.network = Q_net.__class__.__name__

nodes = []
params = list(Q_net.parameters())
for i in range(len(params))[1::2]:
    nodes.append(params[i].size()[0])
wandb.config.nn_nodes = nodes

optimizer = optim.Adam(Q_net.parameters(), lr=LEARNING_RATE)
memory = ReplayMemory(MEMORY_SIZE) # comment out for offline
loss_fn = torch.nn.MSELoss() #SmoothL1Loss()# Huber loss 

In [None]:
episode_rewards = []
steps_done = 0

for episode in tqdm(range(num_episodes)):
    
    # reset step count
    episode_reward = 0
    
    # get start state from env
    state = env.reset() 
    
    # timers
    time = []
    
    terminal = False
    while terminal is False:
        
        # choose next action
        action = epsilon_greedy_action(state)
        
        # take next step and get reward from env
        next_state, reward, terminal, _ = env.step(action.item())
        
        # store in memory
        memory.store(state, action, next_state, reward, terminal) # comment out for offline
        
        # updates
        state = next_state
        steps_done += 1
        episode_reward += reward
       
        # Perform one step of the optimization (on the target network)
        optimize_model()
        
        if terminal:
            episode_rewards.append(episode_reward)
            wandb.log({"reward": episode_reward})
            break
            
        # Update the target network, copying all weights and biases in DQN
        if steps_done % UPDATE_TARGET_FREQ == 0:
            target_net.load_state_dict(Q_net.state_dict())
            

In [None]:
wandb.config.ave_reward = np.mean(episode_rewards[-100:])