# MCTS
for simple 2-person games
Algorithm based on https://www.youtube.com/watch?v=UXW2yZndl7U

In [None]:
import numpy as np

### Game Definitions

#### Connect 2
Game based on http://joshvarty.github.io/AlphaZero/

In [None]:
class Connect2:
    def __init__(self, board=None, played=None, terminated=None, reward=None):
        # board and reward is not "perspective corrected"
        # player - O: 1, X: -1

        # self.board_:
        # [ 0 1 2 3 ]
        if board is None:
            self.raw_board = [0 for i in range(4)]
        else:
            self.raw_board = list(board)

        # game starts with a change of player, so set to opponent
        if played is None:
            self.played = -1
        else:
            self.played = played

        self.terminated = False if terminated is None else terminated

        if reward is None:
            self.raw_reward = 0
        else:
            self.raw_reward = reward

    def state(self):
        return {
            "board": self.raw_board,
            "played": self.played,
            "terminated": self.terminated,
            "reward": self.raw_reward
        }

    def valid_actions(self):
        # non-zero locations can be valid actions
        ret = []
        if self.terminated:
            return ret
        for idx, val in enumerate(self.raw_board):
            if val == 0:
                ret.append(idx)
        return ret

    def board(self, perspective=None):
        # perspective
        # - None: raw_board
        # - played: "played" player's perspective
        # - to_play: "to_play" player's perspective
        # - -1, 1: selected player's perspective
        if perspective is None:
            return self.raw_board
        elif perspective == "played":
            return [k * self.played for k in self.raw_board]
        elif perspective == "to_play":
            return [k * -self.played for k in self.raw_board]
        return [k * perspective for k in self.raw_board]

    def reward(self, perspective=None):
        if perspective is None:
            return self.raw_reward
        elif perspective == "played":
            return self.raw_reward * self.played
        elif perspective == "to_play":
            return self.raw_reward * -self.played
        return self.raw_reward * perspective
        
    def _eval_win(self, player):
        # check if player has won - enumerate all win positions
        win_positions = np.array([
            [0, 1], [1, 2], [2, 3]
        ])
        has_win = np.any(
            np.all((
                np.array(self.raw_board)[win_positions.ravel()] == player
            ).reshape(*win_positions.shape), axis=1)
        )
        return has_win
    
    def step(self, action, debug=False):
        if self.raw_board[action] == 0:
            self.raw_board[action] = -self.played
        else:
            # invalid action
            if debug:
                print("step: invalid action")
            if self.terminated is False:
                # automatic win for other player
                self.raw_reward = self.played
            self.terminated = True

        # player played
        self.played = -self.played

        if self._eval_win(self.played):
            if debug:
                print("step: eval_win")
            if self.terminated is False:
                self.raw_reward = self.played
            self.terminated = True

        if len(self.valid_actions()) == 0:
            if debug:
                print("step: no more valid actions")
            # tie if not already terminated?
            self.terminated = True

        return self
    
    def render(self):
        print("[", end=" ")
        for j in range(4):
            val = self.raw_board[j]
            s = " "
            if val == 1:
                s = "O"
            elif val == -1:
                s = "X"
            print(s, end=" ")
        print("]")

#### TicTacToe

In [None]:
class TicTacToe:
    def __init__(self, board=None, played=None, terminated=None, reward=None):
        # board and reward is not "perspective corrected"
        # player - O: 1, X: -1

        # self.board_:
        # [ 0 1 2
        #   3 4 5
        #   6 7 8 ]
        if board is None:
            self.raw_board = [0 for i in range(9)]
        else:
            self.raw_board = list(board)

        # game starts with a change of player, so set to opponent
        if played is None:
            self.played = -1
        else:
            self.played = played

        self.terminated = False if terminated is None else terminated

        if reward is None:
            self.raw_reward = 0
        else:
            self.raw_reward = reward

    def state(self):
        return {
            "board": self.raw_board,
            "played": self.played,
            "terminated": self.terminated,
            "reward": self.raw_reward
        }

    def valid_actions(self):
        # non-zero locations can be valid actions
        ret = []
        if self.terminated:
            return ret
        for idx, val in enumerate(self.raw_board):
            if val == 0:
                ret.append(idx)
        return ret

    def board(self, perspective=None):
        # perspective
        # - None: raw_board
        # - played: "played" player's perspective
        # - to_play: "to_play" player's perspective
        # - -1, 1: selected player's perspective
        if perspective is None:
            return self.raw_board
        elif perspective == "played":
            return [k * self.played for k in self.raw_board]
        elif perspective == "to_play":
            return [k * -self.played for k in self.raw_board]
        return [k * perspective for k in self.raw_board]

    def reward(self, perspective=None):
        if perspective is None:
            return self.raw_reward
        elif perspective == "played":
            return self.raw_reward * self.played
        elif perspective == "to_play":
            return self.raw_reward * -self.played
        return self.raw_reward * perspective
        
    def _eval_win(self, player):
        # check if player has won - enumerate all win positions
        win_positions = np.array([
            [0, 1, 2], [3, 4, 5], [6, 7, 8],
            [0, 3, 6], [1, 4, 7], [2, 5, 8],
            [0, 4, 8], [2, 4, 6],
        ])
        has_win = np.any(
            np.all((
                np.array(self.raw_board)[win_positions.ravel()] == player
            ).reshape(*win_positions.shape), axis=1)
        )
        return has_win
    
    def step(self, action, debug=False):
        if self.raw_board[action] == 0:
            self.raw_board[action] = -self.played
        else:
            # invalid action
            if debug:
                print("step: invalid action")
            if self.terminated is False:
                # automatic win for other player
                self.raw_reward = self.played
            self.terminated = True

        # player played
        self.played = -self.played

        if self._eval_win(self.played):
            if debug:
                print("step: eval_win")
            if self.terminated is False:
                self.raw_reward = self.played
            self.terminated = True

        if len(self.valid_actions()) == 0:
            if debug:
                print("step: no more valid actions")
            # tie if not already terminated?
            self.terminated = True

        return self
    
    def render(self):
        for j in range(3):
            print("[", end=" ")
            for i in range(3):
                val = self.raw_board[j * 3 + i]
                s = " "
                if val == 1:
                    s = "O"
                elif val == -1:
                    s = "X"
                print(s, end=" ")
            print("]")

#### Connect4

In [None]:
def get_win_positions_connect4():
    win_positions = []
    # vertical check
    for col in range(7):
        c = np.array([[0, 7, 14, 21], [7, 14, 21, 28], [14, 21, 28, 35]]) + col
        c = c.tolist()
        win_positions.extend(c)
    # horizontal check
    for row in range(6):
        c = np.array([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]]) + 7*row
        c = c.tolist()
        win_positions.extend(c)
    # diagonal (top-left to bottom-right check)
    start_pts = [0, 1, 2, 3, 7, 8, 9, 10, 14, 15, 16, 17]
    c = []
    for pt in start_pts:
        ci = [(pt + 8*i) for i in range(4)]
        c.append(ci)
    win_positions.extend(c)
    # diagonal (top-right to bottom-left check)
    start_pts = [3, 4, 5, 6, 10, 11, 12, 13, 17, 18, 19, 20]
    c = []
    for pt in start_pts:
        ci = [(pt + 6*i) for i in range(4)]
        c.append(ci)
    win_positions.extend(c)
    win_positions = np.array(win_positions)
    return win_positions

def valid_action_connect4(board):
    valid_actions = []
    for i in range(7):
        row_idxs = list(reversed([(i + j * 7) for j in range(6)]))
        for idx in row_idxs:
            if board[idx] == 0:
                valid_actions.append(idx)
                break
    return valid_actions

class Connect4:
    def __init__(self, board=None, played=None, terminated=None, reward=None):
        # board and reward is not "perspective corrected"
        # player - O: 1, X: -1

        # self.raw_board:
        # [  0  1  2  3  4  5  6
        #    7  8  9 10 11 12 13
        #   14 15 16 17 18 19 20
        #   21 22 23 24 25 26 27
        #   28 29 30 31 32 33 34
        #   35 36 37 38 39 40 41 ]
        if board is None:
            self.raw_board = [0 for i in range(42)]
        else:
            self.raw_board = list(board)

        # game starts with a change of player, so set to opponent
        if played is None:
            self.played = -1
        else:
            self.played = played

        self.terminated = False if terminated is None else terminated

        if reward is None:
            self.raw_reward = 0
        else:
            self.raw_reward = reward

    def state(self):
        return {
            "board": self.raw_board,
            "played": self.played,
            "terminated": self.terminated,
            "reward": self.raw_reward
        }

    def valid_actions(self):
        ret = []
        if self.terminated:
            return ret
        ret = valid_action_connect4(self.raw_board)
        return ret

    def board(self, perspective=None):
        # perspective
        # - None: raw_board
        # - played: "played" player's perspective
        # - to_play: "to_play" player's perspective
        # - -1, 1: selected player's perspective
        if perspective is None:
            return self.raw_board
        elif perspective == "played":
            return [k * self.played for k in self.raw_board]
        elif perspective == "to_play":
            return [k * -self.played for k in self.raw_board]
        return [k * perspective for k in self.raw_board]

    def reward(self, perspective=None):
        if perspective is None:
            return self.raw_reward
        elif perspective == "played":
            return self.raw_reward * self.played
        elif perspective == "to_play":
            return self.raw_reward * -self.played
        return self.raw_reward * perspective
        
    def _eval_win(self, player):
        # check if player has won - enumerate all win positions
        win_positions = get_win_positions_connect4()
        has_win = np.any(
            np.all((
                np.array(self.raw_board)[win_positions.ravel()] == player
            ).reshape(*win_positions.shape), axis=1)
        )
        return has_win
    
    def step(self, action, debug=False):
        if self.terminated is False and action in self.valid_actions():
        #if self.raw_board[action] == 0:
            self.raw_board[action] = -self.played
        else:
            # invalid action
            if debug:
                print("step: invalid action")
            if self.terminated is False:
                # automatic win for other player
                self.raw_reward = self.played
            self.terminated = True

        # player played
        self.played = -self.played

        if self._eval_win(self.played):
            if debug:
                print("step: eval_win")
            if self.terminated is False:
                self.raw_reward = self.played
            self.terminated = True

        if len(self.valid_actions()) == 0:
            if debug:
                print("step: no more valid actions")
            # tie if not already terminated?
            self.terminated = True

        return self
    
    def render(self):
        for j in range(6):
            print("[", end=" ")
            for i in range(7):
                val = self.raw_board[j * 7 + i]
                s = " "
                if val == 1:
                    s = "O"
                elif val == -1:
                    s = "X"
                print(s, end=" ")
            print("|", end=" ")
            for i in range(7):
                idx = j * 7 + i
                print("%2d" % idx, end=" ")
            print("]")

In [None]:
#game_class = Connect2
#game_class = TicTacToe
game_class = Connect4

### Test Run

In [None]:
def render_game_status(g):
    g.render()
    print("%2d" % g.played, "b:     ", g.board(perspective="played"), end="")
    print(" | T:", g.terminated, "| R:", g.reward(perspective="played"), "[" + str(g.reward()) + "]")
    print("%2d" % -g.played, "next_a:", g.valid_actions())
    print()

In [None]:
g = game_class()
render_game_status(g)
#for action in [3, 2, 0, 1]: # connect2
#for action in [3, 2, 0, 6, 4, 5, 7, 8]: # tictactoe
for action in [38, 37, 31, 30, 24, 23, 17]: # connect4
    g.step(action)
    render_game_status(g)

In [None]:
def random_rollout(g, n_steps=None, debug=False):
    current_player = g.played * -1
    if debug:
        print("current player:", current_player)
    step = 0
    while g.terminated is False:
        valid_actions = g.valid_actions()
        action_idx = np.random.choice(len(valid_actions))
        action = valid_actions[action_idx]
        g.step(action)
        if debug:
            print("player", "%2d" % g.played, "plays  ")
            g.render()
        step += 1
        if n_steps is not None:
            if step >= n_steps:
                break
    if debug:
        print("final board state")
        g.render()
        print("final reward (as current player):",
              g.reward(perspective=current_player))
    return g.reward(perspective=current_player)

In [None]:
g = game_class()
random_rollout(g, debug=True)

### MCTS Implementation

In [None]:
class Node:
    def __init__(self, prior, game_state, game_class=None, node_id=None):
        self.prior = prior # not used for pure MCTS
        self.game_state = game_state
        self.game_class = game_class
        self.node_id = node_id

        self.children = {}
        self.visit_count = 0
        self.value_sum = 0

    def value(self):
        return self.value_sum / self.visit_count

    def is_leaf_node(self):
        return len(self.children) == 0

    def __str__(self):
        return "%s: prior: %5.2f, n_child: %d, n_visit: %d, value_sum: %d" %\
            (str(self.node_id), self.prior, len(self.children), self.visit_count, self.value_sum)
    
    def restore_game(self):
        if self.game_class:
            return self.game_class(**self.game_state)
        return None

    def print_tree(self, depth=None):
        # currently max depth = 4
        if depth is None:
            depth = 999
        print(self)
        for i1 in self.children.keys():
            n1 = self.children[i1]
            print(" ", n1)
            if depth >= 2:
                for i2 in n1.children.keys():
                    n2 = n1.children[i2]
                    print("   ", n2)
                    if depth >= 3:
                        for i3 in n2.children.keys():
                            n3 = n2.children[i3]
                            print("     ", n3)
                            if depth >= 4:
                                for i4 in n3.children.keys():
                                    n4 = n3.children[i4]
                                    print("       ", n4)

def rollout(node, search_path, debug=False):
    if debug:
        print("rollout from", node.node_id)
    g = node.restore_game()
    if debug:
        g.render()
    current_player = g.played * -1
    reward = random_rollout(g, debug=False)
    if debug:
        print("playing as", current_player, "reward:", reward)
        g.render()

    mod_reward = reward
    for n in reversed(search_path):
        n.value_sum += mod_reward
        n.visit_count += 1
        mod_reward *= -1

def ucb1(parent, child):
    if child.visit_count == 0:
        return np.inf
    return (-child.value_sum / child.visit_count) + 2 * np.sqrt(np.log(parent.visit_count) / child.visit_count)
    """ 
    # Definition from http://joshvarty.github.io/AlphaZero/
    prior_score = child.prior * np.sqrt(parent.visit_count) / (child.visit_count + 1)
    if child.visit_count > 0:
        # The value of the child is from the perspective of the opposing player
        value_score = -child.value()
    else:
        value_score = 0

    return value_score + prior_score
    """

def expand_node(node):
    game_class = node.game_class
    valid_actions = node.restore_game().valid_actions()
    for action in valid_actions:
        g = node.restore_game().step(action)
        node.children[action] = Node(
            1/len(valid_actions),
            g.state(),
            game_class,
            node.node_id + "-" + str(action)
        )
    return valid_actions

def run_mcts(node, num_simulations=5, debug=False):
    game_class = node.game_class
    for i in range(num_simulations):
        current = node
        search_path = [current]
        if debug:
            print()
            print("simulation:", i)
        while not current.is_leaf_node():
            #print("not_leaf_node")
            ucb1_max_value = None
            ucb1_action = None
            for action, child_node in current.children.items():
                if ucb1_max_value is None:
                    ucb1_max_value = ucb1(current, child_node)
                    ucb1_action = action
                else:
                    ucb1_value = ucb1(current, child_node)
                    if ucb1_max_value < ucb1_value:
                        ucb1_max_value = ucb1_value
                        ucb1_action = action
            current = current.children[ucb1_action]
            search_path.append(current)
        if debug:
            print("initial:")
            node.print_tree()

        #print("is_leaf_node")
        if current.visit_count == 0:
            #print("first rollout")
            rollout(current, search_path, debug=debug)
            if debug:
                node.print_tree()
        else:
            #print("expand and rollout")
            valid_actions = expand_node(current)

            if len(valid_actions) == 0:
                g = current.restore_game()
                reward = g.reward(perspective="to_play")
                for n in reversed(search_path):
                    n.value_sum += reward
                    n.visit_count += 1
                    reward *= -1
            else:
                first_action = valid_actions[0]
                current = current.children[first_action]
                search_path.append(current)

                rollout(current, search_path, debug=debug)
            if debug:
                node.print_tree()
    return node

def get_visit_counts(node):
    db = []
    visit_counts = [0 for i in range(len(g.raw_board))]
    for n in node.children:
        db.append([n, node.children[n].visit_count])
        visit_counts[n] = node.children[n].visit_count
    if len(db) == 0:
        return None, visit_counts
    return sorted(db, key=lambda x: x[1], reverse=True)[0][0], visit_counts

### Basic Manual Run

In [None]:
g = game_class()
node = Node(1, g.state(), game_class, node_id="root")
valid_actions = expand_node(node)
#run_mcts(node, num_simulations=100, debug=False) # connect2
#run_mcts(node, num_simulations=1000, debug=False) # tictactoe
run_mcts(node, num_simulations=2000, debug=False) # connect4

In [None]:
node.print_tree(depth=1)

In [None]:
for i in g.valid_actions():
    print(i, ucb1(node, node.children[i]))

### Self Play Episode

In [None]:
def run_self_play(game_class, num_simulations=100, run_status=True, debug=False):
    action_sequence = []
    stats = []
    while True:
        if run_status:
            print(".", end="")
        g = game_class()
        for action in action_sequence:
            g.step(action)
        node = Node(1, g.state(), game_class, node_id="root")
        valid_actions = expand_node(node)
        run_mcts(node, num_simulations=num_simulations, debug=debug)
        next_action, visit_counts = get_visit_counts(node)
        if g.terminated is True or next_action is None:
            for idx, (to_play, board, visit_prob, _) in enumerate(stats):
                reward = g.reward(perspective=to_play)
                stats[idx][-1] = reward
            break
        to_play = -g.played
        board = g.board(perspective=to_play)
        visit_prob = visit_counts / np.sum(visit_counts)
        stats.append([to_play, board, visit_prob, 0])
        action_sequence.append(next_action)
    if run_status:
        print()
    return action_sequence, stats

In [None]:
#action_sequence, stats = run_self_play(game_class, 100) # connect2
#action_sequence, stats = run_self_play(game_class, 1000) # tictactoe
action_sequence, stats = run_self_play(game_class, 2000) # connect4

In [None]:
np.array(action_sequence)

In [None]:
g = game_class()
for action in action_sequence:
    print("play", action)
    g.step(action)
    render_game_status(g)

### Interactive Play

In [None]:
action_sequence = []
while True:
    g = game_class()
    for action in action_sequence:
        g.step(action)
    g.render()
    print()
    action = input()
    action = int(action)
    action_sequence.append(action)
    
    g = game_class()
    for action in action_sequence:
        g.step(action)
    g.render()
    print()
    node = Node(1, g.state(), game_class, node_id="root")
    valid_actions = expand_node(node)
    run_mcts(node, num_simulations=2000, debug=False)
    next_action, _ = get_visit_counts(node)
    if next_action is None:
        break
    action_sequence.append(next_action)