In [24]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import torch

from tqdm.notebook import trange
from random import random

In [50]:
model = torch.nn.Sequential(
    torch.nn.Linear(4, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 50),
    torch.nn.ReLU(),
    torch.nn.Linear(50, 2))

In [51]:
loss_fn = torch.nn.MSELoss()
lr = 0.1
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [52]:
max_episodes = 5000
gamma = 0.9
epsilon = 1.0

env = gym.make('CartPole-v1', render_mode=None)

for episode in (pbar:=trange(max_episodes)):

    g = 0
    total_reward = 0
    obs_, _ = env.reset()
    obs = torch.from_numpy(obs_).float()

    while True:
        # a = env.action_space.sample()
        q = model(obs)
        if random() < epsilon:
            a = env.action_space.sample()
        else:
            a = torch.argmax(q).item()

        obs_, r, terminated, truncated, _ = env.step(a)
        obs = torch.from_numpy(obs_).float()

        g = gamma*g + r
        total_reward += r

        with torch.no_grad():
            next_q = model(obs)
            next_q_max = torch.argmax(next_q)

        if terminated or truncated:
            stop = True
            target = float(r)
        else:
            stop = False
            target = float(r) + gamma*next_q_max

        Y = torch.tensor([target]).detach()
        X = q[a]
        loss = loss_fn(X, Y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if stop:
            break

    epsilon = 1.0 - 1e-3*episode
    if epsilon < 0.1:
        epsilon = 0.1

    pbar.set_postfix_str(
        f"g={g:5.2f}, tr={total_reward:3.0f} eps={epsilon:.2f}", refresh=False)

env.close()

  0%|          | 0/5000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [37]:
max_episodes = 100
env = gym.make('CartPole-v1', render_mode='human')

for episode in (pbar:=trange(max_episodes)):

    obs_, _ = env.reset()
    obs = torch.from_numpy(obs_).float()

    while True:
        q = model(obs)
        a = torch.argmax(q).item()

        obs_, r, terminated, truncated, _ = env.step(a)
        obs = torch.from_numpy(obs_).float()


        if terminated or truncated:
            break

env.close()

  0%|          | 0/100 [00:00<?, ?it/s]