In [None]:
import argparse
import gymnasium as gym
import numpy as np
from itertools import count
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import pygame

env = gym.make('CartPole-v1',render_mode="rgb_array")
env.reset(seed=543)
torch.manual_seed(543)
gamma = 0.9
log_interval = 1

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)

policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()

def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item()

def finish_episode():
    R = 0
    policy_loss = []
    returns = deque()
    for r in policy.rewards[::-1]:
        R = r + gamma * R
        returns.appendleft(R)
    returns = torch.tensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + eps)
    for log_prob, R in zip(policy.saved_log_probs, returns):
        policy_loss.append(-log_prob * R)
    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()
    del policy.rewards[:]
    del policy.saved_log_probs[:]

def main():
    running_reward = 10
    for i_episode in count(1):
        state, _ = env.reset()
        ep_reward = 0
        for t in range(1, 10000):  # Don't infinite loop while learning
            action = select_action(state)
            state, reward, done, _, _ = env.step(action)
            env.render()
            policy.rewards.append(reward)
            ep_reward += reward
            if done:
                break

        running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
        finish_episode()
        if i_episode % log_interval == 0:
            print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}'.format(
                  i_episode, ep_reward, running_reward))
        if running_reward > env.spec.reward_threshold:
            print("Solved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(running_reward, t))
            break


main()

Episode 1	Last reward: 34.00	Average reward: 11.20
Episode 2	Last reward: 20.00	Average reward: 11.64
Episode 3	Last reward: 12.00	Average reward: 11.66
Episode 4	Last reward: 10.00	Average reward: 11.58
Episode 5	Last reward: 40.00	Average reward: 13.00
Episode 6	Last reward: 11.00	Average reward: 12.90
Episode 7	Last reward: 10.00	Average reward: 12.75
Episode 8	Last reward: 14.00	Average reward: 12.81
Episode 9	Last reward: 21.00	Average reward: 13.22
Episode 10	Last reward: 29.00	Average reward: 14.01
Episode 11	Last reward: 20.00	Average reward: 14.31
Episode 12	Last reward: 13.00	Average reward: 14.25
Episode 13	Last reward: 24.00	Average reward: 14.73
Episode 14	Last reward: 18.00	Average reward: 14.90
Episode 15	Last reward: 10.00	Average reward: 14.65
Episode 16	Last reward: 9.00	Average reward: 14.37
Episode 17	Last reward: 54.00	Average reward: 16.35
Episode 18	Last reward: 12.00	Average reward: 16.13
Episode 19	Last reward: 20.00	Average reward: 16.33
Episode 20	Last reward

In [11]:
%pip install pygame


Collecting pygame
  Downloading pygame-2.2.0-cp310-cp310-macosx_10_9_x86_64.whl (13.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.0/13.0 MB[0m [31m36.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: pygame
Successfully installed pygame-2.2.0

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3.1[0m[39;49m -> [0m[32;49m23.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49m/usr/local/opt/python@3.10/bin/python3.10 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
