# Imports

In [1]:
from dataclasses import dataclass
from collections import deque
from itertools import count

In [2]:
import gym

In [3]:
import plotly.graph_objects as go

In [4]:
import numpy as np

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

# Parameters

These parameters were taken from https://github.com/yc930401/Actor-Critic-pytorch/blob/master/Actor-Critic.py

In [None]:
torch.manual_seed(0)

In [6]:
HIDDEN = [128, 256]

In [7]:
env = gym.make("CartPole-v0").unwrapped
CATEGORICAL_ACTION = True
ACTION_DIM = env.action_space.n
STATE_DIM = env.observation_space.shape[0]

In [8]:
LR = 0.001
N_ITERS = 1000
GAMMA = 0.99

# The Neural Network

Actor Critic methods use two network outputs:

1. A value prediction based on the state. (Value Network, Critic)
2. An action prediction based on the state. (Policy Network, Actor)

In [10]:
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        
        # We are going to make a feed forward network of depth len(HIDDEN)
        self.layers = []
        self.input_layer = nn.Linear(STATE_DIM, HIDDEN[0])
        for i, j in zip(HIDDEN, HIDDEN[1:]):
            self.layers.append(nn.Linear(i,j))
            
        # These will be our output layers
        # If the policy is categorical we need to use a softmax output
        self.value_output_layer = nn.Linear(HIDDEN[-1], 1)
        self.policy_output_layer = nn.Linear(HIDDEN[-1], ACTION_DIM)

    def forward(self, state):
        temp = F.relu(self.input_layer(state))
        for layer in self.layers:
            temp = F.relu(layer(temp))
            
        # The policy network 
        # If the policy is categorical we need to sample from the softmax distribution
        policy = self.policy_output_layer(temp)
        if CATEGORICAL_ACTION:
            policy = Categorical(F.softmax(policy))
            
        # The value output is a simple linear layer
        value = self.value_output_layer(temp)
        
        # Return as a tuple
        return value, policy
    
    def calculate_loss(self, 
                       state: torch.FloatTensor,
                       values: torch.FloatTensor,
                       log_probs: torch.FloatTensor,
                       rewards: torch.FloatTensor,
                       dones: torch.FloatTensor):
        # REF: https://github.com/yc930401/Actor-Critic-pytorch/blob/master/Actor-Critic.py
        # Get the predicted value at the final state
        final_value, _ = self(state)
        
        # Calculate the cumulative rewards using the final predicted value as the terminal value
        cum_reward = final_value
        not_dones = 1 - dones
        discounted_future_rewards = torch.FloatTensor(np.zeros(len(rewards))).to(device)
        for i in range(len(rewards)):
            cum_reward = rewards[-i] + GAMMA * cum_reward * not_dones[-1]
            discounted_future_rewards[-i] = cum_reward
        
        # Now we calculate the advantage function
        advantage = discounted_future_rewards - values

        # And the loss for both the actor and the critic
        actor_loss = -(log_probs * advantage.detach()).mean()
        critic_loss = advantage.pow(2).mean()
        
        return critic_loss, actor_loss

# Training Loop

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
fig = go.FigureWidget()
fig.add_scatter()

FigureWidget({
    'data': [{'type': 'scatter', 'uid': 'a7283b35-2908-429b-8237-f8587f88b16f'}], 'layout': {'t…

In [13]:
model = ActorCritic()
optimizer = optim.Adam(model.parameters(), lr=LR)
scores = []
for this_iter in range(N_ITERS):
    log_probs, values, rewards, dones = [], [], [], []
    entropy = 0
    state = env.reset()

    for i in count():
        state = torch.FloatTensor(state).to(device)
        value, policy = model(state)
        action = policy.sample()
        next_state, reward, done, _ = env.step(action.cpu().numpy())

        log_prob = policy.log_prob(action).unsqueeze(0)
        entropy += policy.entropy().mean()

        log_probs.append(log_prob)
        values.append(value)
        rewards.append(torch.tensor([reward], dtype=torch.float, device=device))
        dones.append(torch.tensor([1-done], dtype=torch.float, device=device))

        state = next_state

        if done:
            print('Iteration: {}, Score: {}'.format(this_iter, i))
            if this_iter % 100:
                #env.render()
                pass
            scores.append(i)
            with fig.batch_update():
                # fig.data[0].x = list(range(this_iter))
                for i in range(len(scores)):
                    fig.data[0].y = scores
            break

    log_probs = torch.cat(log_probs)
    values = torch.cat(values)
    rewards = torch.cat(rewards)
    dones = torch.cat(dones)
    
    critic_loss, actor_loss = model.calculate_loss(torch.FloatTensor(state).to(device), values, log_probs, rewards, dones)

    optimizer.zero_grad()
    actor_loss.backward(retain_graph=True)
    critic_loss.backward()
    optimizer.step()
    # torch.save(model, 'model.pkl')
env.close()


Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.



Iteration: 0, Score: 11
Iteration: 1, Score: 28
Iteration: 2, Score: 11
Iteration: 3, Score: 14
Iteration: 4, Score: 46
Iteration: 5, Score: 10
Iteration: 6, Score: 16
Iteration: 7, Score: 13
Iteration: 8, Score: 13
Iteration: 9, Score: 21
Iteration: 10, Score: 18
Iteration: 11, Score: 9
Iteration: 12, Score: 31
Iteration: 13, Score: 38
Iteration: 14, Score: 36
Iteration: 15, Score: 16
Iteration: 16, Score: 19
Iteration: 17, Score: 28
Iteration: 18, Score: 17
Iteration: 19, Score: 10
Iteration: 20, Score: 8
Iteration: 21, Score: 29
Iteration: 22, Score: 15
Iteration: 23, Score: 50
Iteration: 24, Score: 32
Iteration: 25, Score: 13
Iteration: 26, Score: 36
Iteration: 27, Score: 16
Iteration: 28, Score: 12
Iteration: 29, Score: 12
Iteration: 30, Score: 10
Iteration: 31, Score: 12
Iteration: 32, Score: 44
Iteration: 33, Score: 22
Iteration: 34, Score: 11
Iteration: 35, Score: 45
Iteration: 36, Score: 22
Iteration: 37, Score: 14
Iteration: 38, Score: 21
Iteration: 39, Score: 21
Iteration: 4

Iteration: 320, Score: 39
Iteration: 321, Score: 50
Iteration: 322, Score: 70
Iteration: 323, Score: 109
Iteration: 324, Score: 75
Iteration: 325, Score: 57
Iteration: 326, Score: 71
Iteration: 327, Score: 89
Iteration: 328, Score: 34
Iteration: 329, Score: 37
Iteration: 330, Score: 25
Iteration: 331, Score: 56
Iteration: 332, Score: 37
Iteration: 333, Score: 43
Iteration: 334, Score: 51
Iteration: 335, Score: 26
Iteration: 336, Score: 141
Iteration: 337, Score: 25
Iteration: 338, Score: 25
Iteration: 339, Score: 20
Iteration: 340, Score: 18
Iteration: 341, Score: 21
Iteration: 342, Score: 45
Iteration: 343, Score: 132
Iteration: 344, Score: 18
Iteration: 345, Score: 67
Iteration: 346, Score: 19
Iteration: 347, Score: 19
Iteration: 348, Score: 108
Iteration: 349, Score: 121
Iteration: 350, Score: 30
Iteration: 351, Score: 19
Iteration: 352, Score: 24
Iteration: 353, Score: 126
Iteration: 354, Score: 51
Iteration: 355, Score: 19
Iteration: 356, Score: 13
Iteration: 357, Score: 91
Iterat

Iteration: 633, Score: 101
Iteration: 634, Score: 75
Iteration: 635, Score: 161
Iteration: 636, Score: 169
Iteration: 637, Score: 63
Iteration: 638, Score: 85
Iteration: 639, Score: 56
Iteration: 640, Score: 174
Iteration: 641, Score: 52
Iteration: 642, Score: 59
Iteration: 643, Score: 164
Iteration: 644, Score: 225
Iteration: 645, Score: 15
Iteration: 646, Score: 165
Iteration: 647, Score: 29
Iteration: 648, Score: 218
Iteration: 649, Score: 208
Iteration: 650, Score: 187
Iteration: 651, Score: 53
Iteration: 652, Score: 52
Iteration: 653, Score: 52
Iteration: 654, Score: 24
Iteration: 655, Score: 575
Iteration: 656, Score: 86
Iteration: 657, Score: 106
Iteration: 658, Score: 93
Iteration: 659, Score: 206
Iteration: 660, Score: 71
Iteration: 661, Score: 25
Iteration: 662, Score: 121
Iteration: 663, Score: 96
Iteration: 664, Score: 95
Iteration: 665, Score: 83
Iteration: 666, Score: 181
Iteration: 667, Score: 25
Iteration: 668, Score: 101
Iteration: 669, Score: 28
Iteration: 670, Score:

Iteration: 942, Score: 490
Iteration: 943, Score: 19
Iteration: 944, Score: 92
Iteration: 945, Score: 286
Iteration: 946, Score: 179
Iteration: 947, Score: 33
Iteration: 948, Score: 346
Iteration: 949, Score: 35
Iteration: 950, Score: 16
Iteration: 951, Score: 96
Iteration: 952, Score: 181
Iteration: 953, Score: 114
Iteration: 954, Score: 53
Iteration: 955, Score: 78
Iteration: 956, Score: 78
Iteration: 957, Score: 93
Iteration: 958, Score: 34
Iteration: 959, Score: 130
Iteration: 960, Score: 137
Iteration: 961, Score: 49
Iteration: 962, Score: 108
Iteration: 963, Score: 242
Iteration: 964, Score: 71
Iteration: 965, Score: 196
Iteration: 966, Score: 196
Iteration: 967, Score: 111
Iteration: 968, Score: 78
Iteration: 969, Score: 49
Iteration: 970, Score: 283
Iteration: 971, Score: 166
Iteration: 972, Score: 218
Iteration: 973, Score: 200
Iteration: 974, Score: 329
Iteration: 975, Score: 50
Iteration: 976, Score: 153
Iteration: 977, Score: 282
Iteration: 978, Score: 188
Iteration: 979, S