The state is represented by 4 numbers:

The cart position x from -2.4 to 2.4.

The cart velocity v

The pole angle θ with respect to the vertical from -12 to 12 degrees (from -0.21 to 0.21 in radians)

The pole angular velocity ω. This is the rate of change of θ.

In [11]:
import argparse
import gymnasium as gym
import numpy as np
from itertools import count
from random import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

from cartpole import CartPoleEnv

In [12]:
#env = gym.make('CartPole-v1', render_mode="human")
env = CartPoleEnv(render_mode = "human")
#print(env._max_episode_steps)

In [13]:
def select_action_random(state):
    if random() < 0.5:
        return 0
    else:
        return 1

def goodness_score(select_action, num_episodes=100):
    num_steps = 500
    ts = []
    #for episode in range(1):
    for episode in range(num_episodes):
        state = env.reset()[0]
        for t in range(1, num_steps+1):
            action = select_action(state)
            print(action)
            state, reward, done, truncated, info = env.step(action)
            if done:
                break
        ts.append(t)
    score = sum(ts) / (len(ts)*num_steps)
    return score

print(goodness_score(select_action_random))

0
0
1
0
0
1
1
0
0
1
1
0
0
0
0
1
0
1
0
1
1
1
1
1
1
1
1
1
1
1
1
0
1
1
0
0
0
1
1
1
1
1
1
1
1
1
0
1
1
1
0
0
0
0
1
0
0
1
1
1
1
1
1
1
0
1
1
1
0
0
0
0
1
0
1
1
0
1
0
1
1
1
1
1
1
0
0
0
1
0
1
1
0
0
1
0
1
0
1
0
0
0
1
0
1
0
0
1
0
1
1
1
0
0
1
1
0
0
1
1
0
1
0
0
0
0
1
1
1
1
0
0
1
1
0
0
1
1
0
0
0
0
0
0
0
1
1
0
1
1
0
0
0
0
1
0
1
1
0
1
0
0
1
0
0
0
1
1
0
1
1
1
0
0
1
0
0
0
1
1
1
0
1
1
1
1
1
1
0
0
1
0
0
1
0
0
0
1
0
0
0
1
1
0
0
0
0
1
0
1
0
0
1
0
1
1
1
1
0
0
0
0
1
1
0
0
0
1
0
1
1
0
0
0
1
1
0
1
0
0
1
0
0
0
0
1
1
1
1
0
0
1
0
0
1
0
1
1
1
0
1
0
0
1
0
1
1
0
1
1
1
0
0
1
0
0
0
0
0
1
1
1
1
0
0
0
1
0
0
1
0
1
0
1
1
1
0
1
0
1
1
1
0
0
1
1
1
1
0
1
0
1
0
0
1
0
0
1
0
1
0
1
0
1
1
1
0
0
0
1
0
0
1
1
1
0
0
1
0
0
1
1
0
0
1
0
0
1
1
0
1
0
1
0
0
0
1
1
1
0
0
0
1
1
1
0
1
0
1
1
0
0
0
1
0
0
1
0
0
1
1
1
0
1
0
1
0
0
1
1
1
0
1
1
1
1
1
0
0
1
0
1
0
0
1
0
0
1
0
0
0
1
1
0
0
0
1
0
1
0
1
0
1
0
0
0
1
0
0
0
0
1
0
1
1
1
0
1
0
0
1
1
1
0
0
0
0
1
0
0
1
1
0
1
1
1
0
1
0
0
1
1
1
1
1
1
0
1
0
0
0
1
0
0
0
0
0
1
0
1
0
1
0
1
1
0
0
0
1
1
0
0
1
1
1
1
1
0
1
0


0
1
0
1
1
0
0
0
1
1
0
0
1
1
0.08236


In [4]:
def select_action_simple(state):
    if (state[0] < 0):
        return 0
    else:
        return 1

goodness_score(select_action_simple)

0.0379

In [5]:
def select_action_good(state):
    if state[2]+state[3] < 0:
        return 0
    else:
        return 1

goodness_score(select_action_good)

0.9527

In [6]:
class PolicyNN(nn.Module):
    def __init__(self):
        super(PolicyNN, self).__init__()
        self.fc = nn.Linear(4, 2)

    def forward(self, x):
        x = self.fc(x)
        return F.softmax(x, dim=1)

def select_action_from_policy(model, state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = model(state)
    m = Categorical(probs)
    action = m.sample()
    return action.item(), m.log_prob(action)

def select_action_from_policy_best(model, state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = model(state)
    if probs[0][0] > probs[0][1]:
        return 0
    else:
        return 1

In [7]:
model_untrained = PolicyNN()

print(
    goodness_score(lambda state: select_action_from_policy(model_untrained, state)[0]),
    goodness_score(lambda state: select_action_from_policy_best(model_untrained, state))
)

0.07958 0.0391


In [8]:
model = PolicyNN()
optimizer = optim.Adam(model.parameters(), lr=0.01)

def train_wont_work(num_episodes=100):
    num_steps = 500
    for episode in range(num_episodes):
        state = env.reset()
        for t in range(1, num_steps+1):
            action = select_action(state)
            state, _, done, _ = env.step(action)
            if done:
                break
        loss = 1.0 - t / num_steps
        # this doesn't actually work, because
        # the loss function is not an explicit
        # function of the model's output; it's
        # a function of book keeping variables
        optimizer.zero_grad()
        loss.backward() # AttributeError: 'float' object has no attribute 'backward'
        optimizer.step()

def train_simple(num_episodes=10*1000):
    num_steps = 500
    ts = []
    for episode in range(num_episodes):
        state = env.reset()[0]
        probs = []
        for t in range(1, num_steps+1):
            action, prob = select_action_from_policy(model, state)
            probs.append(prob)
            #state, _, done, _ = env.step(action)
            state, reward, done, truncated, info = env.step(action)
            if done:
                break
        loss = 0
        for i, prob in enumerate(probs):
            loss += -1 * (t - i) * prob
        print(episode, t, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ts.append(t)
        if len(ts) > 10 and sum(ts[-10:])/10.0 >= num_steps * 0.95:
            print('Stopping training, looks good...')
            return

train_simple()

0 52 941.7232666015625
1 45 704.2392578125
2 30 324.8945007324219
3 55 1046.9794921875
4 61 1281.407470703125
5 44 673.77392578125
6 56 1082.5718994140625
7 57 1113.201904296875
8 49 857.1608276367188
9 63 1362.190673828125
10 39 521.0167236328125
11 71 1769.9267578125
12 31 335.93377685546875
13 91 2787.60009765625
14 36 446.0467834472656
15 30 361.3577575683594
16 42 622.5128784179688
17 95 3128.694580078125
18 37 499.397705078125
19 40 569.4281005859375
20 61 1265.5255126953125
21 42 612.5289916992188
22 66 1471.178466796875
23 57 1111.37841796875
24 46 721.0980224609375
25 55 1019.0435180664062
26 66 1470.585693359375
27 63 1352.765380859375
28 38 528.2055053710938
29 38 485.38909912109375
30 56 1061.76611328125
31 92 2860.722900390625
32 56 1070.2940673828125
33 76 1922.0391845703125
34 48 772.8508911132812
35 45 676.9366455078125
36 46 744.9132690429688
37 60 1219.5189208984375
38 37 474.7861328125
39 70 1650.0831298828125
40 42 624.837646484375
41 59 1149.55615234375
42 51 864.0

341 244 17734.30859375
342 266 21774.974609375
343 287 25565.38671875
344 99 2937.874267578125
345 178 9662.064453125
346 307 28500.330078125
347 97 2865.907958984375
348 194 11553.0
349 127 4929.98779296875
350 233 16119.736328125
351 96 2787.98974609375
352 182 10108.619140625
353 262 20877.673828125
354 429 56507.78125
355 413 51429.80078125
356 211 13622.70703125
357 386 43571.75
358 190 10884.4521484375
359 226 15699.6767578125
360 199 12184.90234375
361 175 9603.4921875
362 373 42763.76953125
363 442 55632.4140625
364 500 75643.1328125
365 311 30360.79296875
366 277 23239.25
367 169 8331.736328125
368 108 3669.12109375
369 171 8999.8349609375
370 207 13280.72265625
371 465 61371.9765625
372 500 74975.0078125
373 427 53602.37890625
374 500 74745.0546875
375 262 20242.314453125
376 367 40764.9453125
377 478 69417.828125
378 500 75608.09375
379 171 8879.1123046875
380 380 43082.99609375
381 290 24820.224609375
382 341 35727.05078125
383 290 24426.58203125
384 94 2911.75634765625
385

In [9]:
print(
    goodness_score(lambda state: select_action_from_policy(model, state)[0]),
    goodness_score(lambda state: select_action_from_policy_best(model, state))
)

0.87996 1.0
