In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from src import *

In [3]:
import numpy as np
import random
import pandas as pd
import torch
from torch import nn
from collections import defaultdict
from tqdm.auto import tqdm, trange
import pickle as pkl
from typing import Literal, get_type_hints

import itertools as it
import more_itertools as mit

In [4]:
ttt = TicTacToe()
ttt.render()

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [26]:
q_table_policy = QTablePolicy.load('q_table', lr=0.1)
# q_table_policy = QTablePolicy(lr=0.1)
# mlp_policy = MLPPolicy(Model(), lr=0.001)
# mcts_policy = MCTS(ttt)
# mlp_policy = MLPPolicy.load(Model(), 'mlp.pt', lr=0.001)

In [102]:
def player_formatter(player: int):
    symbol_mapping = {0: " ", 1: "X", -1: "O"}
    return symbol_mapping[player]
    
    
def state_formatter(state: tuple[int, ...]):
    size = int(len(state) ** 0.5)
    formatted_state = "\n"
    for i in range(size):
        formatted_state += "+---+---+---+\n"
        row = state[i*size:(i+1)*size]
        formatted_state += "| " + " | ".join(player_formatter(cell)for cell in row) + " |\n"
    formatted_state += "+---+---+---+"
    return formatted_state

In [93]:
def play(ttt, policy=None, print_state=False):
    if policy is None:
        policy = RandomPolicy()
    if not isinstance(policy, dict):
        policy = {-1: policy, 1: policy}

    tape = [dict(
        player=None,
        action_space=None,
        action=None,
        state=ttt.get_state(),
        winner=0
    )]

    while (action_space := ttt.get_actions()) and not ttt.get_winner():
        if print_state:
            print(state_formatter(ttt.get_state()), flush=True)
        player = ttt.player
        action = ttt.agent_move(policy[player])
        state = ttt.get_state()
        winner = ttt.get_winner()
        tape.append(dict(
            player=player,
            action_space=action_space,
            action=action,
            state=state,
            winner=winner
        ))
    if print_state:
        print(state_formatter(ttt.get_state()), flush=True)

    tape.append(dict(
        player=None,
        action_space=set(),
        action=None,
        state=ttt.get_state(),
        winner=ttt.get_winner()
    ))
    
    return tape

In [94]:
def swap(state: np.ndarray, option: Literal[1, -1]):
    return option * state

def flip(state: np.ndarray, option: Literal[True, False]):
    return np.fliplr(state) if option else state

def rotate(state: np.ndarray, option: Literal[0, 1, 2, 3]):
    return np.rot90(state, k=option)
    

In [95]:
def get_param_options(fn, param='option'):
    return get_type_hints(fn)['option'].__args__

In [96]:
def transform_state(raw_state, fns, opts):
    state = raw_state.reshape([3, 3])
    for fn, opt in zip(fns, opts):
        state = fn(state, opt)
    return state.flatten()

def transform_actions(raw_actions, fns, opts):
    actions = np.zeros(9)
    actions[raw_actions] = 1
    actions = transform_state(actions, fns, opts)
    return np.nonzero(actions)[0]

In [97]:
def bellman_equation(policy, reward, state, player, actions):
    if not actions:
        return reward
    return reward + max(policy.get_Q(state, player, action) for action in actions)

def replay_episode(tape, policy):
    state_list, player_list, action_list, q_list = [], [], [], []
    transformations = [swap, flip, rotate]

    for pre, cur, nxt in mit.windowed(tape, 3):
        raw_start_state = np.array(list(pre['state'])).astype(int)
        raw_end_state = np.array(list(nxt['state'])).astype(int)
        raw_action = np.array([cur['action']]).astype(int)
        raw_action_space = np.array(list(nxt['action_space'])).astype(int)
        player = cur['player']
        reward = cur['player'] * nxt['winner']
        
        for opts in it.product(*map(get_param_options, transformations)):
            start_state = tuple(transform_state(raw_start_state, transformations, opts).tolist())
            end_state = tuple(transform_state(raw_end_state, transformations, opts).tolist())
            action = transform_actions(raw_action, transformations, opts).item()
            action_space = set(transform_actions(raw_action_space, transformations, opts).tolist())

            new_q = bellman_equation(policy, reward, end_state, player, action_space)
            state_list.append(start_state)
            player_list.append(player)
            action_list.append(action)
            q_list.append(new_q)

    return state_list, player_list, action_list, q_list

In [98]:
def train(policy, episodes, epsilon=0.1):
    eps_greedy_policy = EpsilonGreedyPolicy(policy, epsilon=epsilon)
    for episode in trange(episodes):
        ttt.reset(start_player=random.choice([1, -1]))
        tape = play(ttt, policy=eps_greedy_policy)
        state_list, player_list, action_list, q_list = replay_episode(tape, policy)
        loss = policy.batch_update_Q(state_list, player_list, action_list, q_list)


In [27]:
train(q_table_policy, episodes=50000)

  0%|          | 0/50000 [00:00<?, ?it/s]

In [103]:
ttt.reset()
tape = play(ttt, policy={1: PromptPolicy(player_formatter), -1: q_table_policy}, print_state=True)


+---+---+---+
|   |   |   |
+---+---+---+
|   |   |   |
+---+---+---+
|   |   |   |
+---+---+---+

+---+---+---+
| X |   |   |
+---+---+---+
|   |   |   |
+---+---+---+
|   |   |   |
+---+---+---+

+---+---+---+
| X |   |   |
+---+---+---+
|   | O |   |
+---+---+---+
|   |   |   |
+---+---+---+


Exception: Invalid move

In [82]:
pd.DataFrame(tape)

Unnamed: 0,player,action_space,action,state,winner
0,,,,"(0, 0, 0, 0, 0, 0, 0, 0, 0)",0
1,1.0,"{0, 1, 2, 3, 4, 5, 6, 7, 8}",2.0,"(0, 0, 1, 0, 0, 0, 0, 0, 0)",0
2,-1.0,"{0, 1, 3, 4, 5, 6, 7, 8}",0.0,"(-1, 0, 1, 0, 0, 0, 0, 0, 0)",0
3,1.0,"{1, 3, 4, 5, 6, 7, 8}",3.0,"(-1, 0, 1, 1, 0, 0, 0, 0, 0)",0
4,-1.0,"{1, 4, 5, 6, 7, 8}",8.0,"(-1, 0, 1, 1, 0, 0, 0, 0, -1)",0
5,1.0,"{1, 4, 5, 6, 7}",4.0,"(-1, 0, 1, 1, 1, 0, 0, 0, -1)",0
6,-1.0,"{1, 5, 6, 7}",1.0,"(-1, -1, 1, 1, 1, 0, 0, 0, -1)",0
7,1.0,"{5, 6, 7}",6.0,"(-1, -1, 1, 1, 1, 0, 1, 0, -1)",1
8,,{},,"(-1, -1, 1, 1, 1, 0, 1, 0, -1)",1


In [32]:
q_table_policy.save('q_table')

In [28]:
win_count = defaultdict(int)
for episode in trange(1000):
    ttt.reset(start_player=1)
    tape = play(ttt, policy={1: RandomPolicy(), -1: q_table_policy})
    win_count[ttt.get_winner()] += 1
win_count

  0%|          | 0/1000 [00:00<?, ?it/s]

defaultdict(int, {-1: 920, 0: 77, 1: 3})

In [29]:
win_count = defaultdict(int)
for episode in trange(1000):
    ttt.reset(start_player=1)
    tape = play(ttt, policy={-1: RandomPolicy(), 1: q_table_policy})
    win_count[ttt.get_winner()] += 1
win_count

  0%|          | 0/1000 [00:00<?, ?it/s]

defaultdict(int, {1.0: 961, 0: 39})

In [30]:
win_count = defaultdict(int)
for episode in trange(1000):
    ttt.reset(start_player=-1)
    tape = play(ttt, policy={1: RandomPolicy(), -1: q_table_policy})
    win_count[ttt.get_winner()] += 1
win_count

  0%|          | 0/1000 [00:00<?, ?it/s]

defaultdict(int, {-1: 904, 0: 96})

In [31]:
win_count = defaultdict(int)
for episode in trange(1000):
    ttt.reset(start_player=-1)
    tape = play(ttt, policy={-1: RandomPolicy(), 1: q_table_policy})
    win_count[ttt.get_winner()] += 1
win_count

  0%|          | 0/1000 [00:00<?, ?it/s]

defaultdict(int, {1.0: 864, 0: 99, -1: 37})

In [None]:
ttt.get_winner()

In [38]:
import time