In [None]:
from torch import randint
from torch import nn, optim
import torch 
import gym
import numpy as np

from copy import deepcopy
from collections import deque
import random

from scores.score_logger import ScoreLogger

import numpy as np

In [None]:
class TicTacToe:
    def __init__(self):
        self.state = np.zeros((3,3), dtype=int)
        self.player = "1"
        self.winner = None
        self.done = False

    def step(self, action):
        self.state[action] = self.player
        self.player = -self.player
        self.winner = self.get_winner(self.state)
        self.done = self.winner is not None or len(self.get_legal_actions()) == 0
        return self.state, self.winner, self.done

    def get_winner(self, state):
        for row in state:
            if abs(np.sum(row)) == len(row):
                return np.sum(row) / len(row)
        for col in state.transpose():
            if abs(np.sum(col)) == len(col):
                return np.sum(col) / len(col)
        if abs(np.sum(np.diag(state))) == len(np.diag(state)):
            return np.sum(np.diag(state)) / len(np.diag(state))
        if abs(np.sum(np.diag(np.fliplr(state)))) == len(np.diag(np.fliplr(state))):
            return np.sum(np.diag(np.fliplr(state))) / len(np.diag(np.fliplr(state)))
                
        return None

    def get_legal_actions(self, state):
        return np.argwhere(state == 0)

    def reset(self):
        self.state = np.zeros((3,3), dtype=int)
        self.player = "1"
        self.winner = None
        self.done = False
        return self.state


In [None]:
class MonteCarloTreeSearch:
    def __init__(self, env, tree, n_iterations=50, depth=15, exploration_constant=5.0):

        self.n_iterations = n_iterations
        self.depth = depth
        self.exploration_constant = exploration_constant
        self.total_simulation_count = 0
        self.env = env
        self.tree = tree.add_node((), (0, ), env.state, player=1)

    def selection(self): 

        leaf_node_found = False
        leaf_node_id = (0,)

        while not leaf_node_found:

            node = self.tree[leaf_node_id]
            
            if len(node.actions) == 0:
                leaf_node_id = node.id
                leaf_node_found = True
            else: 
                UCB = -100
                for action in node.actions:
                    child = self.tree[node.id + (action,)]

                    # prevent divide by zero where child.visits == 0
                    n = child.visits
                    if n == 0:
                        n = 1e-4

                    exploitation_value = child.reward / child.visits
                    exploration_value  = np.sqrt(np.log(self.total_simulation_count)/child.count)
                    uct_value = exploitation_value + self.exploration_constant * exploration_value

                    if uct_value > maximum_uct_value:
                        maximum_uct_value = uct_value
                        leaf_node_id = child.id

        depth = len(leaf_node_id) # as node_id records selected action set
        # print('leaf node found: ', leaf_node_found)
        # print('n_child: ', n_child)
        # print('selected leaf node: ')
        # print(self.tree[leaf_node_id])
        return leaf_node_id, depth
                    

    def expansion(self, leaf_node_id):
        '''
        create all possible outcomes from leaf node
        in: tree, leaf_node
        out: expanded tree (self.tree),
             randomly selected child node id (child_node_id)
        '''
        leaf_node = self.tree[leaf_node_id]
        winner = self.env.get_winner(leaf_node.state)
        possible_actions = self.env.get_legal_actions(leaf_node.state)

        child_node_id = leaf_node.id # default value in case of game termination
        if winner is None:
            '''
            when leaf state is not terminal state
            '''
            childs = []
            for action_set in possible_actions:
                action, action_idx = action_set
                state = deepcopy(leaf_node.state)

                if leaf_node.player == '1':
                    next_turn = '-1'
                    state[action] = 1
                else:
                    next_turn = '1'
                    state[action] = -1

                #Node id is a tuple of action set
                child = self.tree.add_node(leaf_node_id, action_idx, state)
                self.tree[leaf_node_id].actions.append(action_idx)

            rand_idx = np.random.randint(low=0, high=len(childs), size=1)
            # print(rand_idx)
            # print('childs: ', childs)
            child_node_id = childs[rand_idx[0]]

        return child_node_id

    
    def simulation(self, child_node_id):
        '''
        simulate game from child node's state until it reaches the resulting state of the game.
        in:
        - child node id (randomly selected child node id from `expansion`)
        out:
        - winner ('o', 'x', 'draw')
        '''
        self.total_simulation_count += 1

        #Deep copy so as to not update the actual node
        state = deepcopy(self.tree[child_node_id]['state'])
        previous_player = deepcopy(self.tree[child_node_id]['player'])
        anybody_win = False

        while not anybody_win:
            winner = self.env.get_winner(state)
            if winner is not None:
                # print('state')
                # print(state)
                # import matplotlib.pyplot as plt
                # plt.figure(figsize=(4.5,4.56))
                # plt.pcolormesh(state, alpha=0.6, cmap='RdBu_r')
                # plt.grid()
                # plt.axis('equal')
                # plt.gca().invert_yaxis()
                # plt.colorbar()
                # plt.title('winner = ' + winner + ' (o:1, x:-1)')
                # plt.show()
                anybody_win = True
            else:
                possible_actions = self.env.get_legal_actions(state)
                # randomly choose action for simulation (= random rollout policy)
                rand_idx = np.random.randint(low=0, high=len(possible_actions), size=1)[0]
                action, _ = possible_actions[rand_idx]

                if previous_player == '1':
                    current_player = '-1'
                    state[action] = 1
                else:
                    current_player = '1'
                    state[action] = -1

                previous_player = current_player
        return winner

    def backprop(self, child_node_id, winner):
        player = deepcopy(self.tree[(0,)].player)

        if winner == 'draw':
            reward = 0
        elif winner == player:
            reward = 1
        else:
            reward = -1

        node_id = child_node_id
        while (True):
            self.tree[node_id].visits += 1
            self.tree[node_id].reward += reward
            self.tree[node_id].q = self.tree[node_id].reward / self.tree[node_id].visits
            parent_id = self.tree[node_id].parent
            if parent_id == (0,):
                self.tree[parent_id].visits += 1
                self.tree[parent_id].reward += reward
                self.tree[parent_id].q = self.tree[parent_id].reward / self.tree[parent_id].visits
                break
            else:
                node_id = parent_id

    
    

In [None]:
class Node: 
    def __init__ (self, parent_id, action, state, player): 
        self.id = parent_id + (action,)
        self.parent = parent_id
        self.state = state
        self.actions = []
        player = player
        self.reward = 0
        self.visits = 0
        self.q = 0

class Tree: 
    def __init__ (self, state): 
        self.tree = {}
    
    def add_node(self, id, state, player=None):
        self.tree[id] = Node(id, state, player)


In [None]:
env = TicTacToe()
score_logger = ScoreLogger(ENV_NAME)

# Reset the environment and get the first state
state, info = env.reset(seed=46, return_info=True)

# Initialize the tree
tree = Tree(state)

# Create the agent
MONTY_CARLO_TREE_SEARCH = MonteCarloTreeSearch(env, tree, num_simulations=10, c=1.0)

run = 0 # run is the number of episodes
while run < 100:
    run += 1
    env.reset()
    step = 0
    while not env.done:
        step += 1
        # Get the action from the agent
        action = DQN_AGENT.act(state)
        # Step the environment and get the next state, reward, and done flag
        state, reward, done = env.step(action)
        # Update the agent with the new experience
        MONTY_CARLO_TREE_SEARCH.update(state, action, reward, done)
        # Render the environment
        env.render()
        # If the episode is done, break the loop
        if done:
            break


In [None]:
env = TicTacToe()

In [None]:
env.state