In [1]:
import os
import copy
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from tqdm import tqdm
import gymnasium as gym
from tictactoe.env import TicTacToeEnv
from collections import namedtuple, deque
import matplotlib.pyplot as plt

In [None]:
class DQN(nn.Module):
    """
    (Synchronous) Deep Q-Learning Network agent class
    """
    def __init__(self, n_features, n_actions, hidden_size):
        super().__init__()
        self.layer1 = nn.Linear(n_features, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)
        
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = self.layer2(x)
        return self.layer3(x)

In [None]:
class Agent():
    def __init__(self, device):
        self.device = device
        self.steps_done = 0
        
    def select_action(self, net, states):
        x = torch.Tensor(states).to(self.device)
        logits = net.forward(x)
        return logits.argmax(1).cpu().numpy(), logits

In [None]:
# envs = gym.vector.AsyncVectorEnv([GridWorldEnv for i in range(4)])
# envs_wrapper = gym.wrappers.RecordEpisodeStatistics(envs)
# envs_wrapper.reset()
n_envs = 6
n_episodes = 1000
n_steps_per_update = 9
eps = 0.1
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

envs = gym.vector.AsyncVectorEnv([TicTacToeEnv for _ in range(n_envs)])
envs_wrapper = gym.wrappers.RecordEpisodeStatistics(envs)
n_observations = np.prod(envs.single_observation_space.shape)
n_actions = envs.single_action_space.n
policy_net = DQN(n_observations, n_actions, 128)
policy_net.to(device)
# states, info = envs_wrapper.reset()

In [None]:
ep_pred_q = torch.zeros(n_steps_per_update, n_envs, device=device)
ep_rewards = torch.zeros(n_steps_per_update, n_envs, device=device)
masks = torch.zeros(n_steps_per_update, n_envs, device=device)

agent = Agent(device)
states, info = envs_wrapper.reset()
for step in range(n_steps_per_update):
    actions, logits = agent.select_action(policy_net, states.reshape(n_envs, -1))
    # epsilon greedy
    actions = (np.random.rand(n_envs) > eps) * actions + (np.random.rand(n_envs) < eps) * envs.action_space.sample() 
    next_states, rewards, terminated, truncated, infos = envs_wrapper.step(actions)
    # state_action_values = logits.gather(1, actions.unsqueeze(-1)).reshape(-1)
    # ep_rewards[step] = torch.tensor(rewards, device=device)
    # masks[ste                                     `p] = torch.tensor([not term for term in terminated])
    break

In [None]:
states, info = envs_wrapper.reset()
x = torch.Tensor(states).to(device)
logits = policy_net.forward(x)

In [None]:
states.shape

In [None]:
def mcts(net, envs_wrapper, ):
    states, info = envs_wrapper.reset()
    x = torch.Tensor(states).to(device)
    logits = policy_net.forward(x)
    pass

In [None]:
class TreeNode:
    """ 
    A tree node in the MCTS. Each node keeps track of its own value Q, prior probability P, 
    and its visited-count-adjusted prior score u.
    TODO: gamma is 1.0 in this case.
    """

    def __init__(self, parent, prior_p):
        self._parent = parent
        self._children = {}
        self._n_visits = 0 
        self._Q = 0
        self._P = prior_p 
        self._u = 0

    def expand(self, action_priors):
        """
        Expand tree by creating new children.
        action_priors: a list of tuples of actions and their prior probability according to the policy function.
        """
        for action, prob in action_priors:
            if action not in self._children:
                self._children[action] = TreeNode(prob)

    def select(self, c_puct):
        """
        Selection action among children that gives maximum action value Q plus bonus u(P).
        """
        return max(self._children.items(), key=lambda act_node:act_node[1].get_value(c_puct))
    
    def update(self, G):
        """
        Update node values form Monte-Carlo evaluation with return G.
        TODO: alpha is 1.0 in this case.
        """
        self._n_visits += 1
        self._Q += 1.0 * (G - self._Q) / self._n_visits

    def update_recursive(self, G):
        """
        Update recursively for all ancestors
        """
        if self._parent:
            self._parent.update_resursive(-G)
        self.update(G)

    def get_value(self, c_puct):
        """
        Calculate and return the value for this node.
        It is a combination of leaf evaluation Q, and this node's prior adjusted for its visit count u.
        c_punt: a number in (0, inf) controlling the relative impact of value Q, and prior probability P, 
        on this node's score.
        """
        self._u = (c_puct * self._P * np.sqrt(self._parent._n_visits)) / (1 + self._n_visits)
        return self._Q + self._u 

    def is_leaf(self):
        """
        Check if leaf node.
        """
        return self._chilfren == {}
    
    def is_root(self):
        """
        Check if root node.
        """
        return self._parent is None 

In [3]:
class MCTS:
    """A simple implementation of Monte Carlo Tree Search."""

    def __init__(self, policy_value_fn, c_puct=5, n_playout=10000):
        """
        policy_value_fn: a function that takes in a board state, outputs a list of (action, probability, q) tuples.
        c_puct: a number in (0, inf) that controls how quickly exploration converages to the maximum-value policy.
        """
        self._root = TreeNode(parent=None, prior_p=1.0)
        self._policy = policy_value_fn
        self._c_puct = c_puct 
        self._n_playout = n_playout 

    def _playout(self, state):
        """
        Run a single playout from the root to the leaf, getting a value at the leaf and propagating it back through
        its partents. State is modified in-place, so a copy must be provided.
        """
        node = self._root
        action, node = node.select(self._c_puct)
        # state.do_mode(action)

        action_probs = self._policy(state)
        # end, winner = state.game_end()
        # if not end:
            # node.expend(action_probs)
        leaf_value = self._evaluate_rollout(state)
        node.update_recursive(-leaf_value)

    def _evaluate_rollout(self, state, limit=1000):
        """
        Use the rollout policy to play until the end of the game, returning 1 if the current player wins, -1 if 
        the oppnent wins and 0 if it is a tie.
        """
        pass

    def get_move(self, state):
        """
        Runs all playout sequentially and returns the most visited action.
        state: the current game state.
        """
        for n in range(self._n_playout):
            state_copy = copy.deepcopy(state)
            self._playoput(state_copy)
            self._playout(state_copy)
            return max(self._root._children.items(), key=lambda act_node: act_node[1]._n_visits)[0]
        
    def update_with_move(self, last_move):
        """
        Step forward in the tree, keeping everything we already know about the subtree.
        """
        if last_move in self._root._children:
            self._root = self._root._children[last_move]
            self._root._parent = None 
        else:
            self._root = TreeNode(None, 1.0)

In [None]:
class MCTSPlayer:
    """
    AI player based on MCTS
    """
    pass

In [None]:
# for i_episode in range(n_episodes):
#     states, info = envs_wrapper.reset()
#     pass

In [None]:
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
BATCH_SIZE = 64
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

state, info = env.reset()
policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)

In [None]:
steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max().indices.view(1, 1)
    else:
        return torch.tensor(env.action_space.sample(), device=device, dtype=torch.long)
    

In [None]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return 
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    

In [None]:
torch.tensor(env.action_space.sample(), device=device, dtype=torch.long)

In [None]:
num_episode = 100
# def train(num_episode):
for i_episode in range(num_episode):
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

In [None]:
state, info = env.reset()
state = torch.tensor(state.reshape(-1), dtype=torch.float32, device=device).unsqueeze(0)

In [None]:
states = torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0.]])

In [None]:
a = policy_net(states).detach()

In [None]:
idx = torch.argmax(a, dim=1)

In [None]:
actions = torch.tensor([[i, j] for i in range(3) for j in range(3)])

In [None]:
torch.index_select(actions, 0, idx).cpu().numpy()