# Actor Critic

The policy update based on the policy gradient theorem is 

$$\theta_{t+1} = \theta_t + \alpha \space \gamma^t \space Q(s,a) \space \nabla_\theta ln \space \pi_\theta(a|s)$$

In REINFORCE Monte Carlo sampling of episode return was used to estimate $Q(s,a)$ while a parameterised state-value function $V(s)$ was learnt as a baseline to reduce the variance of the gradient estimation (while keeping the bias unchanged). 

In actor critic methods a $Q(s,a)$ is directly learnt. Again $V(s)$ can be subtracted from $Q(s,a)$ which is known as the advantage $A(s,a)$ (how much more return action $a$ recieves compared to the average action relative to the policy). The TD error can be used as an estimate of the advantage function.

$$A(s,a) = Q(s,a)-V(s)=R + \gamma \space V(s^\prime) - V(s) $$

The policy update takes a step proportional to the one step TD error, calculated using a parameterised value function (learnt by regression on the mean-square TD error):

$$\theta_{t+1} = \theta_t + \alpha \space \gamma^t \space \delta \space \nabla_\theta ln \space \pi_\theta(a|s)$$

where: $\delta = R + \gamma \space V(s^\prime) - V(s) $

The value function assigns credit to the policy’s action selections – critic the actor.

Bias is introduced through bootstrapping (updating the value estimate for a state from the estimated values of subsequent states) which reduces variance and speeds up learning.



In [None]:
import numpy as np
from tqdm import tqdm
import gym
import random
import time
import matplotlib.pyplot as plt
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import wandb
wandb.init(project='a2C_cartpole')

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('GPU', torch.cuda.is_available())

env = gym.make('CartPole-v0')

# Get size of observation space
obs_size = env.observation_space.shape[0]
print(f'Observation space: {obs_size}')
# Cart Position, Cart Velocity, Pole Angle, Pole Angular Velocity 

# Get number of actions from gym action space
n_actions = env.action_space.n
print(f'Action space: {n_actions}')
# Left, Right

In [None]:
class MLP_policy(nn.Module):
    def __init__(self):
        super(MLP_policy, self).__init__()
        self.fc1 = nn.Linear(obs_size, 64) 
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, n_actions)
            
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class MLP_Vfunction(nn.Module):
    def __init__(self):
        super(MLP_Vfunction, self).__init__()
        self.fc1 = nn.Linear(obs_size, 64) 
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 1)
            
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def softmax_action(state):
    policy = policy_net(torch.from_numpy(state).float().to(device))
    probs = F.softmax(policy, dim=0)
    dist = torch.distributions.Categorical(probs)
    action = dist.sample().item()
    return action

def get_value(state):
    return v_net(torch.from_numpy(state).float().to(device))

def optimise_v_net(target, current_v):
    loss_v_net = loss_fn(target, current_v)
    wandb.log({"value_loss": loss_v_net}, step=episode)
    optimizer_v_net.zero_grad()
    loss_v_net.to(device)
    loss_v_net.backward()#retain_graph=True)
    #for param in v_net.parameters():
    #    param.grad.data.clamp_(-1, 1)
    optimizer_v_net.step()

def get_lob_prob(state, action):
    policy = policy_net(torch.from_numpy(state).float().to(device))
    log_prob = F.log_softmax(policy, dim=0)[action]    
    return log_prob

def optimise_policy(error, I, state, action):
    log_prob = get_lob_prob(state, action)
    loss_policy = -I * error * log_prob
    optimizer_policy.zero_grad()
    loss_policy.to(device)
    loss_policy.backward()#retain_graph=True)
    #for param in policy_net.parameters():
    #    param.grad.data.clamp_(-1, 1)
    optimizer_policy.step()

In [None]:
GAMMA = 0.99
LEARNING_RATE_POLICY = 1e-4
LEARNING_RATE_VALUE = 5e-3 # want critic to update faster

num_episodes = 1000

# Save model inputs and hyperparameters
wandb.config = wandb.config
wandb.config.learning_rate_policy = LEARNING_RATE_POLICY
wandb.config.learning_rate_value = LEARNING_RATE_VALUE

# initialise parameterised policy function
policy_net = MLP_policy().to(device) 
optimizer_policy = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE_POLICY)

v_net = MLP_Vfunction().to(device)
optimizer_v_net = optim.Adam(v_net.parameters(), lr=LEARNING_RATE_VALUE)
loss_fn = torch.nn.MSELoss()

nodes = []
params = list(policy_net.parameters())
for i in range(len(params))[1::2]:
    nodes.append(params[i].size()[0])
wandb.config.nn_nodes = nodes


episode_rewards = []
for episode in tqdm(range(num_episodes)):
    
    # tracking
    I = 1
    episode_reward = 0
    
    # get start state from env
    state = env.reset() 
    
    terminal = False
    while terminal is False:
        
        # choose next action
        action = softmax_action(state)
        
        # take next step and get reward from env
        next_state, reward, terminal, _ = env.step(action)
        
        # TD error
        target = reward + (1-terminal) * GAMMA * get_value(next_state)
        current_v = get_value(state)
        error = (target - current_v).item()
        
        # update value net
        optimise_v_net(target, current_v)
        
        # update policy net
        optimise_policy(error, I, state, action)
        
        # updates
        I *= GAMMA
        state = next_state
        episode_reward += reward

    wandb.log({"reward": episode_reward}, step=episode)
    episode_rewards.append(episode_reward)

wandb.config.ave_rewards = np.mean(episode_rewards[-200:])