In [110]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from copy import deepcopy
import torch.optim as optim
from collections import deque
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [126]:
BATCH_SIZE = 128


def get_action(net: torch.Tensor, state: torch.Tensor, epsilon: float) -> int:
    if np.random.rand() < epsilon:
        return np.random.randint(0, 2)

    return torch.argmax(net(state[None])).item()

In [None]:
REPLAY_BUFFER_SIZE = 10000

online_net = nn.Sequential(
    nn.Linear(4, 128),
    nn.ReLU(),
    #
    nn.Linear(128, 64),
    nn.ReLU(),
    #
    nn.Linear(64, 32),
    nn.ReLU(),
    #
    nn.Linear(32, 16),
    nn.ReLU(),
    #
    nn.Linear(16, 2),
).to(device)

target_net = deepcopy(online_net)


replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)

opt = optim.Adam(online_net.parameters(), lr=1e-2)
criterion = nn.MSELoss()

env = gym.make("CartPole-v1")

In [131]:
EPS_DECAY_RATE = 0.005
MIN_EPS = 0.01
GAMMA = 0.99
epsilon = 0.3
episodes = 500
epochs = 10
target_net_update_freq = 100

writer = SummaryWriter()

for epoch in range(epochs):
    R_sum = 0
    R_max = 0

    for episode in range(episodes):
        obs, _ = env.reset()
        epsilon = max(epsilon - EPS_DECAY_RATE * epsilon, MIN_EPS)

        R = 0

        while True:
            state = torch.tensor(obs, device=device)
            action = get_action(online_net, state, epsilon)

            next_obs, reward, terminated, truncated, _ = env.step(action)
            R += reward
            done = terminated or truncated
            obs = next_obs

            next_state = torch.tensor(next_obs, device=device)
            transition = (state, action, reward, next_state, float(done))

            replay_buffer.append(transition)

            if done:
                break

        R_max = max(R_max, R)
        R_sum += R

        sample = random.sample(replay_buffer, k=min(len(replay_buffer), BATCH_SIZE))

        states, actions, rewards, next_states, dones = zip(*sample)

        states_t = torch.stack(states)  # Shape(Bx4)
        actions_t = torch.tensor(actions, device=device, dtype=torch.int64).view(
            -1, 1
        )  # Shape(Bx1)
        rewards_t = torch.tensor(rewards, device=device, dtype=torch.float32).view(
            -1, 1
        )  # Shape(Bx1)
        next_states_t = torch.stack(next_states)  # Shape(Bx4)
        dones_t = torch.tensor(dones, device=device, dtype=torch.float32).view(
            -1, 1
        )  # Shape(Bx1)

        q_pred = online_net(states_t)  # Shape(Bx2)
        q_pred = q_pred.gather(1, actions_t)

        with torch.no_grad():
            q_next_state = target_net(next_states_t)  # Shape(Bx2)
            q_next_max = q_next_state.max(dim=1)[0].view(-1, 1)
            q_target = rewards_t + (1 - dones_t) * GAMMA * q_next_max  # Shape(B)

        loss = criterion(q_pred, q_target)

        opt.zero_grad()
        loss.backward()
        opt.step()

        if episode % target_net_update_freq == 0:
            target_net = deepcopy(online_net)

        writer.add_scalar("loss", loss, episode)
        writer.add_scalar("R_max", R_max, episode)

        print(
            f"episode: {episode}, epsilon: {epsilon}, loss: {loss}, max_R: {R_max}, curr_R: {R}"
        )

    writer.add_scalar("R_max", R_max, epoch)
    writer.add_scalar("R_avg", R_sum / episodes, epoch)

episode: 0, epsilon: 0.2985, loss: 3.697507619857788, max_R: 26.0, curr_R: 26.0
episode: 1, epsilon: 0.2970075, loss: 1.0164258480072021, max_R: 38.0, curr_R: 38.0
episode: 2, epsilon: 0.29552246249999997, loss: 2.0511367321014404, max_R: 46.0, curr_R: 46.0
episode: 3, epsilon: 0.2940448501875, loss: 1.7533681392669678, max_R: 46.0, curr_R: 13.0
episode: 4, epsilon: 0.29257462593656247, loss: 0.9063012003898621, max_R: 93.0, curr_R: 93.0
episode: 5, epsilon: 0.2911117528068797, loss: 3.4469830989837646, max_R: 106.0, curr_R: 106.0
episode: 6, epsilon: 0.28965619404284526, loss: 0.9743015766143799, max_R: 106.0, curr_R: 90.0
episode: 7, epsilon: 0.28820791307263105, loss: 5.22389030456543, max_R: 106.0, curr_R: 16.0
episode: 8, epsilon: 0.28676687350726787, loss: 1.4466012716293335, max_R: 141.0, curr_R: 141.0
episode: 9, epsilon: 0.2853330391397315, loss: 7.060595989227295, max_R: 141.0, curr_R: 81.0
episode: 10, epsilon: 0.28390637394403284, loss: 1.9113752841949463, max_R: 288.0, cur

KeyboardInterrupt: 

In [132]:
torch.save(online_net.state_dict(), './cartpole_dqn.pth')