In [27]:
import numpy as np
import copy
from typing import Callable, Optional
import time
from src import *
import pandas as pd
import random
from tqdm.auto import trange
from typing import Literal, get_type_hints
import more_itertools as mit
import itertools as it

In [28]:
class TicTacToe_4:
    def __init__(self,
                 start_player: int = 1,
                 default_state_formatter: Callable[[tuple[int, ...]], str] = str,
                 size: int = 3):
        self.board = np.zeros(size * size, dtype=int)
        self.player = start_player
        self.default_state_formatter = default_state_formatter
        self.action_history = []
        self.size = size

    @classmethod
    def from_state(cls, state: tuple[int, ...], player: int):
        obj = TicTacToe()
        obj.update_state(state, player)
        return obj

    def __str__(self):
        format_str = '\n+---+---+---+\n|{:^size}|{:^size}|{:^size}|' * self.size + '\n+---+---+---+\n'
        return format_str.format(*np.array([' ', 'X', 'O'])[self.board].tolist())

    # TODO: use __str__
    def render(self, state_formatter: Optional[Callable[[tuple[int, ...]], str]] = None):
        formatter = state_formatter or self.default_state_formatter
        print(formatter(self.get_state()), flush=True)

    # TODO: optimize
    # iterate through all possible lines
    def __iter__(self):
        # for i in range(self.size):
        #     yield self.board[i * self.size: (i + 1) * self.size]
        # for i in range(self.size):
        #     yield self.board[i::self.size]
        # 
        # yield self.board[::self.size + 1]
        # yield self.board[self.size - 1: self.size ** 2 - 1: self.size - 1]
        for i in (0, 1, 2, 3):
            yield self.board[i], self.board[i + 4], self.board[i + 8], self.board[i + 12]
        for i in (0, 4, 8, 10):
            yield self.board[i], self.board[i + 1], self.board[i + 2], self.board[i + 3]
        for i in (0,):
            yield self.board[i], self.board[i + 5], self.board[i + 10], self.board[i + 15]
        for i in (3,):
            yield self.board[i], self.board[i + 3], self.board[i + 6], self.board[i + 9]

    def update_state(self, state: tuple[int, ...], player: int):
        self.board = np.array(state, dtype=int)
        self.player = player

    def get_state(self) -> tuple[int, ...]:
        return tuple(self.board.astype(int).tolist())

    def get_actions(self):
        return set(np.where(self.board == 0)[0].tolist())

    def get_winner(self):
        start_time = time.time()
        for first, second, third, fourth in self:
            if first == second == third == fourth != 0:
                end_time = time.time()
                return first
        end_time = time.time()
        return 0

    def get_player(self):
        return self.player

    def get_last_player(self):
        return -self.player

    def is_terminated(self):
        return not self.get_actions() or self.get_winner()
    def clone(self):
        return copy.deepcopy(self)

    def move(self, action):
        if self.board[action] != 0:
            raise Exception("Invalid move")
        self.board[action] = self.player
        self.player = -self.player
        self.action_history.append(action)

    def agent_move(self, policy):
        best_action = policy(self.get_state(), self.player, self.get_actions())
        self.move(best_action)
        return best_action

    def reset(self, start_player=1):
        self.board *= 0
        self.player = start_player

    def spawn(self, action):
        clone = self.clone()
        clone.move(action)
        return clone

    # TODO: value function
    def utility(self, player):
        if self.is_terminated():
            if self.get_winner() == 0:
                return 0
            return 10 if self.get_winner() == player else -10
        return 0

    def last_player(self):
        return -self.player

    def apply_action(self, action):
        if self.board[action] != 0:
            raise Exception("Invalid move")
        self.board[action] = self.player
        self.player = -self.player
        self.action_history.append(action)

    def undo_action(self):
        if not self.action_history:
            raise Exception("No actions to undo")
        last_action = self.action_history.pop()
        self.board[last_action] = 0
        self.player = -self.player


In [29]:
def player_formatter(player: int) -> str:
    symbol_mapping = {0: " ", 1: "X", -1: "O"}
    return symbol_mapping[player]

def state_formatter(state: tuple[int, ...]) -> str:
    size = int(len(state) ** 0.5)
    formatted_state = "\n"
    for i in range(size):
        formatted_state += "+" + "---+" * size + "\n"
        row = state[i*size:(i+1)*size]
        formatted_state += "| " + " | ".join(player_formatter(cell)for cell in row) + " |\n"
    formatted_state += "+" + "---+" * size + "\n"
    return formatted_state
ttt=TicTacToe_4(default_state_formatter=state_formatter, size=4)

In [30]:
from src.policies.q_policies.base_q_policy import BaseQPolicySingle
from src.value_functions import BaseValueFunction
from typing import Optional


class MinMaxPolicy(BaseQPolicySingle):
    def __init__(self,
                 game_cls,
                 heuristic: Optional[BaseValueFunction] = None,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.game_cls = game_cls
        self.heuristic = heuristic

    def min_max(self, game: TicTacToe, is_max_player, cloned_game, depth):
        if game.is_terminated():
            winner = game.get_winner()
            if winner == cloned_game.player:
                return 1
            elif winner == 0:
                return 0
            else:
                return -1
        if depth == 0:
            return self.heuristic.get_V(game.get_state(), cloned_game.player, game.get_actions())

        best_score = -float('inf') if is_max_player else float('inf')
        for action in game.get_actions():
            game.apply_action(action)
            score = self.min_max(game, not is_max_player, cloned_game, depth - 1)
            game.undo_action()
            if (is_max_player and score > best_score) or (not is_max_player and score < best_score):
                best_score = score
        return best_score

    def get_all_Qs(self, state: tuple[int, ...], player: int, action_space: set[int]) -> dict[int, float]:
        q_values = {}
        for action in action_space:
            new_game = self.game_cls.from_state(state, player)
            cloned_game = new_game.clone()
            new_game.move(action)
            score = self.min_max(new_game, False, cloned_game, depth=3)
            q_values[action] = score
        return q_values

    def update_Q(self, state: tuple[int, ...], player: int, action: int, Q: float) -> None:
        raise NotImplementedError

In [31]:
# from src.value_functions.heuristic_4 import Heuristic_4
from src.value_functions.heuristic_4_test import Heuristic_4
min_max_policy_test = MinMaxPolicy(
    game_cls=TicTacToe,
    heuristic=Heuristic_4(

    )
)

In [32]:
def play(ttt, policy=None, print_state=False):
    if policy is None:
        policy = RandomPolicy()
    if not isinstance(policy, dict):
        policy = {-1: policy, 1: policy}

    tape = [dict(
        player=None,
        action_space=None,
        action=None,
        state=ttt.get_state(),
        winner=0
    )]

    while (action_space := ttt.get_actions()) and not ttt.get_winner():
        if print_state:
            ttt.render()
        player = ttt.player
        action = ttt.agent_move(policy[player])
        state = ttt.get_state()
        winner = ttt.get_winner()
        tape.append(dict(
            player=player,
            action_space=action_space,
            action=action,
            state=state,
            winner=winner
        ))
    if print_state:
        ttt.render()

    tape.append(dict(
        player=None,
        action_space=set(),
        action=None,
        state=ttt.get_state(),
        winner=ttt.get_winner()
    ))

    return tape

In [38]:
def simulate(round, player_policy_map, start_player=None):
    player_policy_map[0] = DummyPolicy(name='Tie')
    win_count = {policy.get_name(): 0 for policy in player_policy_map.values()}
    for episode in trange(round):
        ttt.reset(start_player=start_player or random.choice([-1, 1]))
        tape = play(ttt, policy=player_policy_map)
        win_count[player_policy_map[ttt.get_winner()].get_name()] += 1
    return win_count

In [39]:
policy_1 = RandomPolicy(name='Random Policy')
# policy_2 = q_table_policy
policy_2 = min_max_policy_test
result_list = []
for p1_role, start_player in [[1, 1], [1, -1], [-1, 1], [-1, -1]]:
    policy_map = {1*p1_role: policy_1, -1*p1_role: policy_2}
    result = simulate(10, policy_map, start_player=start_player)
    result['Player -1'] = policy_map[-1].get_name()
    result['Player 1'] = policy_map[1].get_name()
    result['Start Player'] = policy_map.get(start_player).get_name()
    result_list.append(result)
pd.DataFrame(result_list)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Unnamed: 0,Random Policy,MinMaxPolicy,Tie,Player -1,Player 1,Start Player
0,3,3,4,MinMaxPolicy,Random Policy,Random Policy
1,1,5,4,MinMaxPolicy,Random Policy,MinMaxPolicy
2,2,4,4,Random Policy,MinMaxPolicy,MinMaxPolicy
3,7,1,2,Random Policy,MinMaxPolicy,Random Policy


In [40]:
human_policy = PromptPolicy(player_formatter)
ttt.reset()
tape = play(ttt, policy={-1: human_policy, 1: min_max_policy_test}, print_state=True)


+---+---+---+---+
|   |   |   |   |
+---+---+---+---+
|   |   |   |   |
+---+---+---+---+
|   |   |   |   |
+---+---+---+---+
|   |   |   |   |
+---+---+---+---+

+---+---+---+---+
|   |   |   |   |
+---+---+---+---+
|   |   |   |   |
+---+---+---+---+
| X |   |   |   |
+---+---+---+---+
|   |   |   |   |
+---+---+---+---+


ValueError: invalid literal for int() with base 10: ''