In [13]:
import numpy as np
np.__version__
import math

In [None]:
class TicTacToe:
    def __init__(self, size=3):
        # square shape
        self.row_count    = size
        self.column_count = size

        self.action_size = self.row_count * self.column_count

    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count), dtype=np.int8)

    def action_to_row_col(self, action):
        """turn action number into row and column"""
        row    = action // self.column_count
        column = action  % self.column_count

        return row, column

    def get_next_state(self, state, action, player):
        row, column = self.action_to_row_col(action)
        
        state[row, column] = player
        return state

    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8)

    def check_win(self, state, action):
        if action == None:
            return False

        row, column = self.action_to_row_col(action)
        # infer the player from the state
        player = state[row, column]

        return (
            np.sum(state[row, :])       == player * self.column_count
            or np.sum(state[:, column]) == player * self.row_count
            or np.sum(np.diag(state))   == player * self.row_count # TODO: this only makes sense for square... why do we track column and row count separately?
            or np.sum(np.diag(np.fliplr(state))) == player * self.row_count
        )
    
    def check_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        elif np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        else:
            return 0, False
        
    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
        return -value


In [6]:
tictactoe = TicTacToe()
player = 1

state = tictactoe.get_initial_state()

In [7]:
state

array([[0, 0, 0],
       [0, 0, 0],
       [0, 0, 0]], dtype=int8)

In [8]:
tictactoe.get_valid_moves(state)

array([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=uint8)

In [11]:
tictactoe = TicTacToe()
player = 1

state = tictactoe.get_initial_state()

while True:
    print(state)
    valid_moves = tictactoe.get_valid_moves(state)
    print("Valid moves: ", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])
    print ("Player", player)
    action = int(input(f"{player}: "))
    if valid_moves[action] == 0:
        print("Invalid move")
        continue
    state = tictactoe.get_next_state(state, action, player)
    value, is_terminal = tictactoe.check_value_and_terminated(state, action)
    if is_terminal:
        if value == 1:
            print(f"Player {player} won.")
        else:
            print(f"Draw!")
        print(state)
        break
    player = tictactoe.get_opponent(player)


[[0 0 0]
 [0 0 0]
 [0 0 0]]
Valid moves:  [0, 1, 2, 3, 4, 5, 6, 7, 8]
Player 1
[[0 0 0]
 [0 1 0]
 [0 0 0]]
Valid moves:  [0, 1, 2, 3, 5, 6, 7, 8]
Player -1
[[-1  0  0]
 [ 0  1  0]
 [ 0  0  0]]
Valid moves:  [1, 2, 3, 5, 6, 7, 8]
Player 1
[[-1  1  0]
 [ 0  1  0]
 [ 0  0  0]]
Valid moves:  [2, 3, 5, 6, 7, 8]
Player -1
[[-1  1  0]
 [ 0  1  0]
 [ 0  0 -1]]
Valid moves:  [2, 3, 5, 6, 7]
Player 1
Invalid move
[[-1  1  0]
 [ 0  1  0]
 [ 0  0 -1]]
Valid moves:  [2, 3, 5, 6, 7]
Player 1
Player 1 won.
[[-1  1  0]
 [ 0  1  0]
 [ 0  1 -1]]


In [None]:
class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None):
        self.game = game
        self.args = args
        self.state = state
        self.parent = parent
        self.action_taken = action_taken

        self.children = []
        self.expandable_moves = game.get_valid_moves(state)
        self.visit_count = 0
        self.value_sum = 0

    def is_fully_expanded(self):
        return np.sum(self.expandable_moves) == 0 and len(self.children) > 0
    
    def select(self):
        """pick the child that has the highest UCB score"""
        best_child = None
        best_ucb = -np.inf

        for child in self.children:
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
        return best_child
    
    def get_ucb(self, child):
        # "1 -" because the child node represents the other player
        q_value = 1 - ((child.value_sum / child.visit_count) + 1) / 2
        return q_value + self.args['C'] * math.sqrt(math.log(self.visit_count) / child.visit_count)

    def expand(self):
        """Sample a random move from expandable_moves and expand the node"""
        # WIP


class MCTS:
    def __init__(self, game, args):
        self.game = game
        self.args = args
    
    def search(self, state):
        # define root node
        root = Node(self.game, self.args, state)

        for search in range(self.args['num_searches']):
            node = root

            # selection
            while node.is_fully_expanded():
                node = node.select()
            
            value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
            value = self.game.get_opponent_value(value)

            if not is_terminal:
                # expansion
                node = node.expand()
                # simulation
            # backpropagation

        # return visit_counts
