In [73]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
from src import *

In [2]:
import numpy as np
import random
import pandas as pd
import torch
from torch import nn
from collections import defaultdict
from tqdm.auto import tqdm, trange
import pickle as pkl
from typing import Literal, get_type_hints

import itertools as it
import more_itertools as mit

In [3]:
def player_formatter(player: int) -> str:
    symbol_mapping = {0: " ", 1: "X", -1: "O"}
    return symbol_mapping[player]
    
def state_formatter(state: tuple[int, ...]) -> str:
    size = int(len(state) ** 0.5)
    formatted_state = "\n"
    for i in range(size):
        formatted_state += "+---+---+---+\n"
        row = state[i*size:(i+1)*size]
        formatted_state += "| " + " | ".join(player_formatter(cell)for cell in row) + " |\n"
    formatted_state += "+---+---+---+"
    return formatted_state

In [4]:
ttt = TicTacToe(default_state_formatter=state_formatter)
ttt.render()


+---+---+---+
|   |   |   |
+---+---+---+
|   |   |   |
+---+---+---+
|   |   |   |
+---+---+---+


In [5]:
q_table_policy = QTablePolicy.load('q_table', lr=0.1, name='Deterministic Q Table')
q_table_policy_stochastic = QTablePolicy(q_table_policy.q_table, 
                                         lr=0.1,
                                         stochastic=True,
                                         temperature=0.5,
                                         name='Stochastic Q Table')
human_policy = PromptPolicy(player_formatter)
min_max_policy = MinMaxPolicy(game_cls=TicTacToe)
mcts_policy = MCTS(game_cls=TicTacToe)
# q_table_policy = QTablePolicy(lr=0.1)
# mlp_policy = MLPPolicy(Model(), lr=0.001)
# mcts_policy = MCTS(ttt)
# mlp_policy = MLPPolicy.load(Model(), 'mlp.pt', lr=0.001)

In [6]:
help(MCTS)

Help on class MCTS in module src.policies.q_policies.mcts_policy:

class MCTS(src.policies.q_policies.base_q_policy.BaseQPolicySingle)
 |  MCTS(game_cls, *args, **kwargs)
 |  
 |  Method resolution order:
 |      MCTS
 |      src.policies.q_policies.base_q_policy.BaseQPolicySingle
 |      src.policies.q_policies.base_q_policy.BaseQPolicy
 |      src.policies.base_policy.BasePolicy
 |      abc.ABC
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, game_cls, *args, **kwargs)
 |      Initialize self.  See help(type(self)) for accurate signature.
 |  
 |  back_propogate(self, node, reward)
 |  
 |  calculate_reward(self, node)
 |  
 |  execute_round(self)
 |      execute a selection-expansion-simulation-backpropagation round
 |  
 |  expand(self, node)
 |  
 |  find_node_matching_state(self, node, state)
 |  
 |  get_all_Qs(self, state: tuple[int, ...], player: int, action_space: set[int]) -> dict[int, float]
 |  
 |  get_best_move(self, node, exploration_consta

In [7]:
def play(ttt, policy=None, print_state=False):
    if policy is None:
        policy = RandomPolicy()
    if not isinstance(policy, dict):
        policy = {-1: policy, 1: policy}

    tape = [dict(
        player=None,
        action_space=None,
        action=None,
        state=ttt.get_state(),
        winner=0
    )]

    while (action_space := ttt.get_actions()) and not ttt.get_winner():
        if print_state:
            ttt.render()
        player = ttt.player
        action = ttt.agent_move(policy[player])
        state = ttt.get_state()
        winner = ttt.get_winner()
        tape.append(dict(
            player=player,
            action_space=action_space,
            action=action,
            state=state,
            winner=winner
        ))
    if print_state:
        ttt.render()

    tape.append(dict(
        player=None,
        action_space=set(),
        action=None,
        state=ttt.get_state(),
        winner=ttt.get_winner()
    ))
    
    return tape

In [8]:
def swap(state: np.ndarray, option: Literal[1, -1]):
    return option * state

def flip(state: np.ndarray, option: Literal[True, False]):
    return np.fliplr(state) if option else state

def rotate(state: np.ndarray, option: Literal[0, 1, 2, 3]):
    return np.rot90(state, k=option)
    

In [9]:
def get_param_options(fn, param='option'):
    return get_type_hints(fn)['option'].__args__

In [10]:
def transform_state(raw_state, fns, opts):
    state = raw_state.reshape([3, 3])
    for fn, opt in zip(fns, opts):
        state = fn(state, opt)
    return state.flatten()

def transform_actions(raw_actions, fns, opts):
    actions = np.zeros(9)
    actions[raw_actions] = 1
    actions = transform_state(actions, fns, opts)
    return np.nonzero(actions)[0]

In [11]:
def bellman_equation(policy, reward, state, player, actions):
    if not actions:
        return reward
    return reward + max(policy.get_Q(state, player, action) for action in actions)

def replay_episode(tape, policy):
    state_list, player_list, action_list, q_list = [], [], [], []
    transformations = [swap, flip, rotate]

    for pre, cur, nxt in mit.windowed(tape, 3):
        raw_start_state = np.array(list(pre['state'])).astype(int)
        raw_end_state = np.array(list(nxt['state'])).astype(int)
        raw_action = np.array([cur['action']]).astype(int)
        raw_action_space = np.array(list(nxt['action_space'])).astype(int)
        player = cur['player']
        reward = cur['player'] * nxt['winner']
        
        for opts in it.product(*map(get_param_options, transformations)):
            start_state = tuple(transform_state(raw_start_state, transformations, opts).tolist())
            end_state = tuple(transform_state(raw_end_state, transformations, opts).tolist())
            action = transform_actions(raw_action, transformations, opts).item()
            action_space = set(transform_actions(raw_action_space, transformations, opts).tolist())

            new_q = bellman_equation(policy, reward, end_state, player, action_space)
            state_list.append(start_state)
            player_list.append(player)
            action_list.append(action)
            q_list.append(new_q)

    return state_list, player_list, action_list, q_list

In [13]:
def train(policy, episodes, epsilon=0.1):
    eps_greedy_policy = EpsilonGreedyPolicy(policy, epsilon=epsilon)
    for episode in trange(episodes):
        ttt.reset(start_player=random.choice([1, -1]))
        tape = play(ttt, policy=eps_greedy_policy)
        state_list, player_list, action_list, q_list = replay_episode(tape, policy)
        loss = policy.batch_update_Q(state_list, player_list, action_list, q_list)


In [35]:
train(q_table_policy_stochastic, episodes=100000)

  0%|          | 0/100000 [00:00<?, ?it/s]

## Manual Play

In [12]:
q_table_policy.save('q_table')

In [13]:
ttt.reset()
tape = play(ttt, policy={1: human_policy, -1: mcts_policy}, print_state=True)


+---+---+---+
|   |   |   |
+---+---+---+
|   |   |   |
+---+---+---+
|   |   |   |
+---+---+---+

+---+---+---+
|   |   |   |
+---+---+---+
|   | X |   |
+---+---+---+
|   |   |   |
+---+---+---+

+---+---+---+
|   |   |   |
+---+---+---+
|   | X |   |
+---+---+---+
|   |   | O |
+---+---+---+

+---+---+---+
|   |   |   |
+---+---+---+
|   | X | X |
+---+---+---+
|   |   | O |
+---+---+---+

+---+---+---+
|   |   |   |
+---+---+---+
| O | X | X |
+---+---+---+
|   |   | O |
+---+---+---+

+---+---+---+
|   | X |   |
+---+---+---+
| O | X | X |
+---+---+---+
|   |   | O |
+---+---+---+

+---+---+---+
|   | X |   |
+---+---+---+
| O | X | X |
+---+---+---+
|   | O | O |
+---+---+---+

+---+---+---+
|   | X | X |
+---+---+---+
| O | X | X |
+---+---+---+
|   | O | O |
+---+---+---+

+---+---+---+
|   | X | X |
+---+---+---+
| O | X | X |
+---+---+---+
| O | O | O |
+---+---+---+


In [17]:

1pd.DataFrame(tape)

SyntaxError: invalid decimal literal (66848535.py, line 2)

# Run Simulations

In [16]:
def simulate(round, player_policy_map, start_player=None):
    player_policy_map[0] = DummyPolicy(name='Tie')
    win_count = {policy.get_name(): 0 for policy in player_policy_map.values()}
    for episode in trange(round):
        ttt.reset(start_player=start_player or random.choice([-1, 1]))
        tape = play(ttt, policy=player_policy_map)
        win_count[player_policy_map[ttt.get_winner()].get_name()] += 1
    return win_count

In [19]:
# policy_1 = RandomPolicy(name='Random Policy')
policy_1 = q_table_policy
policy_2 = min_max_policy

result_list = []
for p1_role, start_player in [[1, 1], [1, -1], [-1, 1], [-1, -1]]:
    policy_map = {1*p1_role: policy_1, -1*p1_role: policy_2}
    result = simulate(10, policy_map, start_player=start_player)
    result['Player -1'] = policy_map[-1].get_name()
    result['Player 1'] = policy_map[1].get_name()
    result['Start Player'] = policy_map.get(start_player).get_name()
    result_list.append(result)
pd.DataFrame(result_list)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Unnamed: 0,Deterministic Q Table,MinMaxPolicy,Tie,Player -1,Player 1,Start Player
0,0,0,10,MinMaxPolicy,Deterministic Q Table,Deterministic Q Table
1,0,0,10,MinMaxPolicy,Deterministic Q Table,MinMaxPolicy
2,0,0,10,Deterministic Q Table,MinMaxPolicy,MinMaxPolicy
3,0,0,10,Deterministic Q Table,MinMaxPolicy,Deterministic Q Table
