# Monte Carlo Tree Search 

> Fill in a module description here

In [None]:
#| default_exp mcts

In [None]:
#| hide
from nbdev.showdoc import show_doc
import nbdev; nbdev.nbdev_export()

- Step 1: It takes the current game state
- Step 2: It runs multiple random game simulations starting from this game state
- Step 3: For each simulation, the final state is evaluated by a score
- Step 4: It only remembers the next move of each simulation and accumulates the score for that move
- Step 5: After the simulation is done, it returns the next move with the highest score

In [None]:
#| export
import math
from typing import List, Tuple, Optional, Union, Dict, Literal
from enum import Enum

import torch
import gym

from muzero.chess.game import get_init_board, place_piece, get_valid_moves, is_board_full, is_win

In context of Go, each board of the game is a node, each node contains who turn to play...

In [None]:
#| export
def ucb_score(parent, child):
    prior_score = child.prior_prob * math.sqrt(parent.visits) / (child.visits + 1)
    
    if child.visits > 0:
        value_score = child.value / child.visits
    else:
        value_score = 0
    
    return value_score + prior_score

In [None]:
#| export
class Player(Enum):
    BLACK = 1
    WHITE = -1

In [None]:
#| export
class Node:
    def __init__(self, prior_prob: float, player_turn: Player, state: torch.Tensor):
        """_summary_

        Args:
            prior_prob (float): _description_
            player_turn (_type_): _description_
            state (_type_): _description_
        
        Attr:
            children (Dict[int, Node]): a dictionary of child nodes, indexed by action
            value (Union[int, float]): the total reward value of all visits to this node
            visits (int): the number of times this node has been visited
        """
        self.prior_prob: float = prior_prob
        self.player_turn = player_turn
        self.state: torch.Tensor = state
        
        self.children: Dict[int, Node] = {}
        self.value: Union[int, float] = 0
        self.visits: int = 0
    
    def get_next_player_turn(self, current_turn: Player) -> Player:
        next_player_turn = Player.BLACK if current_turn == Player.WHITE else Player.WHITE
        return next_player_turn
    
    def expand(self, action_probs: List[float]):
        for action, prob in enumerate(action_probs):
            if prob > 0:
                next_player_turn = self.get_next_player_turn(self.player_turn)
                next_state = place_piece(board=self.state, player=next_player_turn.value, action=action)
                
                self.children[action] = Node(
                    prior_prob=prob,
                    player_turn=next_player_turn,
                    state=next_state
                )
    
    def select_child(self):
        max_score = -9999
        
        for action, child in self.children.items():
            score = ucb_score(self, child)
            
            if score > max_score:
                max_score = score
                selected_action = action
                selected_child = child
        
        return selected_action, selected_child

    # def ucb_score(self) -> Union[int, float]:
    #     """The UCB score of a node."""
    #     pass

In [None]:
# go_env = gym.make('gym_go:go-v0', size=7, komi=0, reward_method='real')
# go_env.reset()

In [None]:
import numpy as np
from muzero.chess.view_board import render

board = np.array(
    [[0, -1, -1, -1, 1, 0, -1],
     [0, 1, -1, 1, 1, 0, 1],
     [-1, 1, -1, 1, 1, 0, -1],
     [1, -1, 1, -1, -1, 0, -1],
     [-1, -1, 1, -1, 1, 1, -1],
     [-1, 1, 1, -1, 1, -1, 1]]
)


In [None]:
root = Node(
    prior_prob=0, player_turn=1, state=torch.tensor(board)
)

In [None]:
root.expand(
    action_probs=[0.5, 0, 0, 0, 0, 0.5, 0]
)

In [None]:
root.children

{0: <__main__.Node>, 5: <__main__.Node>}

In [None]:
# render(root.children[0].state)

In [None]:
n_simulations = 100

In [None]:
#| export
def dummy_model_predict(board):
	value_head = 0.5
	action_probs = [0.5, 0, 0, 0, 0, 0.5, 0]
	return value_head, action_probs

In [None]:
for _ in range(n_simulations):
    node = root
    
    search_path = [node]
    
    while len(node.children) > 0:
        # select the next child until we reach an unexpaned node
        action, node = node.select_child()
        search_path.append(node)
    
    value: Optional[Union[int, float]] = None
    
    # calculate the value once we reach a leaf node
    if is_board_full(board=node.state):
        value = 0
    elif is_win(board=node.state, player=1):
        value = 1
    elif is_win(board=node.state, player=-1):
        value = -1
     
    if value is None:
        # if game is not over, get value from network and expand
        # TODO: why game not over? if you continue expand, then one point the game must end?
        value, action_probs = dummy_model_predict(node.state)
        
        node.expand(action_probs)
    
    # back up the value
    for node in search_path:
        node.value += value
        node.visits += 1

In [None]:
root.children[0].__dict__

{'prior_prob': 0.5,
 'player_turn': <Player.WHITE: -1>,
 'state': array([[ 0, -1, -1, -1,  1,  0, -1],
        [-1,  1, -1,  1,  1,  0,  1],
        [-1,  1, -1,  1,  1,  0, -1],
        [ 1, -1,  1, -1, -1,  0, -1],
        [-1, -1,  1, -1,  1,  1, -1],
        [-1,  1,  1, -1,  1, -1,  1]]),
 'children': {0: <__main__.Node>,
  5: <__main__.Node>},
 'value': 79.5,
 'visits': 98}

In [None]:
root.children[5].value

-2

In [None]:
root.children


{0: <__main__.Node>, 5: <__main__.Node>}