In [13]:
import os
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from tqdm import tqdm
import gymnasium as gym
from tictactoe.env import TicTacToeEnv
from gridworld.env import GridWorldEnv
from collections import namedtuple, deque
import matplotlib.pyplot as plt

In [55]:
class DQN(nn.Module):
    """
    (Synchronous) Deep Q-Learning Network agent class
    """
    def __init__(self, n_features, n_actions, hidden_size):
        super().__init__()
        self.layer1 = nn.Linear(n_features, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)
        
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = self.layer2(x)
        return self.layer3(x)

In [65]:
class Agent():
    def __init__(self, device):
        self.device = device
        self.steps_done = 0
        
    def select_action(self, net, states):
        x = torch.Tensor(states).to(self.device)
        logits = net.forward(x)
        return logits.argmax(1), logits

In [68]:
# envs = gym.vector.AsyncVectorEnv([GridWorldEnv for i in range(4)])
# envs_wrapper = gym.wrappers.RecordEpisodeStatistics(envs)
# envs_wrapper.reset()
n_envs = 64
n_episodes = 1000
n_steps_per_update = 9
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

envs = gym.vector.AsyncVectorEnv([TicTacToeEnv for _ in range(n_envs)])
envs_wrapper = gym.wrappers.RecordEpisodeStatistics(envs)
n_observations = np.prod(envs.single_observation_space.shape)
n_actions = envs.single_action_space.n
policy_net = DQN(n_observations, n_actions, 128)
# states, info = envs_wrapper.reset()

In [87]:
ep_pred_q = torch.zeros(n_steps_per_update, n_envs, device=device)
ep_rewards = torch.zeros(n_steps_per_update, n_envs, device=device)
masks = torch.zeros(n_steps_per_update, n_envs, device=device)

agent = Agent(device)
states, info = envs_wrapper.reset()
for step in range(n_steps_per_update):
    actions, logits = agent.select_action(policy_net, states.reshape(n_envs, -1))
    states, rewards, terminated, truncated, infos = envs_wrapper.step(actions.cpu().numpy())
    state_action_values = logits.gather(1, actions).reshape(-1)
    ep_rewards[step] = torch.tensor(rewards, device=device)
    masks[step] = torch.tensor([not term for term in terminated])
    break

In [None]:
for i_episode in range(n_episodes):
    states, info = envs_wrapper.reset()
    

In [2]:
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
BATCH_SIZE = 64
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

state, info = env.reset()
policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)

NameError: name 'env' is not defined

In [15]:
steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max().indices.view(1, 1)
    else:
        return torch.tensor(env.action_space.sample(), device=device, dtype=torch.long)
    

In [16]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return 
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    

In [17]:
torch.tensor(env.action_space.sample(), device=device, dtype=torch.long)

tensor([1, 2])

In [18]:
num_episode = 100
# def train(num_episode):
for i_episode in range(num_episode):
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

In [19]:
state, info = env.reset()
state = torch.tensor(state.reshape(-1), dtype=torch.float32, device=device).unsqueeze(0)

In [20]:
states = torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0.]])

In [21]:
a = policy_net(states).detach()

In [24]:
idx = torch.argmax(a, dim=1)

In [27]:
actions = torch.tensor([[i, j] for i in range(3) for j in range(3)])

In [29]:
torch.index_select(actions, 0, idx).cpu().numpy()

array([[2, 2],
       [2, 2],
       [2, 2]])