# Solving the game of Nim using Monte Carlo / Monte Carlo Tree Search

Nim
- You have N matchsticks
- Each turn, player will take from 1 - K matchsticks
- Player which takes the last matchstick loses

# Game Coding

In [1]:
import numpy as np
class Nim:
    def __init__(self, N=21, K=3):
        self.N = N
        self.K = K
        self.reset()
        
    def move(self, num):
        if 1 <= num <= self.K:
            # remove matchstick
            self.matchstick -= num
            # swap player's turn (1 for Player 1, 2 for Player 2)
            self.turn = 3 - self.turn
            
            # if win, then end the game
            if self.lose():
                self.done = 1
                self.reward = 1 if self.turn == 1 else -1
                
            return (self.matchstick, self.turn, self.reward, self.done)
                
        else:
            print(f"{num}: Invalid move")
            return None, None, None, None
            
    def lose(self):
        ''' You lose when you take the last matchstick '''
        return self.matchstick == 0
    
    def validmoves(self):
        return list(range(1, min(self.K, self.matchstick)+1))
    
    def sample(self):
        ''' Returns a random move '''
        return np.random.choice(self.validmoves())
        
    def reset(self):
        self.matchstick = self.N
        self.turn = 1
        self.done = 0
        self.reward = 0
        

In [2]:
def Game(Agent1, Agent2, N, K, verbose = True):
    ''' This plays a game with Agent1 and Agent 2 
    Inputs: Agent 1 - Program to make move for player 1
            Agent 2 - Program to make move for player 2
            N - number of matchsticks at beginning
            K - number of matchsticks (upper limit) to remove each turn
            verbose - whether to print out the game
    '''

    env = Nim(N, K)
    while not env.done:
        if env.turn == 1:
            move = Agent1(env.matchstick, K, env.validmoves())
        else:
            move = Agent2(env.matchstick, K, env.validmoves())
        # move = env.sample()
        matchstick, turn, reward, done = env.move(move)
        if verbose:
            print(f'Player {3-turn} takes {move} matchstick(s). There are {matchstick} matchstick(s) remaining.')

    playerwin = 1 if env.reward == 1 else 2
    if verbose:
        print(f'Player {playerwin} wins!')
    return env.reward

In [3]:
# Run the game with random agents 100 times to see outcome
def Tournament(numgames, Agent1, Agent2, N, K):
    ''' This plays a tournament with Agent1 and Agent 2 
    Inputs: numgames - Number of games for the tournament
            Agent 1 - Program to make move for player 1
            Agent 2 - Program to make move for player 2
            N - number of matchsticks at beginning
            K - number of matchsticks (upper limit) to remove each turn
    '''
    p1win, p2win = 0, 0
    for gamenum in range(numgames):
        reward = Game(Agent1, Agent2, N, K, verbose = False)
        if reward == 1: p1win+=1
        else: p2win+=1
        
    print(f'Player 1 wins {p1win} games. Player 2 wins {p2win} games.')

# Sample Agents

In [4]:
def RandomAgent(N, K, validmoves):
    ''' Returns a random move '''
    return np.random.choice(validmoves)

In [5]:
def PerfectAgent(N, K, validmoves):
    # if can make it 1 mod (K+1), then do so
    # doing so can ensure that we match the opponent's move to make it sum of (K+1), then the remainder of 1 will always be the opponent
    # the opponent will eventually be forced to take the last matchstick after (K+1) blocks are used up
    for move in validmoves:
        if (N-move) % (K+1)==1:
            return move

    # otherwise just choose a random amount
    return np.random.choice(validmoves)

# Run Some Trial Games
- RandomAgent vs RandomAgent should win about the same number of times each
- PerfectAgent should win almost all games

In [6]:
Game(RandomAgent, PerfectAgent, 21, 3)

Player 1 takes 3 matchstick(s). There are 18 matchstick(s) remaining.
Player 2 takes 1 matchstick(s). There are 17 matchstick(s) remaining.
Player 1 takes 2 matchstick(s). There are 15 matchstick(s) remaining.
Player 2 takes 2 matchstick(s). There are 13 matchstick(s) remaining.
Player 1 takes 3 matchstick(s). There are 10 matchstick(s) remaining.
Player 2 takes 1 matchstick(s). There are 9 matchstick(s) remaining.
Player 1 takes 2 matchstick(s). There are 7 matchstick(s) remaining.
Player 2 takes 2 matchstick(s). There are 5 matchstick(s) remaining.
Player 1 takes 1 matchstick(s). There are 4 matchstick(s) remaining.
Player 2 takes 3 matchstick(s). There are 1 matchstick(s) remaining.
Player 1 takes 1 matchstick(s). There are 0 matchstick(s) remaining.
Player 2 wins!


-1

In [7]:
Game(PerfectAgent, RandomAgent, 21, 3)

Player 1 takes 3 matchstick(s). There are 18 matchstick(s) remaining.
Player 2 takes 1 matchstick(s). There are 17 matchstick(s) remaining.
Player 1 takes 3 matchstick(s). There are 14 matchstick(s) remaining.
Player 2 takes 2 matchstick(s). There are 12 matchstick(s) remaining.
Player 1 takes 3 matchstick(s). There are 9 matchstick(s) remaining.
Player 2 takes 2 matchstick(s). There are 7 matchstick(s) remaining.
Player 1 takes 2 matchstick(s). There are 5 matchstick(s) remaining.
Player 2 takes 2 matchstick(s). There are 3 matchstick(s) remaining.
Player 1 takes 2 matchstick(s). There are 1 matchstick(s) remaining.
Player 2 takes 1 matchstick(s). There are 0 matchstick(s) remaining.
Player 1 wins!


1

In [8]:
Tournament(100, RandomAgent, RandomAgent, 21, 3)

Player 1 wins 47 games. Player 2 wins 53 games.


In [9]:
Tournament(100, RandomAgent, PerfectAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 100 games.


In [10]:
Tournament(100, PerfectAgent, RandomAgent, 21, 3)

Player 1 wins 99 games. Player 2 wins 1 games.


In [11]:
Tournament(100, PerfectAgent, PerfectAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 100 games.


In [12]:
Tournament(100, PerfectAgent, PerfectAgent, 22, 3)

Player 1 wins 100 games. Player 2 wins 0 games.


# Monte Carlo

- This agent will do each possible move and then randomly playout the rest of the game
- The move with the highest average value will be chosen

In [13]:
def randomrollout(N, K):
    ''' Returns the reward of a random playthrough '''
    if N==0:
        return 1
    elif N==1:
        return -1
    else:
        return Game(RandomAgent, RandomAgent, N, K, verbose = False)

In [14]:
def MonteCarloAgent(N, K, validmoves):
    rollouts = 1000
    bestmove = 0
    bestvalue = -10000
    for move in validmoves:
        totalreward = 0
        for _ in range(rollouts):
            totalreward -= randomrollout(N - move, K)
        totalreward = totalreward/rollouts
        if totalreward > bestvalue:
            bestvalue = totalreward
            bestmove = move
        # print(f'Move {move}: {totalreward}')
            
    return bestmove

In [15]:
Tournament(4, MonteCarloAgent, MonteCarloAgent, 21, 3)

Player 1 wins 3 games. Player 2 wins 1 games.


In [16]:
Tournament(10, RandomAgent, MonteCarloAgent, 21, 3)

Player 1 wins 1 games. Player 2 wins 9 games.


In [17]:
Tournament(10, MonteCarloAgent, RandomAgent, 21, 3)

Player 1 wins 9 games. Player 2 wins 1 games.


In [18]:
Tournament(10, PerfectAgent, MonteCarloAgent, 21, 3)

Player 1 wins 10 games. Player 2 wins 0 games.


In [19]:
Tournament(10, MonteCarloAgent, PerfectAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.


# Monte Carlo with Perfect Play by opponent

We can see that Monte Carlo alone may not win a lot of games, especially against Perfect Agent.
This is because Monte Carlo can actually predict the wrong game outcome if the opponent is playing randomly.
E.g. 
- You have 6 matchsticks left, ideal move is to take 1 matchstick and get 5 (1 mod 4).
- Random agent may pick a bad move, making taking 2 or 3 matchsticks (to get 4 or 3 matchsticks remaining respectively) a good move
- This can make the value estimate of a state messy

We can improve naive Monte Carlo with perfect rollouts by the opponent player (best is perfect by both players, but it defeats the point of Monte Carlo)

The game then becomes closer to a single player variant whereby the environment includes the opponent player and responds perfectly

In [20]:
def perfectrollout(N, K):
    ''' Returns the reward of a random playthrough '''
    if N==0:
        return 1
    elif N==1:
        return -1
    else:
        return Game(PerfectAgent, RandomAgent, N, K, verbose = False)

In [21]:
def MonteCarloPerfectOppAgent(N, K, validmoves):
    rollouts = 1000
    bestmove = 0
    bestvalue = -10000
    for move in validmoves:
        totalreward = 0
        for _ in range(rollouts):
            totalreward -= perfectrollout(N - move, K)
        totalreward = totalreward/rollouts
        if totalreward > bestvalue:
            bestvalue = totalreward
            bestmove = move
        # print(f'Move {move}: {totalreward}')
            
    return bestmove

In [145]:
Tournament(10, RandomAgent, MonteCarloPerfectOppAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.


In [146]:
Tournament(10, MonteCarloPerfectOppAgent, RandomAgent, 21, 3)

Player 1 wins 10 games. Player 2 wins 0 games.


In [147]:
Tournament(10, PerfectAgent, MonteCarloPerfectOppAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.


In [148]:
Tournament(10, MonteCarloPerfectOppAgent, PerfectAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.


# Monte Carlo Tree Search

- One way to make the opponent play closer to perfect play is to look in perspective of each player
- Each player seeks to choose their move to maximize their reward (P1 maximizes, P2 minimizes)
- Moves explored are balanced by the explore-exploit tradeoff
- Move selected will be the root move wihich is explored the most

In [22]:
from IPython.display import Image
Image(url = "https://media.geeksforgeeks.org/wp-content/uploads/mcts_own.png")

In [23]:
print("UCB algorithm:")
Image(url = "https://www.cs.swarthmore.edu/~mitchell/classes/cs63/f20/reading/ucb-1.png")

UCB algorithm:


In [44]:
class Node:
    def __init__(self, parent, N, K, move):
        self.value = 0
        self.numselected = 0
        self.N = N
        self.K = K
        self.move = move # the move required to reach this node
        self.validmoves = Nim(N, K).validmoves() # the remaining moves that you can take to go to the next child; change the environment for different moves
        self.child = []
        self.parent = parent

In [45]:
class MonteCarloTreeSearch:
    def __init__(self, rootnode, rollout, c=1):
        self.root = rootnode
        self.rollout = rollout
        self.explore = 1
        
    def run(self):
        # Start from root node
        node = self.root
        # Do selection till last node
        done = False
        while not done:
            node, done = self.select(node)
        # Evaluate the node's value with rollout
        value = -self.rollout(node.N, node.K)
        # Backpropagate values upwards
        factor = 1
        while True:
            node.numselected += 1
            node.value += value * factor
            factor *= -1
            # move to the parent node
            node = node.parent
            if node == None: break
        
    def select(self, curnode):
        # apply UCB algorithm to choose the next node based on explore-exploit

        # if we do not have all children, then expand and return the next one
        if len(curnode.validmoves) > 0:
            nextmove = curnode.validmoves.pop(0)
            return self.expand(curnode, nextmove), True
        
        # if this node of end of the game, simply return this node
        if len(curnode.validmoves) == 0 and len(curnode.child) == 0:
            return curnode, True
            
        # otherwise we will do the UCB algorithm to choose the best node
        maxnode = self.ucb(curnode)
        return maxnode, False
    
    def ucb(self, node):   
        maxvalue = -10000
        maxnode = None
        for child in node.child:
            ucb = child.value/child.numselected + self.explore * np.sqrt(np.log(self.root.numselected)/(child.numselected+1e-9))
            if ucb > maxvalue:
                maxvalue = ucb
                maxnode = child
        return maxnode
    
    def expand(self, curnode, nextmove):
        ''' Expands the node with the next move '''
        newchild = Node(curnode, curnode.N - nextmove, curnode.K, nextmove)
        curnode.child.append(newchild)
        return newchild
    
    def getbestmove(self):
        bestchild = None
        maxselected = 0
        for child in self.root.child:
            if child.numselected > maxselected:
                maxselected = child.numselected
                bestchild = child
        return bestchild.move

In [46]:
def MCTSAgent(N, K, validmoves):
    ''' Returns the best move determined by MCTS '''
    
    rootnode = Node(None, N, K, None)
    mcts = MonteCarloTreeSearch(rootnode, randomrollout)
    # Run the MCTS search
    for i in range(500):
        mcts.run()
    # Select the best move
    bestmove = mcts.getbestmove()
    return bestmove

In [47]:
Tournament(10, RandomAgent, MCTSAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.


In [48]:
Tournament(10, MCTSAgent, RandomAgent, 21, 3)

Player 1 wins 10 games. Player 2 wins 0 games.


In [49]:
Tournament(10, MCTSAgent, MCTSAgent, 21, 3)

Player 1 wins 5 games. Player 2 wins 5 games.


In [50]:
Tournament(10, PerfectAgent, MCTSAgent, 21, 3)

Player 1 wins 9 games. Player 2 wins 1 games.


In [231]:
Tournament(10, MCTSAgent, PerfectAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.


In [232]:
Tournament(10, MonteCarloAgent, MCTSAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.


In [233]:
Tournament(10, MCTSAgent, MonteCarloAgent, 21, 3)

Player 1 wins 10 games. Player 2 wins 0 games.


In [234]:
Tournament(10, MonteCarloPerfectOppAgent, MCTSAgent, 21, 3)

Player 1 wins 9 games. Player 2 wins 1 games.


In [239]:
Game(MonteCarloPerfectOppAgent, MCTSAgent, 21, 3)

Player 1 takes 1 matchstick(s). There are 20 matchstick(s) remaining.
Player 2 takes 3 matchstick(s). There are 17 matchstick(s) remaining.
Player 1 takes 1 matchstick(s). There are 16 matchstick(s) remaining.
Player 2 takes 2 matchstick(s). There are 14 matchstick(s) remaining.
Player 1 takes 1 matchstick(s). There are 13 matchstick(s) remaining.
Player 2 takes 2 matchstick(s). There are 11 matchstick(s) remaining.
Player 1 takes 2 matchstick(s). There are 9 matchstick(s) remaining.
Player 2 takes 2 matchstick(s). There are 7 matchstick(s) remaining.
Player 1 takes 2 matchstick(s). There are 5 matchstick(s) remaining.
Player 2 takes 1 matchstick(s). There are 4 matchstick(s) remaining.
Player 1 takes 3 matchstick(s). There are 1 matchstick(s) remaining.
Player 2 takes 1 matchstick(s). There are 0 matchstick(s) remaining.
Player 1 wins!


1

In [236]:
Tournament(10, MCTSAgent, MonteCarloPerfectOppAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.


# Using Value estimates of a state for planning

- Give each state a value
- Value of each state will be the maximum value you can attain taking a one-step action from that state
- Start from value of the end state (0) and give it 1 point, and a state of 1 give it -1 points
- Based on each state, we choose the move that gives it the least negative value of the next state (other player has opposite reward fun)

In [177]:
def train(N, K):
    ''' Train the value network based on the inputs N and K
    Inputs:
    N - total number of matchsticks
    K - total moves
    '''
    V = [-10 for _ in range(N+1)]
    
    # 0 is lose state, 1 is win state
    # no matchsticks is a winning state
    # having one matchstick is a losing state
    V[0] = 1
    if N > 0: V[1] = -1
    
    # update the values from starting state to ending state
    for j in range(2, N+1):
        for move in range(1, K+1):
            # only when valid move, update the value estimate with the highest possible value
            if j-move >= 0:
                V[j] = max(V[j], -V[j-move])
                    
    return V

In [178]:
def ValueAgent(N, K, validmoves):
    # choose the move which gives the highest value based on one-step lookahead update of value state
    V = train(N, K)
    bestmove = None
    bestvalue = -10000
    for move in validmoves:
        if -V[N-move] > bestvalue:
            bestvalue = -V[N-move]
            bestmove = move
    return bestmove

In [179]:
Tournament(10, RandomAgent, ValueAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.


In [180]:
Tournament(10, RandomAgent, ValueAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.


In [181]:
Tournament(10, PerfectAgent, ValueAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.


In [182]:
Tournament(10, ValueAgent, PerfectAgent, 21, 3)

Player 1 wins 0 games. Player 2 wins 10 games.
