In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src import *

In [3]:
import numpy as np
import random
import pandas as pd
import torch
from torch import nn
from collections import defaultdict
from tqdm.auto import tqdm, trange
import pickle as pkl
from typing import Literal, get_type_hints

import itertools as it
import more_itertools as mit

In [4]:
def player_formatter(player: int) -> str:
    symbol_mapping = {0: " ", 1: "X", -1: "O"}
    return symbol_mapping[player]
    
def state_formatter(state: tuple[int, ...]) -> str:
    size = int(len(state) ** 0.5)
    formatted_state = "\n"
    for i in range(size):
        formatted_state += "+---+---+---+\n"
        row = state[i*size:(i+1)*size]
        formatted_state += "| " + " | ".join(player_formatter(cell)for cell in row) + " |\n"
    formatted_state += "+---+---+---+"
    return formatted_state

In [5]:
ttt = TicTacToe(default_state_formatter=state_formatter)
ttt.render()


+---+---+---+
|   |   |   |
+---+---+---+
|   |   |   |
+---+---+---+
|   |   |   |
+---+---+---+


In [7]:
q_table_policy = QTablePolicy.load('ckpts/q_table', lr=0.1, name='Deterministic Q Table')
q_table_policy_stochastic = QTablePolicy(q_table_policy.q_table, 
                                         lr=0.1,
                                         stochastic=True,
                                         temperature=0.5,
                                         name='Stochastic Q Table')
human_policy = PromptPolicy(player_formatter)
min_max_policy = MinMaxPolicy(game_cls=TicTacToe)
mcts_policy = MCTS(game_cls=TicTacToe)
mlp_policy = MLPPolicy(
    mlp=MLP(hidden_layers=6, hidden_width=256), 
    lr=5e-5, 
    train_device='cpu',
    inference_device='cpu',
    name='MLP Policy'
)
# mlp_policy = MLPPolicy.load(
#     ckpt_name='ckpts/mlp_6x256',
#     mlp=MLP(hidden_layers=6, hidden_width=256),
#     lr=5e-5,
#     train_device='cpu',
#     inference_device='cpu',
#     name='MLP Policy'
# )

In [8]:
def play(ttt, policy=None, print_state=False):
    if policy is None:
        policy = RandomPolicy()
    if not isinstance(policy, dict):
        policy = {-1: policy, 1: policy}

    tape = [dict(
        player=None,
        action_space=None,
        action=None,
        state=ttt.get_state(),
        winner=0
    )]

    while (action_space := ttt.get_actions()) and not ttt.get_winner():
        if print_state:
            ttt.render()
        player = ttt.player
        action = ttt.agent_move(policy[player])
        state = ttt.get_state()
        winner = ttt.get_winner()
        tape.append(dict(
            player=player,
            action_space=action_space,
            action=action,
            state=state,
            winner=winner
        ))
    if print_state:
        ttt.render()

    tape.append(dict(
        player=None,
        action_space=set(),
        action=None,
        state=ttt.get_state(),
        winner=ttt.get_winner()
    ))
    
    return tape

In [9]:
def swap(state: np.ndarray, option: Literal[1, -1]):
    return option * state

def flip(state: np.ndarray, option: Literal[True, False]):
    return np.fliplr(state) if option else state

def rotate(state: np.ndarray, option: Literal[0, 1, 2, 3]):
    return np.rot90(state, k=option)
    

In [10]:
def get_param_options(fn, param='option'):
    return get_type_hints(fn)['option'].__args__

In [11]:
def transform_state(raw_state, fns, opts):
    state = raw_state.reshape([3, 3])
    for fn, opt in zip(fns, opts):
        state = fn(state, opt)
    return state.flatten()

def transform_actions(raw_actions, fns, opts):
    actions = np.zeros(9)
    actions[raw_actions] = 1
    actions = transform_state(actions, fns, opts)
    return np.nonzero(actions)[0]

In [12]:
def bellman_equation_batch(policy, rewards, states, players, action_spaces):
    pred_qs = policy.batch_get_all_Qs(states, players, action_spaces)
    for reward, state, player, action_space, pred_q in zip(rewards, states, players, action_spaces, pred_qs):
        if action_space:
            yield reward + max(pred_q.values())
        else:
            yield reward

def replay_episodes(tapes: list, policy):
    start_state_list, end_state_list, player_list, action_list, action_space_list, reward_list = [], [], [], [], [], []
    transformations = [swap, flip, rotate]

    for tape in tapes:
        for pre, cur, nxt in mit.windowed(tape, 3):
            raw_start_state = np.array(list(pre['state'])).astype(int)
            raw_end_state = np.array(list(nxt['state'])).astype(int)
            raw_action = np.array([cur['action']]).astype(int)
            raw_action_space = np.array(list(nxt['action_space'])).astype(int)
            player = cur['player']
            reward = cur['player'] * nxt['winner']
            
            for opts in it.product(*map(get_param_options, transformations)):
                start_state_list.append(tuple(transform_state(raw_start_state, transformations, opts).tolist()))
                end_state_list.append(tuple(transform_state(raw_end_state, transformations, opts).tolist()))
                action_list.append(transform_actions(raw_action, transformations, opts).item())
                action_space_list.append(set(transform_actions(raw_action_space, transformations, opts).tolist()))
                player_list.append(player)
                reward_list.append(reward)

    q_list = list(bellman_equation_batch(policy, reward_list, end_state_list, player_list, action_space_list))

    return start_state_list, player_list, action_list, q_list

In [13]:
def train(policy, episodes, epsilon=0.2, update_interval=1):
    eps_greedy_policy = EpsilonGreedyPolicy(policy, epsilon=epsilon)
    tape_list = []
    for episode in trange(episodes):
        ttt.reset(start_player=random.choice([1, -1]))
        tape_list.append(play(ttt, policy=eps_greedy_policy))
        if episode > 0 and episode % update_interval == 0:
            data = replay_episodes(tape_list, policy)
            loss = policy.batch_update_Q(*data)
            print(loss)
            tape_list.clear()

In [15]:
train(mlp_policy, episodes=1_000_000, update_interval=32)

  0%|          | 0/1000000 [00:00<?, ?it/s]

0.07205723226070404
0.06988868117332458
0.09837668389081955
0.07804964482784271
0.07805898785591125
0.06501083821058273
0.1315041035413742
0.08804813772439957
0.09284887462854385
0.09534108638763428
0.09893138706684113
0.07833267748355865
0.0887734442949295
0.08296443521976471
0.07477708160877228
0.08953355252742767
0.08302787691354752
0.10171451419591904
0.1035476103425026
0.11062180250883102
0.08625424653291702
0.1128174290060997
0.14527535438537598
0.13865171372890472
0.13269652426242828
0.09977356344461441
0.13649475574493408
0.11116776615381241
0.10689717531204224
0.10064954310655594
0.081812784075737
0.11697231978178024
0.09519228339195251
0.07534226775169373
0.07489848881959915
0.09123846143484116
0.0511411689221859
0.07252182066440582
0.09893420338630676
0.11515916138887405
0.10715433210134506
0.10219378024339676
0.09015550464391708
0.11758848279714584
0.10075381398200989
0.09582898020744324
0.076482854783535
0.07488463819026947
0.07270494848489761
0.08575166761875153
0.0642396

In [16]:
mlp_policy.save('ckpts/mlp_6x256')

In [None]:
train(q_table_policy_stochastic, episodes=100000)

In [None]:
q_table_policy.save('ckpts/q_table')

## Manual Play

In [13]:
ttt.reset()
tape = play(ttt, policy={1: human_policy, -1: q_table_policy}, print_state=True)

In [14]:
pd.DataFrame(tape)

Unnamed: 0,player,action_space,action,state,winner
0,,,,"(0, 0, 0, 0, 0, 0, 0, 0, 0)",0
1,1.0,"{0, 1, 2, 3, 4, 5, 6, 7, 8}",4.0,"(0, 0, 0, 0, 1, 0, 0, 0, 0)",0
2,-1.0,"{0, 1, 2, 3, 5, 6, 7, 8}",0.0,"(-1, 0, 0, 0, 1, 0, 0, 0, 0)",0
3,1.0,"{1, 2, 3, 5, 6, 7, 8}",5.0,"(-1, 0, 0, 0, 1, 1, 0, 0, 0)",0
4,-1.0,"{1, 2, 3, 6, 7, 8}",3.0,"(-1, 0, 0, -1, 1, 1, 0, 0, 0)",0
5,1.0,"{1, 2, 6, 7, 8}",6.0,"(-1, 0, 0, -1, 1, 1, 1, 0, 0)",0
6,-1.0,"{8, 1, 2, 7}",2.0,"(-1, 0, -1, -1, 1, 1, 1, 0, 0)",0
7,1.0,"{8, 1, 7}",1.0,"(-1, 1, -1, -1, 1, 1, 1, 0, 0)",0
8,-1.0,"{8, 7}",7.0,"(-1, 1, -1, -1, 1, 1, 1, -1, 0)",0
9,1.0,{8},8.0,"(-1, 1, -1, -1, 1, 1, 1, -1, 1)",0


# Run Simulations

In [18]:
def simulate(round, player_policy_map, start_player=None):
    player_policy_map[0] = DummyPolicy(name='Tie')
    win_count = {policy.get_name(): 0 for policy in player_policy_map.values()}
    for episode in trange(round):
        ttt.reset(start_player=start_player or random.choice([-1, 1]))
        tape = play(ttt, policy=player_policy_map)
        win_count[player_policy_map[ttt.get_winner()].get_name()] += 1
    return win_count

In [20]:
# policy_1 = RandomPolicy(name='Random Policy')
policy_1 = q_table_policy
policy_2 = mlp_policy

result_list = []
for p1_role, start_player in [[1, 1], [1, -1], [-1, 1], [-1, -1]]:
    policy_map = {1*p1_role: policy_1, -1*p1_role: policy_2}
    result = simulate(100, policy_map, start_player=start_player)
    result['Player -1'] = policy_map[-1].get_name()
    result['Player 1'] = policy_map[1].get_name()
    result['Start Player'] = policy_map.get(start_player).get_name()
    result_list.append(result)
pd.DataFrame(result_list)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Unnamed: 0,Deterministic Q Table,MLP Policy,Tie,Player -1,Player 1,Start Player
0,100,0,0,MLP Policy,Deterministic Q Table,Deterministic Q Table
1,0,0,100,MLP Policy,Deterministic Q Table,MLP Policy
2,0,0,100,Deterministic Q Table,MLP Policy,MLP Policy
3,100,0,0,Deterministic Q Table,MLP Policy,Deterministic Q Table
