# Monte Carlo Tree Search Lab

In this lab, we'll be using the game connect four, as a vehicle for learning MinMax and Monte Carlo Tree Search.
We'll also introduce concepts, such as state, that'll stay relevant throughout the course.
Expect to lose in connect four to the algorithm at the end of the lab.

## Setup
This section you won't need to edit, but it is worth skimming through—this is where we declare the objects you'll be interacting with througout the lab.

In [1]:
# imports
import random
from typing import List, Tuple
import time
from copy import deepcopy # world -> thought

In [None]:
# world and world model
class State:
    def __init__(self, cols=7, rows=6, win_req=4):
        self.board = [['.'] * cols for _ in range(rows)]
        self.heights = [1] * cols
        self.num_moves = 0
        self.win_req = win_req

    def get_avail_actions(self) -> List[int]:
        return [i for i in range(len(self.board[0])) if self.heights[i] <= len(self.board)]
  
    def put_action(self, action, agent):
        self.board[len(self.board) - self.heights[action]][action] = agent.name
        self.heights[action] += 1
        self.num_moves += 1

    def is_over(self):
        return self.num_moves >= len(self.board) * len(self.board[0])

    def __repr__(self):
        return self.__str__()
    
    def __str__(self):
        header  = " ".join([str(i) for i in range(len(self.board[0]))])
        line    = "".join(["-" for _ in range(len(header))])
        board   = [[e for e in row] for row in self.board]
        board   = '\n'.join([' '.join(row) for row in board])
        return  '\n' + header + '\n' + line + '\n' + board + '\n'


In [14]:
t = State()
t.board

[['.', '.', '.', '.', '.', '.', '.'],
 ['.', '.', '.', '.', '.', '.', '.'],
 ['.', '.', '.', '.', '.', '.', '.'],
 ['.', '.', '.', '.', '.', '.', '.'],
 ['.', '.', '.', '.', '.', '.', '.'],
 ['.', '.', '.', '.', '.', '.', '.']]

In [15]:
# evaluate the utility of a state
def utility(state: 'State'):
    board = state.board
    n_cols = len(board[0]) - 1
    n_rows = len(board) - 1
    # print("n_cols: ", n_cols)
    # print("n_rows: ", n_rows)

    def diags_pos():
        """Get positive diagonals, going from bottom-left to top-right."""
        for di in ([(j, i - j) for j in range(n_cols)] for i in range(n_cols + n_rows - 1)):
            yield [board[i][j] for i, j in di if i >= 0 and j >= 0 and i < n_cols and j < n_rows]

    def diags_neg():
        """Get negative diagonals, going from top-left to bottom-right."""
        for di in ([(j, i - n_cols + j + 1) for j in range(n_cols)] for i in range(n_cols + n_rows - 1)):
            yield [board[i][j] for i, j in di if i >= 0 and j >= 0 and i < n_cols and j < n_rows]

    cols = list(map(list, list(zip(*board))))
    rows = board
    diags = list(diags_neg()) + list(diags_pos())
    lines = rows + cols + diags
    # lines = diags
    strings = ["".join(s) for s in lines]
    for string in strings:
        # print(string)
        if 'OOOO' in string:
            return -1
        if 'XXXX' in string:
            return 1
    return 0


In [16]:
# parrent class for mcts, minmax, human, and any other idea for an agent you have
class Agent:
    def __init__(self, name: str):
        self.name: str = name

    def get_action(self, state: State):
        return random.choice(state.get_avail_actions())

In [17]:
# connecting states and agents
class Game:
    def __init__(self, agents: Tuple[Agent]):
        self.agents = agents
        self.state = State()

    def play(self):
        while utility(self.state) == 0 and not self.state.is_over():
            for agent in agents:
                if utility(self.state) == 0 and not self.state.is_over():
                    action = agent.get_action(self.state)
                    self.state.put_action(action, agent)
                    print(self.state)

## Exercise 0: Discuss and Run game
Let's discuss if the `utility` function best belongs to the state or the agent.
Put the state, agent and game class together so that a game is run.

In [18]:
agents = (Agent('O'), Agent('X'))
game = Game(agents)
game.play()


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . O . .


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . X . .
. . . . O . .


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . X . .
. . O . O . .


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . X . .
. . . . X . .
. . O . O . .


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . X . .
. . . . X . .
. . O O O . .


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . X . .
. . . . X . .
. X O O O . .


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . X . .
. . . . X . .
O X O O O . .


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . X . .
. X . . X . .
O X O O O . .


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . X . .
. X

## Exercise 1: Human Agent
Make a child class of `Agent` called `Human`, with the `get_action` method overwritten to take input from you. *hint*: use `int(input())`

In [19]:
class Human(Agent):
    def __init__(self, name):
        super(Human, self).__init__(name)
        
    def get_action(self, state: State):
        action = int(input())
        while not (0 <= action <= 6):
            action = int(input())
        
        return action

In [20]:
agents = (Agent('O'), Human('X'))
game = Game(agents)
#game.play()

## Exercise 2: Gekko
Make a child class of `Agent` called `Gekko`, with a `get_action` that is very short sighted (greedy). You can basically do whatever you want here, as long as your output a valid action. You might want to make a `utility` function for the agent, and perhaps some helper functions. Write a two line comment explaining your Gekko's heuristic.

In [21]:
class Gekko(Agent):
    def __init__(self, name, enemy = 'O'):
        super(Gekko, self).__init__(name)
        self.enemy = enemy
        
    def get_action(self, state: State):
        #action = random.choice(state.get_avail_actions())
        #print(self.utility(state))
        action = self.utility(state)
        # board = state.board
        # print(board)
        return action 
        
    def diags_pos(self, rows, n_cols, n_rows):
        """Get positive diagonals, going from bottom-left to top-right."""
        for di in ([(j, i - j) for j in range(n_cols)] for i in range(n_cols + n_rows - 1)):
            yield [rows[i][j] for i, j in di if i >= 0 and j >= 0 and i < n_cols and j < n_rows]
        
    def diags_neg(self, rows, n_cols, n_rows):
        """Get negative diagonals, going from top-left to bottom-right."""
        for di in ([(j, i - n_cols + j + 1) for j in range(n_cols)] for i in range(n_cols + n_rows - 1)):
            yield [rows[i][j] for i, j in di if i >= 0 and j >= 0 and i < n_cols and j < n_rows]
    
    """
        Simple utility function that looks at the state of the board after each action taken by the agent and evaluates the utility of the action.
        The heuristic itself is simple and creates a value based on how close the agent is to winning vs the opponent. If the opponent
        is close to winning, the agent will aim to block the opponent, while if the agent is close to winning, it will prioritize those moves.
    """
    
    def utility(self, state: State):
        actions = state.get_avail_actions()
        n_cols = len(state.board[0]) - 1
        n_rows = len(state.board) - 1
        
        move_utility = [0]*7
        
        for a in actions:
            rows = state.board.copy()
            rows[len(rows) - state.heights[a]][a] = self.name
            cols = list(map(list, list(zip(*rows))))
            diags = list(self.diags_neg(rows, n_cols, n_rows)) + list(self.diags_pos(rows, n_cols, n_rows))
            
            lines = rows + cols + diags
            strings = ["".join(s) for s in lines]
            
            # Simple hardcoded heuristic that considers the placement of opponent and own move
            opponent_score = 0
            own_score = 0
            
            # Naive piece placement evaulation. Should ideally be more comprehensive and take more combinations into considerations.
            for string in strings:
                if str(['.',self.enemy,self.enemy,self.enemy]) in string or str([self.enemy,',', self.enemy,self.enemy]) in string  or str([self.enemy, self.enemy, ',',self.enemy]) in string  or str([self.enemy, self.enemy, self.enemy, ',']) in string:
                #if '.OOO' in string:
                    if opponent_score < 1000: 
                        opponent_score += 1000
                    print(string)
                    
                #elif '..OO' in string  or '.O.O' in string  or '.OO.' in string  or 'O..O' in string  or 'O.O.' in string  or 'OO..' in string :
                 #   if opponent_score < 5:
                 #       opponent_score = 5
            
            for string in strings:
                if str([self.name,self.name,self.name,self.name]) in string:
                    own_score += 10000 # high score to make sure the agent priotizes this move if it's very close to winning.
                    break
                elif str(['.',self.name,self.name,self.name]) in string or str([self.name,',', self.name,self.name]) in string  or str([self.name, self.name, ',',self.name]) in string  or str([self.name, self.name, self.name, ',']) in string:
                    own_score += 100
                #elif '..XX' in string or '.X.X' in string or '.XX.' in string or 'X..X' in string or 'X.X.' in string or 'XX..' in string:
                 #   own_score += 5
            
            
            move_utility[a] = own_score - opponent_score
            
            rows[len(rows) - state.heights[a]][a] = '.'
            
        return move_utility.index(max(move_utility))
        
        

In [22]:
agents = (Agent('O'), Gekko('X'))
game = Game(agents)
game.play()


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . O


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
X . . . . . O


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . O
X . . . . . O


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
X . . . . . O
X . . . . . O


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
O . . . . . .
X . . . . . O
X . . . . . O


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
X . . . . . .
O . . . . . .
X . . . . . O
X . . . . . O


0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
X . . . . . .
O . . . . . .
X . . . . . O
X . . . . O O


0 1 2 3 4 5 6
-------------
. . . . . . .
X . . . . . .
X . . . . . .
O . . . . . .
X . . . . . O
X . . . . O O


0 1 2 3 4 5 6
-------------
. . . . . . .
X . . . . . .
X . . . . . .
O . . . . . .
X .

## *Optional exercise: MinMax (useful to have done for exercise 3)*
Make a MinMax agent:

In [23]:
class MinMax(Agent):
    def __init__(self, name):
        super(MinMax, self).__init__(name)
        
    def minmax(self, depth, state: State, player): 
        val = utility(state)
        if depth == 0 or val == 1 or val == -1:
            return val
        
        actions = state.get_avail_actions()
        
        for a in actions:
            rows = state.board.copy()
            rows[len(rows) - state.heights[a]][a] = player
            cols = list(map(list, list(zip(*rows))))
            diags = list(self.diags_neg(rows, n_cols, n_rows)) + list(self.diags_pos(rows, n_cols, n_rows))
            
            lines = rows + cols + diags
            strings = ["".join(s) for s in lines]
            
            if player == 'X':
                return max(minmax(depth-1, state, 'O'))
            else:
                return min(minmax(depth-1, state, 'X'))
            
            rows[len(rows) - state.heights[a]][a] = '.'
        
    
    def utility(self, state: State):
        rows = state.board.copy()
        rows[len(rows) - state.heights[a]][a] = player
        cols = list(map(list, list(zip(*rows))))
        diags = list(self.diags_neg(rows, n_cols, n_rows)) + list(self.diags_pos(rows, n_cols, n_rows))
            
        lines = rows + cols + diags
        strings = ["".join(s) for s in lines]
        for string in strings:
            if 'OOOO' in string:
                return -1
            if 'XXXX' in string:
                return 1
        return 0
        

## Exercise 3: MCTS
Same but for Monte Carlo Tree Search. See if you can beat it with a `Human`.

In [69]:
class Node:
    def __init__(self, state: State, parent: 'Node' = None):
        self.children: List['Nodes'] = [None] * 7
        self.parent: 'Node' = parent
        self.state: State = state
        self.visits = 0
        self.quality = 0

In [89]:
import copy

class MCTS(Agent):
    def __init__(self, name, enemy = 'O'):
        super(MCTS, self).__init__(name)
        self.enemy = enemy
        self.exploration = 0.25
        
    def make_move(self, node, a, player):
        rows = node.state.board
        rows[len(rows) - node.state.heights[a]][a] = player
        node.state.heights[a] += 1
        node.state.num_moves += 1
        
        
    def undo_move(self, node, a):
        rows = node.state.board
        node.state.heights[a] -= 1
        node.state.num_moves -= 1
        rows[len(rows) - node.state.heights[a]][a] = '.'
        
        
    def mcts_search(self, state: State, n_searches):
        root_node = Node(state)
        max_depth = 100
        
        for i in range(n_searches):
            node = self.select(root_node)
            # print("node: ", node.state)
            value = self.simulation(max_depth, node, 'O')
            self.back_prop(node, value)
            
        print(root_node.quality)
        print(root_node.visits)
        actions = state.get_avail_actions()
        
        print("best actions:")
        
        for a in actions:
            c = root_node.children[a]
            print(a, ": ", c.quality, " / ",c.visits)
            
        best_child_index = self.best_child_index(root_node)
        
        #print("best child state:")
        #print(root_node.state)
        #print("parent state:")
        #print(root_node.children[best_child_index].state)
        
        
        return best_child_index
        # return random.choice(state.get_avail_actions())
    
    def select(self, parent: Node):
        actions = parent.state.get_avail_actions()
        
        # Expand child nodes that haven't been explored yet
        for a in actions:
            if parent.children[a] is None:
                c_node = Node(parent.state, parent)
                c_node.state = copy.deepcopy(parent.state)
                self.make_move(c_node, a, 'X')
                # print(c_node.state)
                parent.children[a] = c_node
                return c_node
                
        # For now, random
        node = random.choice(parent.children)
        # index = parent.children.index(node)
        
        return node
        
        # if not dead_end:
       # best_child_index = self.best_child_index(parent)
        
        #return parent.children[best_child_index]
        # return self.select(parent.children[best_child_index])
        
        # return self.best_child_index(parent)
            
    
    def best_child_index(self, parent: Node, player = ''):
        if player == '':
            player = self.name
        
        c_values = [0]*7
        for i in range(len(parent.children)):
            c = parent.children[i]
            if c is None:
                c_values[i] = -1
            else:
                # print(i)
                #c_values[i] = c.quality/c.visits
                if player is self.name:
                    c_values[i] = c.quality/c.visits
                else:
                    c_values[i] = (c.visits - c.quality)/c.visits
                # c_values[i] = random.randint(0, 1000)
                # print(c.quality)
        # print(c_values)
        
        best = max(c_values)
        if best == -1:
            return -1
        else:
            return c_values.index(max(c_values))
        
    def simulation(self, depth, node: Node, player): 
        n_cols = len(node.state.board[0]) - 1
        n_rows = len(node.state.board) - 1
        actions = node.state.get_avail_actions()
        result = self.utility(node.state)
        
        if len(actions) == 0:
            #print("draw")
            #print("move: ", node.state)
            return 0
        
        #a = random.choice(actions)
        
        if random.random() < self.exploration:
            a = random.choice(actions)
        else:
            #a = random.choice(actions)
            a = self.best_child_index(node)
            if a == -1:
                a = random.choice(actions)
            
        
        
        # a = self.best_child_index(node)
        
        rows = node.state.board
        self.make_move(node, a, player)
        # print("move: ", node.state)
        
        if result == 1:
            #print("won")
            # print("move: ", node.state)
            self.undo_move(node, a)

            return 1
        elif result == -1:
            #print("lost")
            # print("move: ", node.state)
            self.undo_move(node, a)

            return 0
            
        
        if player == self.name:
            value = self.simulation(depth-1, node, self.enemy)
            
        else:
            value = self.simulation(depth-1, node, self.name)
            
        self.undo_move(node, a)

        return value
        
    
    def back_prop(self, node: Node, value):
        node.visits += 1
        node.quality += value
        if node.parent is not None:
            self.back_prop(node.parent, value)
    
    def utility(self, state: State):
        # print("utility")
        rows = state.board
        n_cols = len(state.board[0]) - 1
        n_rows = len(state.board)
        cols = list(map(list, list(zip(*rows))))
        #print(n_cols)
        #print(n_rows)
        
        def diags_pos():
            """Get positive diagonals, going from bottom-left to top-right."""
            for di in ([(j, i - j) for j in range(n_cols)] for i in range(n_cols + n_rows - 1)):
                yield [rows[i][j] for i, j in di if i >= 0 and j >= 0 and i < n_cols and j < n_rows]

        def diags_neg():
            """Get negative diagonals, going from top-left to bottom-right."""
            for di in ([(j, i - n_cols + j + 1) for j in range(n_cols)] for i in range(n_cols + n_rows - 1)):
                yield [rows[i][j] for i, j in di if i >= 0 and j >= 0 and i < n_cols and j < n_rows]
        
        diags = list(diags_neg()) + list(diags_pos())
        #diags = list(self.diags_neg(rows, n_cols, n_rows)) + list(self.diags_pos(rows, n_cols, n_rows))
            
        lines = rows + cols + diags
        # lines = diags
        # lines = rows + cols
        strings = ["".join(s) for s in lines]
        for string in strings:
            # print(string)
            if self.enemy*4 in string:
                #print("0000")
                return -1
            if self.name*4 in string:
                #print("XXXX")
                return 1
        # print()
        # print()
        return 0
        
    
    def get_action(self, state: State):
        return self.mcts_search(state, 1000)
        # return random.choice(state.get_avail_actions())

In [91]:
m = MCTS('X', 'O')
#agents = (m, Gekko('O'))
agents = (m, Gekko('O'))
#agents = (m, Gekko('O'))
#agents = (m, Human('O'))
#agents = (Human('X'), Human('O'))
game = Game(agents)
game.play()


578
1000
best actions:
0 :  82  /  149
1 :  72  /  156
2 :  66  /  118
3 :  97  /  144
4 :  95  /  151
5 :  80  /  134
6 :  86  /  148

0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . X . . .

4

0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . X O . .

612
1000
best actions:
0 :  75  /  132
1 :  91  /  143
2 :  89  /  158
3 :  90  /  134
4 :  99  /  155
5 :  87  /  141
6 :  81  /  137

0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . . . . .
. . . X . . .
. . . X O . .

3

0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . O . . .
. . . X . . .
. . . X O . .

594
1000
best actions:
0 :  94  /  156
1 :  90  /  137
2 :  93  /  151
3 :  85  /  145
4 :  81  /  124
5 :  75  /  136
6 :  76  /  151

0 1 2 3 4 5 6
-------------
. . . . . . .
. . . . . . .
. . . . . . .
. . . O . . .
. . . X . . .
. X . X O . .

0

0 1 2 3 4 5 6
----

AttributeError: 'NoneType' object has no attribute 'state'

# *Optional exercise: Dynamic Programming*
Then use dynamic programming to make your AI more efficient. You can use the class below (or not)

In [None]:
class TranspositionTable:
    def __init__(self, size=1_000_000):
        self.size = size
        self.vals = [None] * size

    def board_str(self, state: State):
        return ''.join([''.join(c) for c in state.board])

    def put(self, state: State, utility: float):
        bstr = self.board_str(state)
        idx = hash(bstr) % self.size
        self.vals[idx] = (bstr, utility)

    def get(self, state: State):
        bstr = self.board_str(state)
        idx = hash(bstr) % self.size
        stored = self.vals[idx]
        if stored is None:
            return None
        if stored[0] == bstr:
            return stored[1]
        else:
            return None