# REINFORCE: Monte Carlo Policy Gradient

Involves learning a parameterised policy $\pi(a|s,\theta)$ directly. Using gradient ascent, we can move $\theta$ toward the direction suggested by the gradient of the value function $J(\theta)$ with respect to $\theta$ to find the policy that produces the highest return.

$$\theta_{t+1}=\theta_t+\nabla_\theta J(\theta)$$

The Policy Gradient Theorem is an expression for how performance is affected by the policy parameter and does not involve derivatives of state distribution

$$\nabla_\theta J(\theta)= \mathop{\mathbb{E}} [Q^\pi(s,a) \nabla_\theta ln \space \pi_\theta(a|s)]$$

REINFORCE relies on an estimated return by Monte-Carlo sampling of full trajectories. The policy update is:

$$\theta_{t+1} = \theta_t + \alpha \space \gamma^t \space G_t \space \nabla_\theta ln \space \pi_\theta(a|s)$$

A widely used variation of REINFORCE is to subtract a baseline value from the return $G_t$ to reduce the variance of gradient estimation while keeping the bias unchanged. A common baseline in the state-value value function $V(s)$.

Alternatively, the return in the gradient expression can be replaced with $A(s,a)$, the advantage function (how much more return action $a$ recieves compared to the average action relative to the policy). This has not been implemented here. 

$$\nabla_\theta J(\theta)= \mathop{\mathbb{E}} [A^\pi(s,a) \nabla_\theta ln \space \pi_\theta(a|s)]$$

where $A(s,a)=Q(s,a)-V(s)$


In [None]:
import numpy as np
from tqdm import tqdm
import gym
import random
import time
import matplotlib.pyplot as plt
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import wandb
wandb.init(project='reinforce_cartpole')

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('GPU', torch.cuda.is_available())

env = gym.make('CartPole-v0')

# Get size of observation space
obs_size = env.observation_space.shape[0]
print(f'Observation space: {obs_size}')
# Cart Position, Cart Velocity, Pole Angle, Pole Angular Velocity 

# Get number of actions from gym action space
n_actions = env.action_space.n
print(f'Action space: {n_actions}')
# Left, Right

In [None]:
class MLP_policy(nn.Module):
    def __init__(self):
        super(MLP_policy, self).__init__()
        self.fc1 = nn.Linear(obs_size, 64) 
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, n_actions)
            
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class MLP_Vfunction(nn.Module):
    def __init__(self):
        super(MLP_Vfunction, self).__init__()
        self.fc1 = nn.Linear(obs_size, 64) 
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 1)
            
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def softmax_action(state):
    policy = policy_net(torch.from_numpy(state).float().to(device))
    probs = F.softmax(policy, dim=0)
    dist = torch.distributions.Categorical(probs)
    action = dist.sample().item()
    return action

def get_return(rewards):
    rewards = torch.Tensor(rewards).to(device)
    if reward_to_go is False:
        returns = rewards.sum()
    elif reward_to_go is True:
        returns = torch.stack([rewards[i:].sum() for i in range(len(rewards))]).to(device)
    return returns

def get_value(states):
    return torch.stack([v_net(torch.from_numpy(state).float().to(device))[0] for state in states]).to(device)

def optimise_v_net(current_v, returns):
    loss_v_net = loss_fn(current_v, returns)
    wandb.log({"value_loss": loss_v_net}, step=episode)
    optimizer_v_net.zero_grad()
    loss_v_net.to(device)
    loss_v_net.backward()
    for param in v_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer_v_net.step()

def get_gradient(states, actions, returns, current_v):
    log_probs = []
    for i, state in enumerate(states):
        policy = policy_net(torch.from_numpy(state).float().to(device))
        log_probs.append(F.log_softmax(policy, dim=0)[actions[i]])    
    log_probs = torch.stack(log_probs).to(device)
    phi = returns - current_v
    gradient = (log_probs * phi).sum()
    return gradient

def optimise_policy(loss_policy):
    optimizer_policy.zero_grad()
    loss_policy.to(device)
    loss_policy.backward(retain_graph=True)
    #for param in policy_net.parameters():
    #    param.grad.data.clamp_(-1, 1)
    optimizer_policy.step()

In [None]:
GAMMA = 0.99
LEARNING_RATE_POLICY = 5e-4
LEARNING_RATE_VALUE = 1e-6
reward_to_go = True
baseline = False   

num_episodes = 2000

if reward_to_go == False and baseline == True:
    raise ValueError("if using episode return then cannot use baseline")

# Save model inputs and hyperparameters
wandb.config = wandb.config
wandb.config.learning_rate_policy = LEARNING_RATE_POLICY
if baseline is True:
    wandb.config.learning_rate_value = LEARNING_RATE_VALUE
wandb.config.reward_to_go = reward_to_go
wandb.config.baseline = baseline

# initialise parameterised policy function
policy_net = MLP_policy().to(device) 
optimizer_policy = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE_POLICY)

if baseline is True:
    v_net = MLP_Vfunction().to(device)
    optimizer_v_net = optim.Adam(v_net.parameters(), lr=LEARNING_RATE_VALUE)
    loss_fn = torch.nn.MSELoss()

nodes = []
params = list(policy_net.parameters())
for i in range(len(params))[1::2]:
    nodes.append(params[i].size()[0])
wandb.config.nn_nodes = nodes


episode_rewards = []
for episode in tqdm(range(num_episodes)):
    
    # tracking
    step = 0
    episode_reward = 0
    states = []
    actions = []
    rewards = []
    
    # get start state from env
    state = env.reset() 
    
    terminal = False
    while terminal is False:
        
        # choose next action
        action = softmax_action(state)
        
        # take next step and get reward from env
        next_state, reward, terminal, _ = env.step(action)
        
        # tracking
        states.append(state)
        actions.append(action)
        rewards.append(reward*np.power(GAMMA, step)) 
        
        # updates
        step += 1
        state = next_state
        episode_reward += reward
    
    returns = get_return(rewards)
    
    current_v = 0
    if baseline is True:
        current_v = get_value(states)

    gradient = get_gradient(states, actions, returns, current_v)
    optimise_policy(-gradient)
    
    if baseline is True:
        optimise_v_net(current_v, returns)
    
    wandb.log({"reward": episode_reward}, step=episode)
    episode_rewards.append(episode_reward)

wandb.config.ave_rewards = np.mean(episode_rewards[-500:])