In [None]:
import numpy as np
import pandas as pd
import os
import random
import math
import copy
from tqdm.notebook import tqdm
import time

In [6]:
class State:
    won_check = np.array([[0,1,2],[3,4,5],[6,7,8],[0,3,6],[1,4,7],[2,5,8],[0,4,8],[2,4,6]])

    def __init__(self, board_state,previous_move=-1):

        self.board_state = board_state # state of the board
        self.won_flag = [] # a list of 9 elements corresponding to who has won which board
        self.moves_left = [] # a list of moves completed in each subboard
        self.no_of_moves_left = [] # no. of moves left in each subboard
        self.previous_move = previous_move
        self.first = np.all(self.board_state==0) # whether game has started or not
        
        for i in range(self.board_state.shape[0]):#Traversing every sub board
            self.won_flag.append(self.check_result_board(self.board_state[i]))
            self.moves_left.append(self.board_state[i] == 0)
            self.no_of_moves_left.append(np.count_nonzero(self.board_state[i] == 0))
        
        
        self.finish = self.check_result_board(self.won_flag) # whether game is finished or not

    def check_result_board(self, arr):
        arr = np.array(arr)
        for triplet in State.won_check:
            values = arr[triplet]
            if np.all(values == values[0]):
                return values[0]
        if np.count_nonzero(arr==0)==0:
            return -2 #-2 is to signify a draw at any subboard or the final board
        return 0
    
    def update_state(self,move,player): #player should be either -1 or +1
        updated_state = copy.deepcopy(self)
        updated_state.previous_move = move
        # print(f"chosen move: {move}")
        updated_state.board_state[move//9,move%9]=player
        for i in range(updated_state.board_state.shape[0]):#Traversing every sub board
            updated_state.won_flag[i] = updated_state.check_result_board(updated_state.board_state[i])

            updated_state.moves_left[i] = updated_state.board_state[i] == 0
            updated_state.no_of_moves_left[i] = np.count_nonzero(updated_state.board_state[i] == 0)
        updated_state.first = np.all(updated_state.board_state==0)
        updated_state.finish = updated_state.check_result_board(updated_state.won_flag)
        
        return updated_state


    def format_board(self):
        display_board = np.zeros((9,9))
        for x in [3,6,9]:
            for i in range(x-3,x):
                for y in [3,6,9]:
                    for j in range(y-3,y):
                        display_board[(i//3)*3+(j//3),3*(i%3)+(j%3)]=self.board_state[i,j]
        return display_board
    
    def print_board(self):
        display_board = self.format_board()
                        # print(f"old index{i},{j}\nnew index {(i//3)*3+(j//3)},{3*(i%3)+(j%3)}\nupdated value{display_board[(i//3)*3+(j//3),3*(i%3)+(j%3)]}")
        
        for i in range(9):
            # Print horizontal divider after every 3 rows
            if i % 3 == 0 and i != 0:
                print("-" * 31)  # Adjusted width for better alignment
            
            row = ""
            for j in range(9):
                # Add vertical divider after every 3 columns
                if j % 3 == 0 and j != 0:
                    row += "| "
                # Add extra spacing for '1' to maintain alignment
                if display_board[i, j] != -1:
                    row += f" {int(display_board[i,j])} "  # Extra space before '1'
                else:
                    row += f"{int(display_board[i, j])} "  # Ensure consistent width for other values
            print(row)

    def legal_moves(self):
        legal_moves=[]

        if self.first == True:
            for i in range(len(self.moves_left)):
                for j in range(len(self.moves_left[0])):
                    legal_moves.append(i*9+j)
        elif self.finish == 0:
            i = self.previous_move%9
            if self.won_flag[i] == 0:
                for j in range(len(self.moves_left[0])):
                    if self.moves_left[i][j]:
                        legal_moves.append(i*9+j)
            else:
                for i in range(len(self.moves_left)):
                    if self.won_flag[i] == 0:
                        for j in range(len(self.moves_left[0])):
                            if self.moves_left[i][j]:
                                legal_moves.append(i*9+j)

        
        return legal_moves

    def inspect_board(self):
        print(self.board_state)

    def print_won_flag(self):
        for i in range(3):  # 3 rows
            row = ""
            for j in range(3):  # 3 columns
                value = self.won_flag[i * 3 + j]
                # Add extra space before 0 and 1 for formatting
                if value in [0, 1]:
                    row += f"  {int(value)}"
                else:
                    row += f" {int(value)}"
                # Add column separator if not the last column
                if j < 2:
                    row += " |"
            print(row)
            # Add row separator if not the last row
            if i < 2:
                print("----+----+----")

In [None]:
class MCTSNode:
    def __init__(self,state = None,total_score = 0,visit_count = 0):
        self.state = state
        self.total_score = total_score
        self.visit_count = visit_count
        self.children = []
        self.action = []
        self.parent = None

    def add_child(self,child_node,action):
        child_node.parent = self
        self.children.append(child_node)
        self.action.append(action)
        return self
    
    def prior_probability(self):
        return 1/(len(self.state.legal_moves()))

    def calculate_UCT_score(self, c = 1.25):
        if self.visit_count != 0:
            UCT_score = (self.total_score/self.visit_count) + c* (self.prior_probability())* (math.sqrt(self.parent.visit_count))/(1+self.visit_count)
        else:
            UCT_score = 10000
        return UCT_score
    
    def normalize_visit_count(self):
        visit_count_of_children = [child.visit_count for child in self.children]
        action_count_of_children = [action for action in self.action]
        probability_distribution = []
        for i in range(81):
            if i in action_count_of_children:
                index = self.action.index(i)
                probability_distribution.append(visit_count_of_children[index])
            else:
                probability_distribution.append(0)

        max_count = max(probability_distribution)
        exp_values = [math.exp(count - max_count) for count in probability_distribution]
        total = sum(exp_values)
        normalized_visit_count = [value / total for value in exp_values]
        return normalized_visit_count

    
    def max_visit_count(self):
        visit_count_of_children = [child.visit_count for child in self.children]
        action_index = visit_count_of_children.index(max(visit_count_of_children))
        return self.action[action_index]

    def is_root(self):
        return self.parent is None
    
    def is_leaf(self):
        return len(self.children) == 0

In [None]:
def selection(current_node,current_player):
    # print("Selecting...")
    depth = 0
    while not current_node.is_leaf():
        UCT_values = [] #UCT score values calculated for each possible action in legal moves
        for child in current_node.children:
            if child.state.finish==1 or child.state.finish==-1:
                UCT_values.append(child.state.finish)
            elif child.state.finish==-2:
                UCT_values.append(0)
            else:
                UCT_values.append(child.calculate_UCT_score())
        
        if current_player>0:
            max_UCT_value = max(UCT_values)
            max_UCT_index = UCT_values.index(max_UCT_value)
            current_node = current_node.children[max_UCT_index]
            # print(f"Max UCT value: {max_UCT_value}\nUCT index: {max_UCT_index+1}\nPlayer:{current_player}")
        elif current_player<0:
            min_UCT_value = min(UCT_values)
            min_UCT_index = UCT_values.index(min_UCT_value)
            current_node = current_node.children[min_UCT_index]
            # print(f"Min UCT value: {min_UCT_value}\nUCT index: {min_UCT_index+1}\nPlayer:{current_player}")

        
        current_player = current_player * (-1)
        depth+=1
        # print(f"Selecting at depth {depth}...")
        
    # print("Selection complete")
    return current_node,current_player

In [None]:
def expansion(current_node,current_player):
    # print("Expanding...")
    if current_node.state.finish != 0:
        # print("found terminal state while expanding")
        return current_node
    for action in current_node.state.legal_moves():
        current_node = current_node.add_child(child_node=MCTSNode(state =current_node.state.update_state(move=action,player=current_player)),action = action)
    
    current_node=random.choice(current_node.children)
    # print(f"Chosen:")
    # current_node.state.print_board()
    # print("Expanded")
    return current_node

In [None]:
def simulate_from_state(state):
    # print("Simulating...")
    i=1
    simulation_count = 0
    while state.finish == 0:
        state = state.update_state(move=random.choice(state.legal_moves()),player=i)
        i=i*(-1)
        simulation_count=simulation_count+1
    
    # print(f"No. of moves played: {simulation_count}")
    # print(f"Result: {state.finish}")
    # print("Final board")
    # state.print_board()
    # print("Mini board")
    # state.print_won_flag()

    # print("Simulated")

    return state.finish

In [None]:
def backpropagation(current_node, value):
    # print("Backpropagating...")
    while not current_node.is_root():
        current_node.visit_count +=1
        current_node.total_score += value
        current_node = current_node.parent
        
    
    current_node.visit_count +=1
    current_node.total_score += value
    # print(f"Visit count of root: {current_node.visit_count}\nTotal score of root: {current_node.total_score}")
    # print("Backpropagated")
    
    return current_node

In [None]:
def create_MCTS(state,simulation_limiter,player,policy_network_X,policy_network_y):
    # print("MCTS...")
    current_node = MCTSNode(state=state)
    # print(f"Created root node for MCTS")
    # for i in tqdm(range(simulation_limiter), desc="Simulating MCTS"):
    for i in range(simulation_limiter):
        # print(f"MCTS attempt {i+1}")
        current_node,current_player = selection(current_node=current_node,current_player=player)
        current_node = expansion(current_node=current_node,current_player=current_player)
        simulated_result = simulate_from_state(current_node.state)

        if simulated_result == -2:
            simulated_result = 0

        current_node = backpropagation(current_node=current_node,value=simulated_result)
    # print(f"MCTS attempt {i+1} completed")
    policy_network_X.append(current_node.state.board_state)
    policy_network_y.append(current_node.normalize_visit_count())
    # print(f"Added board:\n{current_node.state.board_state}\nAdded labels: {current_node.normalize_visit_count()}")
    return current_node.max_visit_count()  

In [None]:
def selfplay(iteration):
    policy_network_X = []
    policy_network_y = []
    value_network_X = []
    value_network_y = []
    state = State(board_state=np.zeros((9,9)))
    player=1
    move_count = 0
    # with tqdm(total =  0, desc = "Selfplay Moves", unit = "move") as pbar:
    while state.finish == 0:
        # pbar.set_description(f"Move count: {move_count}")
        # start_time = time.time()

        # print(f"Move count: {move_count}")
        move_to_play = create_MCTS(state=state,simulation_limiter=800,player=player,policy_network_X=policy_network_X,policy_network_y=policy_network_y)
        state = state.update_state(move=move_to_play,player=player)
        # print(f"Move played by player {player}: {move_to_play}")
        value_network_X.append(state.board_state)
        # print(f"Current Board: {state.print_board()}")

        # elapsed_time = time.time()-start_time
        # pbar.update(1)
        # pbar.set_postfix(time = f"{elapsed_time:.2f}s")

        player=player*(-1)
        move_count=move_count+1
    
    if state.finish == -2:
        result = 0
    else:
        result = state.finish
    # print("---Game completed---")
    # print(f"Result: {result}\nFinal Board:")
    # state.print_board()

    for i in range(move_count):
        value_network_y.append(result)

    data = {
        'value_network_X': value_network_X,
        'value_network_y': value_network_y,
        'policy_network_X': policy_network_X,
        'policy_network_y': policy_network_y
    }
    df = pd.DataFrame(data)

    df.to_csv(f'mcts_data{iteration}.csv', index=False)

    # from google.colab import files
    # files.download(f'mcts_data{iteration}.csv')

In [None]:
for i in range(100):
    selfplay(i+1)