The state is represented by 4 numbers:

The cart position x from -2.4 to 2.4.

The cart velocity v

The pole angle θ with respect to the vertical from -12 to 12 degrees (from -0.21 to 0.21 in radians)

The pole angular velocity ω. This is the rate of change of θ.

In [1]:
import argparse
import gym
import numpy as np
from itertools import count
from random import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [2]:
env = gym.make('CartPole-v1')
print(env._max_episode_steps)

500


In [3]:
def select_action_random(state):
    if random() < 0.5:
        return 0
    else:
        return 1

def goodness_score(select_action, num_episodes=100):
    num_steps = 500
    ts = []
    for episode in range(num_episodes):
        state = env.reset()
        for t in range(1, num_steps+1):
            action = select_action(state)
            state, _, done, _ = env.step(action)
            if done:
                break
        ts.append(t)
    score = sum(ts) / (len(ts)*num_steps)
    return score

print(goodness_score(select_action_random))

0.04412


In [4]:
def select_action_simple(state):
    if state[2] < 0:
        return 0
    else:
        return 1

goodness_score(select_action_simple)

0.08618

In [5]:
def select_action_good(state):
    if state[2]+state[3] < 0:
        return 0
    else:
        return 1

goodness_score(select_action_good)

0.9662

In [6]:
class PolicyNN(nn.Module):
    def __init__(self):
        super(PolicyNN, self).__init__()
        self.fc = nn.Linear(4, 2)

    def forward(self, x):
        x = self.fc(x)
        return F.softmax(x, dim=1)

def select_action_from_policy(model, state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = model(state)
    m = Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)

def select_action_from_policy_best(model, state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = model(state)
    if probs[0][0] > probs[0][1]:
        return 0
    else:
        return 1

In [7]:
model_untrained = PolicyNN()

print(
    goodness_score(lambda state: select_action_from_policy(model_untrained, state)[0]),
    goodness_score(lambda state: select_action_from_policy_best(model_untrained, state))
)

0.04146 0.01874


In [8]:
model = PolicyNN()
optimizer = optim.Adam(model.parameters(), lr=0.01)

def train_wont_work(num_episodes=100):
    num_steps = 500
    for episode in range(num_episodes):
        state = env.reset()
        for t in range(1, num_steps+1):
            action = select_action(state)
            state, _, done, _ = env.step(action)
            if done:
                break
        loss = 1.0 - t / num_steps
        # this doesn't actually work, because
        # the loss function is not an explicit
        # function of the model's output; it's
        # a function of book keeping variables
        optimizer.zero_grad()
        loss.backward() # AttributeError: 'float' object has no attribute 'backward'
        optimizer.step()

def train_simple(num_episodes=10*1000):
    num_steps = 500
    ts = []
    for episode in range(num_episodes):
        state = env.reset()
        probs = []
        for t in range(1, num_steps+1):
            action, prob = select_action_from_policy(model, state)
            probs.append(prob)
            state, _, done, _ = env.step(action)
            if done:
                break
        loss = 0
        for i, prob in enumerate(probs):
            loss += -1 * (t - i) * prob
        print(episode, t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ts.append(t)
        if len(ts) > 10 and sum(ts[-10:])/10.0 >= num_steps * 0.95:
            print('Stopping training, looks good...')
            return

train_simple()

0 16 85.35948944091797
1 20 133.23199462890625
2 26 230.32022094726562
3 26 236.91815185546875
4 18 105.97496795654297
5 10 38.34042739868164
6 13 56.65700149536133
7 18 106.04237365722656
8 48 808.704833984375
9 30 299.4029541015625
10 27 238.62835693359375
11 33 374.26947021484375
12 34 416.2560119628906
13 23 180.34190368652344
14 30 304.6261291503906
15 78 2066.367919921875
16 34 396.86590576171875
17 35 422.7949523925781
18 75 1896.2210693359375
19 16 89.17831420898438
20 45 671.5740966796875
21 106 3748.981201171875
22 37 451.1553649902344
23 19 134.71336364746094
24 18 110.48934173583984
25 45 670.2572631835938
26 19 152.81546020507812
27 43 647.4814453125
28 13 63.41194534301758
29 67 1425.4854736328125
30 13 61.09122848510742
31 39 496.96319580078125
32 30 313.1767578125
33 61 1243.1768798828125
34 27 233.2957763671875
35 91 2670.95654296875
36 72 1638.4683837890625
37 93 2806.391845703125
38 46 680.6381225585938
39 59 1091.8087158203125
40 92 2690.11767578125
41 69 1588.32800

340 240 16819.33984375
341 233 15935.90234375
342 163 7877.0625
343 223 15377.14453125
344 210 13678.2568359375
345 217 13686.630859375
346 180 9972.6943359375
347 268 20767.55859375
348 242 17234.0703125
349 274 22112.431640625
350 312 29959.029296875
351 29 293.5941467285156
352 23 230.5286865234375
353 167 8213.177734375
354 189 10782.328125
355 177 9567.7294921875
356 290 25861.28515625
357 244 17402.43359375
358 161 8021.90234375
359 233 16108.310546875
360 171 8669.9267578125
361 319 30163.59765625
362 334 32795.08203125
363 175 8999.048828125
364 186 10331.9189453125
365 285 24494.849609375
366 281 23083.56640625
367 283 23728.25
368 500 73870.4140625
369 323 30357.638671875
370 20 150.28189086914062
371 321 30774.380859375
372 500 75192.5390625
373 360 38349.375
374 143 6186.6005859375
375 450 60154.55859375
376 500 74800.71875
377 282 24131.48046875
378 112 3996.891357421875
379 169 8817.0615234375
380 500 74064.4375
381 500 77291.203125
382 445 59216.8203125
383 220 14812.448

In [9]:
print(
    goodness_score(lambda state: select_action_from_policy(model, state)[0]),
    goodness_score(lambda state: select_action_from_policy_best(model, state))
)

0.81904 1.0
