In [110]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from copy import deepcopy
import torch.optim as optim
from collections import deque
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
BATCH_SIZE = 128


def get_action(net: torch.Tensor, state: torch.Tensor, epsilon: float) -> int:
    if np.random.rand() < epsilon:
        return np.random.randint(0, 2)

    return torch.argmax(net(state[None])).item()


def add_transition(
    replay_buffer: deque, elite_buffer: deque, transition: list, min_elite_val=None
) -> int:
    """
    transition is of shape (State, Action, Reward, NextState)
    returns most minimum for future reference
    """
    replay_buffer.append(transition)

    if len(elite_buffer) < BATCH_SIZE / 2:
        elite_buffer.append(transition)

    else:
        if min_elite_val is not None and transition[2] < min_elite_val:
            return min_elite_val

        min_elite = min(elite_buffer, key=lambda x: x[2])
        if transition[2] > min_elite[2]:
            elite_buffer.remove(min_elite)
            elite_buffer.append(transition)

            return min_elite[2]

    return min_elite_val

In [118]:
REPLAY_BUFFER_SIZE = 1000

online_net = nn.Sequential(
    nn.Linear(4, 128),
    nn.ReLU(),
    #
    nn.Linear(128, 64),
    nn.ReLU(),
    #
    nn.Linear(64, 32),
    nn.ReLU(),
    #
    nn.Linear(32, 16),
    nn.ReLU(),
    #
    nn.Linear(16, 2),
).to(device)

target_net = deepcopy(online_net)


replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)
elite_buffer = deque(maxlen=BATCH_SIZE // 2)

opt = optim.Adam(online_net.parameters(), lr=1e-2)
criterion = nn.MSELoss()

env = gym.make("CartPole-v1")

In [None]:
EPS_DECAY_RATE = 0.001
MIN_EPS = 0.1
GAMMA = 0.99
epsilon = 0.5
episodes = 1000
target_net_update_freq = 100

avg_R_div = 0
R_sum = 0
max_R = 0
min_elite_val = None

for episode in range(episodes):
    obs, _ = env.reset()
    epsilon = max(epsilon - EPS_DECAY_RATE * epsilon, MIN_EPS)

    R = 0

    while True:
        state = torch.tensor(obs, device=device)
        action = get_action(online_net, state, epsilon)

        next_obs, reward, terminated, truncated, _ = env.step(action)
        R += reward
        done = terminated or truncated
        obs = next_obs

        next_state = torch.tensor(next_obs, device=device)
        transition = (state, action, reward, next_state, float(done))

        min_elite_val = add_transition(
            replay_buffer, elite_buffer, transition, min_elite_val
        )

        if done:
            break

    max_R = max(max_R, R)
    avg_R_div += 1
    R_sum += R
    avg_R = (R_sum) / (avg_R_div)

    sample = random.sample(
        replay_buffer, k=min(len(replay_buffer), BATCH_SIZE // 2)
    ) + list(elite_buffer)

    states, actions, rewards, next_states, dones = zip(*sample)

    states_t = torch.stack(states)  # Shape(Bx4)
    actions_t = torch.tensor(actions, device=device, dtype=torch.int64).view(
        -1, 1
    )  # Shape(Bx1)
    rewards_t = torch.tensor(rewards, device=device, dtype=torch.float32).view(
        -1, 1
    )  # Shape(Bx1)
    next_states_t = torch.stack(next_states)  # Shape(Bx4)
    dones_t = torch.tensor(dones, device=device, dtype=torch.float32).view(
        -1, 1
    )  # Shape(Bx1)

    q_pred = online_net(states_t)  # Shape(Bx2)
    q_pred = q_pred.gather(1, actions_t)

    with torch.no_grad():
        q_next_state = target_net(next_states_t)  # Shape(Bx2)
        q_next_max = q_next_state.max(dim=1)[0].view(-1, 1)
        q_target = rewards_t + (1 - dones_t) * GAMMA * q_next_max  # Shape(B)

    loss = criterion(q_pred, q_target)

    opt.zero_grad()
    loss.backward()
    opt.step()

    if episode % target_net_update_freq == 0:
        R_sum = 0
        avg_R_div = 0
        max_R = 0
        target_net = deepcopy(online_net)

    print(
        f"episode: {episode}, epsilon: {epsilon}, loss: {loss}, max_R: {max_R}, avg_R: {avg_R}"
    )

episode: 0, epsilon: 0.4995, loss: 0.05066611245274544, max_R: 0, avg_R: 48.0
episode: 1, epsilon: 0.4990005, loss: 0.7128541469573975, max_R: 11.0, avg_R: 11.0
episode: 2, epsilon: 0.4985014995, loss: 0.4006396234035492, max_R: 77.0, avg_R: 44.0
episode: 3, epsilon: 0.4980029980005, loss: 1.277982473373413, max_R: 77.0, avg_R: 33.666666666666664
episode: 4, epsilon: 0.4975049950024995, loss: 0.7608539462089539, max_R: 229.0, avg_R: 82.5
episode: 5, epsilon: 0.497007490007497, loss: 1.1547342538833618, max_R: 229.0, avg_R: 72.2
episode: 6, epsilon: 0.4965104825174895, loss: 0.16062523424625397, max_R: 229.0, avg_R: 73.33333333333333
episode: 7, epsilon: 0.49601397203497205, loss: 1.4712132215499878, max_R: 229.0, avg_R: 69.0
episode: 8, epsilon: 0.4955179580629371, loss: 0.3880242705345154, max_R: 229.0, avg_R: 61.75
episode: 9, epsilon: 0.4950224401048741, loss: 0.77504962682724, max_R: 229.0, avg_R: 56.77777777777778
episode: 10, epsilon: 0.49452741766476926, loss: 0.1936896890401840

KeyboardInterrupt: 

In [102]:
torch.save(online_net.state_dict(), './cartpole_dqn.pth')