In [76]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import gymnasium as gym
from collections import deque
import torch.optim as optim
import random
from torch.utils.tensorboard import SummaryWriter
import copy

In [84]:
def one_hot(observation, n_observations):
    one_hot_observation = torch.zeros(n_observations, dtype=torch.float32)
    one_hot_observation[observation] = 1

    return one_hot_observation


def take_action(dqn, one_hot_observation, n_actions, epsilon):
    if np.random.random() < epsilon:
        return np.random.randint(n_actions)

    return torch.argmax(dqn(one_hot_observation)).item()

In [85]:
REPLAY_BUFFER_SIZE = 1000
GAMMA = 0.99
BATCH_SIZE = 128
LEARNING_RATE = 0.01
EPSILON_DECAY = 0.99
EPSILON_MIN = 0.1

In [86]:
env = gym.make("FrozenLake-v1", is_slippery=True)
n_actions = env.action_space.n
n_observations = env.observation_space.n

In [87]:
writer = SummaryWriter()

class DQN(nn.Module):
    def __init__(self, inp_size, out_size):
        super().__init__()

        self.layer = nn.Sequential(
            nn.Linear(inp_size, 128),
            nn.ReLU(),
            #
            nn.Linear(128, 64),
            nn.ReLU(),
            #
            nn.Linear(64, 32),
            nn.ReLU(),
            #
            nn.Linear(32, 4),
        )

    def forward(self, x):
        return self.layer(x)

online_dqn = DQN(n_observations, n_actions)
target_dqn = copy.deepcopy(online_dqn)
replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)

In [88]:
opt = optim.Adam(online_dqn.parameters(), LEARNING_RATE)
criterion = nn.MSELoss()

In [91]:
epsilon = 1
epochs = 20000
target_update_freq = 500



for epoch in range(epochs):
    obs, _ = env.reset()

    obs = one_hot(obs, n_observations)

    epsilon = max(EPSILON_MIN, epsilon * EPSILON_DECAY)

    while True:
        action = take_action(online_dqn, obs, n_actions, epsilon)

        new_obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        new_obs = one_hot(new_obs, n_observations)

        reward -= 0.01
        

        replay_buffer.append((obs, action, reward, new_obs, done))

        obs = new_obs
        if done:
            break

    if len(replay_buffer) < BATCH_SIZE:
        continue

    batch = random.sample(replay_buffer, BATCH_SIZE)
    states, actions, rewards, next_states, dones = zip(*batch)

    actions = torch.tensor(actions).unsqueeze(1)  # Shape(Bx1)
    rewards = torch.tensor(rewards).unsqueeze(1)  # Shape(Bx1)
    dones = torch.tensor(dones).unsqueeze(1)  # Shape(Bx1)

    states = torch.stack(states)  # Shape(Bxn_states)
    next_states = torch.stack(next_states)  # Shape(Bxn_states)

    # Training

    q_values = online_dqn(states)  # Shape(Bxn_actions)
    predicted_curr_q_values = q_values.gather(1, actions)  # Shape(Bx1)

    with torch.no_grad():
        next_q_values = target_dqn(next_states)  # Shape(Bxn_actions)
        max_next_q_value, _ = torch.max(next_q_values, dim=1, keepdim=True)  # Shape(Bx1)

    curr_q_value = rewards + GAMMA * max_next_q_value * (1 - dones.float())

    loss = criterion(predicted_curr_q_values, curr_q_value)

    opt.zero_grad()
    loss.backward()
    opt.step()
    
    writer.add_scalar("Loss", loss, epoch)
    writer.add_scalar("Reward", rewards.sum(dim=0).item(), epoch)
    writer.add_scalar("Epsilon", epsilon, epoch)

    if epoch % target_update_freq == 0:
        target_dqn.load_state_dict(online_dqn.state_dict())


    print(f"epoch {epoch}, loss: {loss}, epsilon: {epsilon}")

epoch 0, loss: 0.0042159054428339005, epsilon: 0.99
epoch 1, loss: 0.0014877605717629194, epsilon: 0.9801
epoch 2, loss: 0.0028459224849939346, epsilon: 0.9702989999999999
epoch 3, loss: 0.0023536975495517254, epsilon: 0.96059601
epoch 4, loss: 0.0031694385688751936, epsilon: 0.9509900498999999
epoch 5, loss: 0.0010986742563545704, epsilon: 0.9414801494009999
epoch 6, loss: 0.0014667994109913707, epsilon: 0.9320653479069899
epoch 7, loss: 0.0016141438391059637, epsilon: 0.92274469442792
epoch 8, loss: 0.002596058417111635, epsilon: 0.9135172474836407
epoch 9, loss: 0.0034320331178605556, epsilon: 0.9043820750088043
epoch 10, loss: 0.0027394071221351624, epsilon: 0.8953382542587163
epoch 11, loss: 0.003639703616499901, epsilon: 0.8863848717161291
epoch 12, loss: 0.0014817729825153947, epsilon: 0.8775210229989678
epoch 13, loss: 0.003266245126724243, epsilon: 0.8687458127689781
epoch 14, loss: 0.0014493544586002827, epsilon: 0.8600583546412883
epoch 15, loss: 0.0033512827940285206, epsil

In [92]:
moves = {
    0: 'left',
    1: 'down',
    2: 'right',
    3: 'up'
}

shit = {}
for obs in range(16):
    obs_t = nn.functional.one_hot(torch.tensor(obs), 16).to(torch.float32)

    shit[obs] = torch.argmax(online_dqn(obs_t)).item()
    
print(shit)
for key, val in shit.items():
    print(f'{key}: {moves[val]}')

torch.save(online_dqn.state_dict(), './dqn.pth')

ncols = 4
row, col = 2,1
obs = row*ncols+col
print(obs)
online_dqn(nn.functional.one_hot(torch.tensor(obs), 16).to(torch.float32))

{0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 2, 6: 0, 7: 2, 8: 0, 9: 0, 10: 2, 11: 0, 12: 2, 13: 2, 14: 2, 15: 2}
0: left
1: left
2: left
3: left
4: left
5: right
6: left
7: right
8: left
9: left
10: right
11: left
12: right
13: right
14: right
15: right
9


tensor([ 0.0076,  0.0068, -0.0022, -0.0064], grad_fn=<ViewBackward0>)