The state is represented by 4 numbers:

The cart position x from -2.4 to 2.4.

The cart velocity v

The pole angle θ with respect to the vertical from -12 to 12 degrees (from -0.21 to 0.21 in radians)

The pole angular velocity ω. This is the rate of change of θ.

In [1]:
import argparse
import gymnasium as gym
import numpy as np
from itertools import count
from random import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [2]:
env = gym.make('CartPole-v1')
print(env._max_episode_steps)

500


In [3]:
def select_action_random(state):
    if random() < 0.5:
        return 0
    else:
        return 1

def goodness_score(select_action, num_episodes=100):
    num_steps = 500
    ts = []
    #for episode in range(1):
    for episode in range(num_episodes):
        state = env.reset()[0]
        for t in range(1, num_steps+1):
            action = select_action(state)
            state, reward, done, truncated, info = env.step(action)
            if done:
                break
        ts.append(t)
    score = sum(ts) / (len(ts)*num_steps)
    return score

print(goodness_score(select_action_random))

0.04576


In [4]:
def select_action_simple(state):
    if (state[0] < 0):
        return 0
    else:
        return 1

goodness_score(select_action_simple)

0.0188

In [5]:
def select_action_good(state):
    if state[2]+state[3] < 0:
        return 0
    else:
        return 1

goodness_score(select_action_good)

0.94296

In [10]:
class PolicyNN(nn.Module):
    def __init__(self):
        super(PolicyNN, self).__init__()
        self.fc = nn.Linear(4, 2)

    def forward(self, x):
        x = self.fc(x)
        return F.softmax(x, dim=1)

def select_action_from_policy(model, state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = model(state)
    m = Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)

def select_action_from_policy_best(model, state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = model(state)
    if probs[0][0] > probs[0][1]:
        return 0
    else:
        return 1

In [11]:
model_untrained = PolicyNN()

print(
    goodness_score(lambda state: select_action_from_policy(model_untrained, state)[0]),
    goodness_score(lambda state: select_action_from_policy_best(model_untrained, state))
)

0.07166 0.1643


In [12]:
model = PolicyNN()
optimizer = optim.Adam(model.parameters(), lr=0.01)

def train_wont_work(num_episodes=100):
    num_steps = 500
    for episode in range(num_episodes):
        state = env.reset()
        for t in range(1, num_steps+1):
            action = select_action(state)
            state, _, done, _ = env.step(action)
            if done:
                break
        loss = 1.0 - t / num_steps
        # this doesn't actually work, because
        # the loss function is not an explicit
        # function of the model's output; it's
        # a function of book keeping variables
        optimizer.zero_grad()
        loss.backward() # AttributeError: 'float' object has no attribute 'backward'
        optimizer.step()

def train_simple(num_episodes=10*1000):
    num_steps = 500
    ts = []
    for episode in range(num_episodes):
        state = env.reset()[0]
        probs = []
        for t in range(1, num_steps+1):
            action, prob = select_action_from_policy(model, state)
            probs.append(prob)
            #state, _, done, _ = env.step(action)
            state, reward, done, truncated, info = env.step(action)
            if done:
                break
        loss = 0
        for i, prob in enumerate(probs):
            loss += -1 * (t - i) * prob
        print(episode, t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ts.append(t)
        if len(ts) > 10 and sum(ts[-10:])/10.0 >= num_steps * 0.95:
            print('Stopping training, looks good...')
            return

train_simple()

0 11 41.889060974121094
1 21 165.8580322265625
2 20 141.98497009277344
3 22 175.63726806640625
4 16 94.69956970214844
5 23 193.51429748535156
6 15 85.28871154785156
7 16 92.27407836914062
8 41 607.623291015625
9 31 350.4656066894531
10 14 71.58489990234375
11 27 264.146240234375
12 16 94.936767578125
13 14 72.85687255859375
14 16 94.10961151123047
15 15 82.8990707397461
16 23 190.28878784179688
17 25 228.7447052001953
18 24 210.80918884277344
19 16 95.28022766113281
20 28 281.9775695800781
21 21 160.26431274414062
22 14 73.12321472167969
23 29 302.3507385253906
24 12 52.952213287353516
25 13 63.191993713378906
26 46 746.6875
27 21 160.20901489257812
28 14 72.79129028320312
29 21 160.1837158203125
30 45 710.5232543945312
31 11 46.253395080566406
32 15 84.12749481201172
33 15 83.43211364746094
34 41 588.076416015625
35 12 55.42134094238281
36 48 801.7495727539062
37 36 453.269287109375
38 13 62.411041259765625
39 24 202.35855102539062
40 14 76.8935775756836
41 17 104.16140747070312
42 49

339 151 7055.08056640625
340 116 3999.48095703125
341 143 6363.71484375
342 105 3320.79248046875
343 107 3376.525390625
344 190 10949.4140625
345 110 3683.079833984375
346 174 9177.0
347 366 40124.98046875
348 265 21237.23046875
349 114 4005.021728515625
350 500 73968.2265625
351 500 74289.0625
352 321 30632.287109375
353 445 59186.0625
354 308 27104.875
355 241 17056.51953125
356 187 10242.3046875
357 187 10388.4775390625
358 147 6415.5615234375
359 273 21662.033203125
360 302 26588.78125
361 118 4192.95703125
362 242 17915.302734375
363 159 7398.41552734375
364 399 45295.98828125
365 232 15214.955078125
366 289 24315.0546875
367 246 17256.095703125
368 252 18745.083984375
369 295 25604.669921875
370 327 32045.205078125
371 192 10760.0888671875
372 120 4233.517578125
373 161 7754.95751953125
374 157 7250.3603515625
375 57 1044.3077392578125
376 71 1594.7340087890625
377 247 17846.345703125
378 200 11506.4794921875
379 391 46206.28125
380 176 8862.7734375
381 206 12375.4072265625
382 3

In [13]:
print(
    goodness_score(lambda state: select_action_from_policy(model, state)[0]),
    goodness_score(lambda state: select_action_from_policy_best(model, state))
)

0.92858 1.0
