In [17]:
import json
import math
import random
from collections import defaultdict
from itertools import count

import numpy as np
import gymnasium as gym
from gymnasium.spaces import Discrete
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

import games

In [18]:
import pygame
import numpy as np
import gymnasium as gym
from gymnasium import spaces


def is_end(board, mark):
    n_rows, n_cols = board.shape()
    for row_index in range(n_rows):
        row = board[row_index, :]
        if (row == mark).all():
            return 1, True
    for col_index in range(n_rows):
        col = board[:, col_index]
        if (col == mark).all():
            return 1, True    
    if (board.diagonal() == mark).all() or (np.fliplr(board).diagonal() == mark).all():
        return 1, True
    else:
        return 0, False

class TicTacToeEnv(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
    
    def __init__(self, render_mode=None):
        self.window_size = 512
        # Observation space is a 3 * 3 deck.
        self.observation_space = spaces.Box(-1, 1, shape=(3, 3), dtype=np.int8)
        self.action_space = spaces.Discrete(9)
        self.agent_index_space = spaces.Discrete(2)
        self.agent_mark_mapping = {
            0: 1,
            1: 2
        }
        self.render_mode = render_mode
        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None
        
    def _get_obs(self):
        return self._board
    
    def _get_info(self):
        # Return the index of agent which ready to act.
        return {"agent_index": self._agent_index}
    
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        # Initialize the deck with zeros and the agent index.
        self._board = np.zeros((3, 3), dtype=np.int8)
        # Randomly pick an agent.
        self._agent_index = self.agent_index_space.sample()

        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()
        
        return observation, info
    
    def step(self, action):
        assert self.action_space.contains(action)
        self._agent_index += 1
        if self._agent_index >= self.agent_space.n:
            self._agent_index = 0        

        mark = self.agent_mark_mapping[self._agent_index]
        move = (action // 3, action % 3)
        reward, terminated = 0, False
        if self._board[move] == 0:
            self._board[move] = mark
            reward, terminated = is_end(self._board, mark)
        return self._get_obs(), reward, terminated, self.get_info()
    
    def _render_frame(self):
        if self.window is None and self.render_mode == "human":
            pygame.init()
            self.window = pygame.display.set_mode(
                (self.window_size, self.window_size)
            )
        if self.clock is None and self.render_mode == "human":
            self.clock = pygame.time.Clock()
        canvas = pygame.Surface(
            (self.window_size, self.window_size)
        )
        canvas.fill((255, 255, 255))
        pix_square_size = self.window_size / 3
        # Draw lines to separate boxes.
        pygame.draw.aaline(
            canvas, 
            (255, 255, 255),
            (pix_square_size * 1,  pix_square_size * 0),
            (pix_square_size * 1,  pix_square_size * 3)
        )
        pygame.draw.aaline(
            canvas, 
            (255, 255, 255),
            (pix_square_size * 2,  pix_square_size * 0),
            (pix_square_size * 2,  pix_square_size * 3)
        )
        pygame.draw.aaline(
            canvas, 
            (255, 255, 255),
            (pix_square_size * 0,  pix_square_size * 1),
            (pix_square_size * 3,  pix_square_size * 1)
        )
        pygame.draw.aaline(
            canvas, 
            (255, 255, 255),
            (pix_square_size * 0,  pix_square_size * 2),
            (pix_square_size * 3,  pix_square_size * 2)
        )
        # Draw "X"s or "O"s.
        for i in range(3):
            for j in range(3):
                if self._deck[i, j] == -1:
                    pygame.draw.circle(
                        canvas,
                        (255, 255, 255),
                        (pix_square_size * (0.5 + i),  pix_square_size * (0.5 + j)),
                        pix_square_size - 5
                    )
                    pygame.draw.circle(
                        canvas,
                        (0, 0, 0),
                        (pix_square_size * (0.5 + i),  pix_square_size * (0.5 + j)),
                        pix_square_size - 10
                    )
                elif self._deck[i, j] == 1:
                    pygame.draw.line(
                        canvas, 
                        (255, 255, 255),
                        (pix_square_size * i + 5,  pix_square_size * j + 5),
                        (pix_square_size * (i + 1) - 5,  pix_square_size * (j + 1) - 5),
                        7
                    )
                    pygame.draw.line(
                        canvas, 
                        (255, 255, 255),
                        (pix_square_size * (i + 1) - 5,  pix_square_size * j + 5),
                        (pix_square_size * i + 5,  pix_square_size * (j + 1) - 5),
                        7
                    )
        
        if self.render_mode == "human":
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()
            self.clock.tick(self.metadata["render_fps"])
        else:
            return np.transpose(
                np.array(pygame.surface.pixels3d(canvas), axis=(1, 0, 2))
            )
    
    def render(self):
        if self.render_mode == "rgb_array":
            return self._render_frame()
        
    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

In [21]:
class EpsilonGreedy:
    def __init__(self, epsilon=0.1):
        self.epsilon = epsilon
    
    def select(self, num_envs, states, action_spaces, qfunction):
        mask = np.random.rand(num_envs) < self.epsilon
        mask_inv = np.invert(mask)
        return action_spaces.sample() * mask + qfunction.get_argmax_q(states) * mask_inv


class EpsilonDecreasing:
    def __init__(self, epsilon=1.0, alpha=0.99, lower_bound=0.01):
        self.epsilon_greedy_bandit = EpsilonGreedy(epsilon)
        self.initial_epsilon = epsilon
        self.alpha = alpha
        self.lower_bound = lower_bound

    def reset(self):
        self.epsilon_greedy_bandit = EpsilonGreedy(self.initial_epsilon)

    def select(self, num_envs, states, action_spaces, qfunction):
        results = self.epsilon_greedy_bandit.select(num_envs, states, action_spaces, qfunction)
        self.epsilon_greedy_bandit.epsilon = max(
            self.epsilon_greedy_bandit.epsilon * self.alpha, self.lower_bound
            # self.epsilon_greedy_bandit.epsilon * self.alpha ** num_envs, self.lower_bound
        )
        return results

In [None]:
class DeepQFunction(nn.Module):
    def __init__(self, num_states, num_actions, hidden_dim=32, alpha=1e-4, device=None):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(num_states, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_actions)
        ).to(device)
        self.optimiser = optim.AdamW(self.parameters(), lr=alpha, amsgrad=True)
        self.device = device

    def forward(self, x):
        return self.net(x)

    def get_q(self, states, actions):
        states_tensor = torch.as_tensor(states, dtype=torch.float32, device=self.device)
        actions_tensor = torch.as_tensor(actions, dtype=torch.int64, device=self.device)
        with torch.no_grad():
            logits = self.forward(states_tensor)
        q_values = logits.gather(-1, actions_tensor.unsqueeze(-1))
        return q_values.squeeze(-1).cpu().numpy()
    
    def get_max_q(self, states):
        states_tensor = torch.as_tensor(states, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            logits = self.forward(states_tensor)
        return logits.max(-1).values.cpu().numpy()
    
    def get_argmax_q(self, states):
        states_tensor = torch.as_tensor(states, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            logits = self.forward(states_tensor)
        return logits.max(-1).indices.cpu().numpy()        
    
    def update(self, states, actions, deltas, non_final_mask):
        states_tensor = torch.as_tensor(states, dtype=torch.float32, device=self.device)
        actions_tensor = torch.as_tensor(actions, dtype=torch.int64, device=self.device).unsqueeze(-1)
        deltas_tensor = torch.as_tensor(deltas, dtype=torch.float32, device=self.device).unsqueeze(-1)
        non_final_mask = torch.tensor(non_final_mask, device=self.device).unsqueeze(-1)
        q_values = self.forward(states_tensor).gather(-1, actions_tensor)
        loss = F.smooth_l1_loss(
            torch.masked_select(q_values, non_final_mask),
            torch.masked_select(deltas_tensor, non_final_mask)
        )
        self.optimiser.zero_grad()
        loss.backward()
        nn.utils.clip_grad_value_(self.parameters(), 100)
        self.optimiser.step()
    
    def save(self, filename):
        torch.save(self.state_dict(), filename)
    
    @classmethod
    def load(cls, filename, num_states, num_actions, hidden_dim=32, alpha=1e-4):
        qfunction = cls(num_states, num_actions, hidden_dim, alpha)
        qfunction.load_state_dict(torch.load(filename))
        return qfunction

In [None]:
class QLearning:
    def __init__(self, envs, bandit, qfunction, gamma=1):
        self.envs = envs
        self.bandit = bandit
        self.qfunction = qfunction
        self.gamma = gamma
    
    def execute(self, episodes=2000):
        mean_rewards = []
        for episode in range(episodes):
            # Get initial states and actions.
            observations, infos = self.envs.reset()
            states = np.array(observations).T
            action_spaces = self.envs.action_space
            actions = self.bandit.select(self.envs.num_envs, states, action_spaces, self.qfunction)
            # Run the episode in batched environments.
            episode_rewards = np.zeros(self.envs.num_envs)
            non_final_mask = np.full(self.envs.num_envs, True)
            steps = np.zeros(self.envs.num_envs)
            # If all environment end, break the loop.
            while non_final_mask.any():
                (next_observations, rewards, terminateds, truncateds, infos) = self.envs.step(actions)             
                next_states = np.array(next_observations).T
                next_actions = self.bandit.select(self.envs.num_envs, next_states, action_spaces, self.qfunction)
                deltas = self.get_deltas(states, actions, rewards, next_states, next_actions, terminateds)
                # Update q network by deltas.
                self.qfunction.update(states, actions, deltas, non_final_mask)
                # Update some veriables.
                episode_rewards += rewards * (self.gamma ** steps) * non_final_mask
                non_final_mask = np.logical_and(non_final_mask, np.invert(terminateds + truncateds))
                steps += 1
                states = next_states
                actions = next_actions    
            mean_rewards.append(episode_rewards.mean())
        return mean_rewards
    
    def get_deltas(self, states, actions, rewards, next_states, next_actions, terminateds):
        q_values = self.qfunction.get_q(states, actions)
        next_state_values = self.state_value(next_states) * np.invert(terminateds)
        delta = rewards + self.gamma * next_state_values - q_values
        return delta
    
    def state_value(self, states):
        max_q_value = self.qfunction.get_max_q(states)
        return max_q_value


def get_ema(rewards, smoothing_factor=0.9):
    smoothed_rewards = []
    for reward in rewards:
        if smoothed_rewards == []:
            smoothed_rewards = [reward]
        else:
            smoothed_rewards += [
                smoothed_rewards[-1] * smoothing_factor
                + reward * (1 - smoothing_factor)
            ]
    return smoothed_rewards