In [1]:
import numpy as np
np.__version__

'1.24.2'

In [2]:
class TicTacToe:
    def __init__(self):
        self.row_count = 3
        self.column_count = 3
        self.action_size = self.row_count * self.column_count
        
    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count), dtype=int)
    
    def get_next_state(self, state, action, player):
        row = action // self.column_count
        column = action % self.column_count
        state[row, column] = player
        return state
    
    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8)
    
    def check_win(self, state, action):
        row = action // self.column_count
        column = action % self.column_count
        player = state[row, column]
        
        return (
            np.sum(state[row, :]) == player * self.column_count
            or np.sum(state[:, column]) == player * self.row_count
            or np.sum(np.diag(state)) == player * self.row_count # change to diagonal length
            or np.sum(np.diag(np.flip(state, axis = 0))) == player * self.row_count # change to diagonal length
        )
    
    def check_draw(self, state):
        if np.sum(self.get_valid_moves(state)) == 0:
            return True
        else:
            return False
        
    def get_opponent(self, player):
        return -player
    
    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if self.check_draw(state):
            return 0, True
        
        return 0, False
    

In [4]:
tictactoe = TicTacToe()
player = 1
state = tictactoe.get_initial_state()

while True:
    print(state)
    valid_moves = tictactoe.get_valid_moves(state)
    print("valid_moves:", [i for i in range(tictactoe.action_size) if valid_moves[i] == 1])
    action = int(input(f"{player}: "))
    
    if valid_moves[action] == 0:
        print("invalid action")
        continue
        
    state = tictactoe.get_next_state(state, action, player)
    
    value, terminated = tictactoe.get_value_and_terminated(state, action)
    
    if terminated:
        print(state)
        if value == 1:
            print("Player ", player, " won")
        else:
            print("Game drawn")
        break
        
    player = tictactoe.get_opponent(player)

[[0 0 0]
 [0 0 0]
 [0 0 0]]
valid_moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]
1: 1
[[0 1 0]
 [0 0 0]
 [0 0 0]]
valid_moves: [0, 2, 3, 4, 5, 6, 7, 8]
-1: 3
[[ 0  1  0]
 [-1  0  0]
 [ 0  0  0]]
valid_moves: [0, 2, 4, 5, 6, 7, 8]
1: 0
[[ 1  1  0]
 [-1  0  0]
 [ 0  0  0]]
valid_moves: [2, 4, 5, 6, 7, 8]
-1: 2
[[ 1  1 -1]
 [-1  0  0]
 [ 0  0  0]]
valid_moves: [4, 5, 6, 7, 8]
1: 4
[[ 1  1 -1]
 [-1  1  0]
 [ 0  0  0]]
valid_moves: [5, 6, 7, 8]
-1: 7
[[ 1  1 -1]
 [-1  1  0]
 [ 0 -1  0]]
valid_moves: [5, 6, 8]
1: 6
[[ 1  1 -1]
 [-1  1  0]
 [ 1 -1  0]]
valid_moves: [5, 8]
-1: 8
[[ 1  1 -1]
 [-1  1  0]
 [ 1 -1 -1]]
valid_moves: [5]
1: 5
[[ 1  1 -1]
 [-1  1  1]
 [ 1 -1 -1]]
Game drawn


In [None]:
class MCTS:
    def __init__(self, game, args):
        self.game = game
        self.args = args
        
    