# Proximal Policy Optimization

In [1]:
import numpy as np
from tqdm import tqdm
import gym
import random
import time
import matplotlib.pyplot as plt
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import wandb
wandb.init(project='ppo_cartpole')

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('GPU', torch.cuda.is_available())

env = gym.make('CartPole-v0')

# Get size of observation space
obs_size = env.observation_space.shape[0]
print(f'Observation space: {obs_size}')
# Cart Position, Cart Velocity, Pole Angle, Pole Angular Velocity 

# Get number of actions from gym action space
n_actions = env.action_space.n
print(f'Action space: {n_actions}')
# Left, Right

[34m[1mwandb[0m: Currently logged in as: [33msradicwebster[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.11 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


GPU False
Observation space: 4
Action space: 2


In [2]:
class MLP_policy(nn.Module):
    def __init__(self):
        super(MLP_policy, self).__init__()
        self.fc1 = nn.Linear(obs_size, 64) 
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, n_actions)
            
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class MLP_Vfunction(nn.Module):
    def __init__(self):
        super(MLP_Vfunction, self).__init__()
        self.fc1 = nn.Linear(obs_size, 64) 
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 1)
            
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def actor(state):
    policy = policy_net(torch.from_numpy(state).float().to(device))
    probs = F.softmax(policy, dim=0)
    dist = torch.distributions.Categorical(probs)
    action = dist.sample().item()
    return action, probs[action]

def rewards_to_go(rewards):
    rewards = torch.Tensor(rewards).to(device)
    returns = torch.stack([rewards[i:].sum() for i in range(len(rewards))]).to(device)
    return returns

def get_value(states):
    return torch.stack([v_net(torch.from_numpy(state).float().to(device))[0] for state in states]).to(device)

def optimise_v_net(returns, values):
    loss_v_net = loss_fn(returns, values)
    wandb.log({"value_loss": loss_v_net}, step=episode)
    optimizer_v_net.zero_grad()
    loss_v_net.to(device)
    loss_v_net.backward()#retain_graph=True
    optimizer_v_net.step()

def get_lob_prob(state, action):
    policy = policy_net(torch.from_numpy(state).float().to(device))
    log_prob = F.log_softmax(policy, dim=0)[action]    
    return log_prob

def optimise_policy(states, actions, old_action_probs, adv):
    
    ratio = []
    for i in range(len(actions)):
        policy = policy_net(torch.from_numpy(states[i]).float().to(device))
        probs = F.softmax(policy, dim=0)
        ratio.append(probs[actions[i]] / old_action_probs[i])
    ratio = torch.stack(ratio).to(device)
    
    clip_adv = torch.clamp(ratio, 1-CLIP_RATIO, 1+CLIP_RATIO) * adv
    loss_policy = -torch.min(ratio * adv, clip_adv).mean()
    
    optimizer_policy.zero_grad()
    loss_policy.to(device)
    loss_policy.backward()#retain_graph=True)
    optimizer_policy.step()

In [3]:
GAMMA = 0.99
LEARNING_RATE_POLICY = 1e-4
LEARNING_RATE_VALUE = 5e-4 # want critic to update faster
MINIBATCH = 32
CLIP_RATIO = 0.1

num_episodes = 2000

# Save model inputs and hyperparameters
wandb.config = wandb.config
wandb.config.gamma = GAMMA
wandb.config.learning_rate_policy = LEARNING_RATE_POLICY
wandb.config.learning_rate_value = LEARNING_RATE_VALUE
wandb.config.minibatch = MINIBATCH
wandb.config.clipratio = CLIP_RATIO

# initialise parameterised policy function
policy_net = MLP_policy().to(device) 
optimizer_policy = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE_POLICY)

v_net = MLP_Vfunction().to(device)
optimizer_v_net = optim.Adam(v_net.parameters(), lr=LEARNING_RATE_VALUE)
loss_fn = torch.nn.MSELoss()

nodes = []
params = list(policy_net.parameters())
for i in range(len(params))[1::2]:
    nodes.append(params[i].size()[0])
wandb.config.policy_nodes = nodes

nodes = []
params = list(v_net.parameters())
for i in range(len(params))[1::2]:
    nodes.append(params[i].size()[0])
wandb.config.value_nodes = nodes


episode_rewards = []
for episode in tqdm(range(num_episodes)):
    
    # tracking
    step = 0
    episode_reward = 0
    states = []
    actions = []
    rewards = []
    old_action_probs = []
    
    # get start state from env
    state = env.reset()
    
    terminal = False
    while terminal is False:
        
        # choose next action
        action, prob = actor(state)
        
        # take next step and get reward from env
        next_state, reward, terminal, _ = env.step(action)
        
        # tracking
        states.append(state)
        actions.append(action)
        old_action_probs.append(prob.item())
        rewards.append(reward*np.power(GAMMA, step)) 
        
        # updates
        step += 1
        state = next_state
        episode_reward += reward
        
        if step % MINIBATCH == 0 or terminal == True:
            
            # next state for TD calc
            states.append(next_state)
            
            # reward to go 
            returns = rewards_to_go(rewards)
            
            # value estimate
            values = get_value(states)
        
            # advantage estimation using TD error
            adv = torch.Tensor([rewards[i] + (1-terminal) * GAMMA * values[i+1].item() - values[i].item() for i in range(len(rewards))])
        
            # update policy net
            optimise_policy(states, actions, old_action_probs, adv)
            
            # update value net
            optimise_v_net(returns, values[:-1])
            
            if episode != num_episodes-1 or terminal is False:
                # clear lists
                states = []
                actions = []
                rewards = []
                old_action_probs = []
    

    
    wandb.log({"reward": int(episode_reward)}, step=episode)
    episode_rewards.append(int(episode_reward))

wandb.config.ave_rewards = np.mean(episode_rewards[-200:])

100%|██████████| 2000/2000 [05:50<00:00,  5.71it/s]
