# Imports

In [1]:
from dataclasses import dataclass
from collections import deque
from itertools import count

In [2]:
import gym
from gym.wrappers.monitoring.video_recorder import VideoRecorder

In [3]:
import plotly.graph_objects as go

In [4]:
import numpy as np

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

# Parameters

These parameters were taken from https://github.com/yc930401/Actor-Critic-pytorch/blob/master/Actor-Critic.py

In [6]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f881078a950>

In [7]:
HIDDEN = [128, 256]

In [8]:
env = gym.make("CartPole-v0")
CATEGORICAL_ACTION = True
ACTION_DIM = env.action_space.n
STATE_DIM = env.observation_space.shape[0]

In [9]:
LR = 0.001
N_ITERS = 1000
GAMMA = 0.99

# The Neural Network

Actor Critic methods use two network outputs:

1. A value prediction based on the state. (Value Network, Critic)
2. An action prediction based on the state. (Policy Network, Actor)

In [15]:
class Actor(nn.Module):
    def __init__(self):
        super(Actor, self).__init__()
        
        # We are going to make a feed forward network of depth len(HIDDEN)
        self.layers = []
        self.input_layer = nn.Linear(STATE_DIM, HIDDEN[0])
        for i, j in zip(HIDDEN, HIDDEN[1:]):
            self.layers.append(nn.Linear(i,j))
            
        # These will be our output layers
        # If the policy is categorical we need to use a softmax output
        self.policy_output_layer = nn.Linear(HIDDEN[-1], ACTION_DIM)

    def forward(self, state):
        temp = F.relu(self.input_layer(state))
        for layer in self.layers:
            temp = F.relu(layer(temp))
            
        # The policy network 
        # If the policy is categorical we need to sample from the softmax distribution
        policy_dist = self.policy_output_layer(temp)
        if CATEGORICAL_ACTION:
            policy_dist = F.softmax(policy_dist, dim=0)
            
        return policy_dist

In [11]:
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        
        # We are going to make a feed forward network of depth len(HIDDEN)
        self.layers = []
        self.input_layer_state = nn.Linear(STATE_DIM, HIDDEN[0])
        self.input_layer_action = nn.Linear(ACTION_DIM, HIDDEN[0])
        for i, j in zip(HIDDEN, HIDDEN[1:]):
            self.layers.append(nn.Linear(i,j))
            
        # These will be our output layers
        # If the policy is categorical we need to use a softmax output
        self.value_output_layer = nn.Linear(HIDDEN[-1], 1)

    def forward(self, state, action):
        state_input = F.relu(self.input_layer_state(state))
        action_input = F.relu(self.input_layer_action(action))
        temp = state_input + action_input
        for layer in self.layers:
            temp = F.relu(layer(temp))
            
        # The value output is a simple linear layer
        value = self.value_output_layer(temp)
        
        # Return as a tuple
        return value

# Training Loop

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
fig = go.FigureWidget()
fig.add_scatter()

FigureWidget({
    'data': [{'type': 'scatter', 'uid': 'd564a08c-9017-45bb-ab8e-204e3edad8f4'}], 'layout': {'t…

In [None]:
actor, critic = Actor(), Critic()
actor_optimizer = optim.Adam(actor.parameters(), lr=LR)
critic_optimizer = optim.Adam(critic.parameters(), lr=LR)
scores = []
video_recorder = VideoRecorder(env, './output/03_Cartpole_A2C_Q_Critic.mp4', enabled=True)
for this_iter in range(N_ITERS):
    log_probs, values, rewards, dones = [], [], [], []
    entropy = 0
    state = env.reset()

    for i in count():
        # Rendering
        if (this_iter+1) % 10 == 0:
            video_recorder.capture_frame()

        state = torch.FloatTensor(state).to(device)
        policy_dist = actor(state)
        policy = Categorical(policy_dist)
        action = policy.sample()
        value = critic(state, policy_dist.detach())
        next_state, reward, done, _ = env.step(action.cpu().numpy())

        log_prob = policy.log_prob(action).unsqueeze(0)
        entropy += policy.entropy().mean()

        log_probs.append(log_prob)
        values.append(value)
        rewards.append(torch.tensor([reward], dtype=torch.float, device=device))
        dones.append(torch.tensor([1-done], dtype=torch.float, device=device))

        state = next_state

        if done:
            # print('Iteration: {}, Score: {}'.format(this_iter, i))
            scores.append(i)
            with fig.batch_update():
                # fig.data[0].x = list(range(this_iter))
                for i in range(len(scores)):
                    fig.data[0].y = scores
            break

    log_probs = torch.cat(log_probs)
    values = torch.cat(values)
    rewards = torch.cat(rewards)
    dones = torch.cat(dones)
    
    # REF: https://github.com/yc930401/Actor-Critic-pytorch/blob/master/Actor-Critic.py
    # Get the predicted value at the final state
    final_value = critic(torch.FloatTensor(state).to(device), policy_dist.detach())

    # Calculate the cumulative rewards using the final predicted value as the terminal value
    cum_reward = final_value
    not_dones = 1 - dones
    discounted_future_rewards = torch.FloatTensor(np.zeros(len(rewards))).to(device)
    for i in range(len(rewards)):
        cum_reward = rewards[-i] + GAMMA * cum_reward * not_dones[-1]
        discounted_future_rewards[-i] = cum_reward

    # Now we calculate the advantage function
    advantage = discounted_future_rewards - values

    # And the loss for both the actor and the critic
    actor_loss = -(log_probs * advantage.detach()).mean()
    critic_loss = advantage.pow(2).mean()
        
    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()
    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()
    # torch.save(model, 'model.pkl')
video_recorder.close()
env.close()


Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.

