# MCTS for Tic Tac Toe

Building a simple MCTS for Tic-tac-toe. No ML only hard computation.

In [46]:
# Building your own game is a time-consuming process, good if you want to practice, not otherwise.
# So I am using the code from below link
# this is a 3D tic-tac-toe meaning that for a normal 2D TTT the number of possible
# states = 3 ** 9 = 19683
# In the 3D TTT, 3 ** 27 = 7625597484987 (76.25Bn)

!mkdir envs
!wget https://raw.githubusercontent.com/shkreza/gym-tictactoe3d/master/gym_tictactoe/envs/tictactoe_env.py -O envs/env3d.py
!wget https://raw.githubusercontent.com/haje01/gym-tictactoe/master/gym_tictactoe/env.py -O envs/env2d.py

--2021-04-05 11:28:06--  https://raw.githubusercontent.com/shkreza/gym-tictactoe3d/master/gym_tictactoe/envs/tictactoe_env.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5754 (5.6K) [text/plain]
Saving to: ‘envs/env3d.py’


2021-04-05 11:28:06 (2.67 MB/s) - ‘envs/env3d.py’ saved [5754/5754]

--2021-04-05 11:28:06--  https://raw.githubusercontent.com/haje01/gym-tictactoe/master/gym_tictactoe/env.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5945 (5.8K) [text/plain]
Saving to: ‘envs/env2d.py’


2021-04-05 11:28:06 (16.0 MB/

In [69]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import random
from envs import env2d
from tqdm import trange
from copy import deepcopy
import numpy as np

In [2]:
# we also need to provide mask for legal moves
legal_moves_2d = lambda board: np.arange(len(board))[np.asarray(board) == 0].tolist()

In [16]:
# test 2D TTT environment
env = env2d.TicTacToeEnv()
env.reset()
env.render()
done = False
for i in range(10):
  l = legal_moves_2d(env.board)
  a = random.choice(l)
  print("-"*30 + f" {i} {a} {l} {done} " + "-"*30)
  observation, _, done, _ = env.step(a)
  env.render()
  if done:
    break

   | | 
  -----
   | | 
  -----
   | | 

------------------------------ 0 7 [0, 1, 2, 3, 4, 5, 6, 7, 8] False ------------------------------
   | | 
  -----
   | | 
  -----
   |O| 

------------------------------ 1 6 [0, 1, 2, 3, 4, 5, 6, 8] False ------------------------------
   | | 
  -----
   | | 
  -----
  X|O| 

------------------------------ 2 3 [0, 1, 2, 3, 4, 5, 8] False ------------------------------
   | | 
  -----
  O| | 
  -----
  X|O| 

------------------------------ 3 0 [0, 1, 2, 4, 5, 8] False ------------------------------
  X| | 
  -----
  O| | 
  -----
  X|O| 

------------------------------ 4 8 [1, 2, 4, 5, 8] False ------------------------------
  X| | 
  -----
  O| | 
  -----
  X|O|O

------------------------------ 5 5 [1, 2, 4, 5] False ------------------------------
  X| | 
  -----
  O| |X
  -----
  X|O|O

------------------------------ 6 2 [1, 2, 4] False ------------------------------
  X| |O
  -----
  O| |X
  -----
  X|O|O

------------------------------ 7 4 

In [17]:
l

[1]

In [12]:
env.done

False

In [113]:
class Node:
  def __init__(self, state, move, value = 0, root = False, exploration_constant = 5.0):
    self.n = 0 # number of times this node was visited, initialised with 0
    self.children = []
    
    self.state = state
    self.move = move
    self.q_value = value
    self.is_root = root
    
    self.c = exploration_constant
    self.wins = 0
    
  def __eq__(self, n):
    return self.move == n.move and len(self) == len(n)
    
  @property
  def terminal(self):
    return len(self.children) == 0
    
  @property
  def total_nodes(self):
    n = 1
    for c in self.children:
      n += c.total_nodes
    return n

  def all_children(self):
    c = []
    if self.children:
      for child in self.children:
        c.extend(child.all_children())
    else:
      c = [self]
    return c
  
  def __len__(self):
    return len(self.children)
  
  def __repr__(self):
    return f"<Move '{self.move}'; q={self.q_value:.3f} c={len(self)}; n={self.n}; s={self.state}>"

  def __str__(self, level=0):
    ret = "  "*level+repr(self)+"\n"
    for child in self.children:
      ret += child.__str__(level+1)
    ret = ret[:-1] if ret[-1] == "\n" else ret
    return ret
  
  def get_uct_value(self, total_n):
    n = max(self.n, 1e-4) # when 0 use 1e-4 as the value of n
    exploitation_value = self.q_value / n
    exploration_value = np.sqrt(2 * np.log(total_n) / n)
    value = exploitation_value + self.c * exploration_value
    return value

In [114]:
def play_game(env, root_node, n_steps = 4):
  env.reset()
  child = root_node
  for i in range(n_steps):
    l = legal_moves_2d(env.board)
    a = random.choice(l)

    # create node and append at the last
    cnode = Node(env.board.copy(), a)
    child.children.append(cnode)
    child = cnode
    
    env.step(a)
  return env, child

In [115]:
def selection(env, node, steps_taken):
  # function takes in the node and current environment and takes the greedy action on state_value
  legal_moves = legal_moves_2d(env.board)
  # init_p = np.ones_like(legal_moves) / len(legal_moves) # uniform initial probability
  if node.terminal:
    # this is the leaf node and nothing is present as of now
    return node, -1

  # else: find the node with maximum UCT value
  maximum_uct_value = -1000 # start with negative for least value
  best_node = None
  for c in node.children:
    action_value = c.get_uct_value(steps_taken)
    if action_value > maximum_uct_value:
      maximum_uct_value = action_value
      best_node = c
      
  return best_node, maximum_uct_value

def expansion(env, node):
  if env.done:
    # this is the leaf node for this search
    return node

  # else: take all possible actions
  legal_moves = legal_moves_2d(env.board)
  root_node = node
  for mv in legal_moves:
    env2 = deepcopy(env)
    _,_,done,_ = env2.step(mv)
    child_node = Node(env2.board, mv)
    if child_node in root_node.children:
      already_present_child = list(filter(lambda x: x == child_node, root_node.children))[0]
      already_present_child.n += 1
    else:
      root_node.children.append(child_node)

  # return any random node
  return root_node.children[random.randint(0, len(root_node) - 1)]


def simulation(env, node, steps_taken):
  # run a random policy till game ends from this node
  steps_taken += 1
  env2 = deepcopy(env)
  done = env2.done
  _steps = 0
  while not done:
    _steps += 1 # increment the step counter
    l = legal_moves_2d(env2.board)
    a = l[random.randint(0, len(l) - 1)]
    _,_,done,_ = env2.step(a)
    
  # if _steps is even means the player at node won else opponent
  # won. Another check is to see whether this game was draw or not.
  l = legal_moves_2d(env2.board)
  if len(l) == 0:
    # there was no move left
    result = "draw"
  else:
    if _steps % 2 == 0:
      result = "win"
    else:
      result = "loss"

  return env2.board, result


def backprop(node, result):
  reward = {"draw": 0, "loss": -1, "win": +1}[result]
  for c in node.children:
    c.n += 1 # increment node count visit
    c.wins += reward
    c.q_value = c.wins/c.n

In [147]:
# let's start writing code for MCTS

# define values
N_SIMS = 100 # number of simulations to run for each step
steps_taken = 1
N_STEPS = 10

# we first reinit the environment for consistency
env.reset()
root_node = Node(env.board, "[S]", root = True) # yes this is root
# play a few steps
env, root_node = play_game(env, root_node, 2)
env.render()
for i in range(N_STEPS):
  print("Starting board:", env.board)

  # perform tree search
  for i in range(N_SIMS):
    child_node, maximum_uct_value = selection(env, root_node, steps_taken)
    child_node = expansion(env, child_node)
    board, result = simulation(env, node, steps_taken)
    backprop(root_node, result)


  # select the best action
  l = legal_moves_2d(env.board)
  print(l)
  best_q = -100
  for mv in root_node.children:
    q = mv.q_value
    if q > best_q:
      best_q = q
      best_action = mv

  print("Best Action:", mv, "\n\n\n")
  root_node = mv
  _, _, done, _ = env.step(mv.move)
  if done:
    break

  

   | |O
  -----
   | | 
  -----
  X| | 

Starting board: [0, 0, 1, 0, 0, 0, 2, 0, 0]
[0, 1, 3, 4, 5, 7, 8]
Best Action: <Move '8'; q=-0.130 c=0; n=100; s=[0, 0, 1, 0, 0, 0, 2, 0, 1]> 



Starting board: [0, 0, 1, 0, 0, 0, 2, 0, 1]
[0, 1, 3, 4, 5, 7]
Best Action: <Move '7'; q=0.170 c=0; n=100; s=[0, 0, 1, 0, 0, 0, 2, 2, 1]> 



Starting board: [0, 0, 1, 0, 0, 0, 2, 2, 1]
[0, 1, 3, 4, 5]
Best Action: <Move '5'; q=-0.530 c=0; n=100; s=[0, 0, 1, 0, 0, 1, 2, 2, 1]> 





In [148]:
env.render()

   | |O
  -----
   | |O
  -----
  X|X|O



In [163]:
# converting the code above to a function
def play_compute_move(env, last_action, n_sims = 100, verbose = False):
  root_node = Node(env.board, last_action)
  if verbose: env.render()
  if verbose: print("Starting board:", env.board)

  # perform tree search
  for i in range(n_sims):
    child_node, maximum_uct_value = selection(env, root_node, steps_taken)
    child_node = expansion(env, child_node)
    board, result = simulation(env, node, steps_taken)
    backprop(root_node, result)

  # select the best action
  l = legal_moves_2d(env.board)
  if verbose: print(l)
  best_q = -100
  for mv in root_node.children:
    q = mv.q_value
    if q > best_q:
      best_q = q
      best_action = mv

  if verbose: print("Best Action:", best_action, "\n\n\n")
  root_node = best_action
  _, _, done, _ = env.step(best_action.move)
  return env, done, best_action

In [165]:
env.reset()
done = False
while True:
  print(legal_moves_2d(env.board))
  action = int(input("Your move human >>> "))
  _, _, done, _ = env.step(action)
  print("----", action, "----")
  env.render()
  
  _draw = len(legal_moves_2d(env.board)) == 0
  if _draw:
    print("Game is a draw")
    break
  
  if done:
    print("This time you win human! >:(")
    break
  
  env, done, best_action = play_compute_move(env, action, 1000)
  print("----", best_action.move, "----")
  env.render()
  
  _draw = len(legal_moves_2d(env.board)) == 0
  if _draw:
    print("Game is a draw")
    break
   
  if done:
    print("Fucking Loser >:)")
    break

[0, 1, 2, 3, 4, 5, 6, 7, 8]
Your move human >>> 4
---- 4 ----
   | | 
  -----
   |O| 
  -----
   | | 

---- 0 ----
  X| | 
  -----
   |O| 
  -----
   | | 

[1, 2, 3, 5, 6, 7, 8]
Your move human >>> 1
---- 1 ----
  X|O| 
  -----
   |O| 
  -----
   | | 

---- 2 ----
  X|O|X
  -----
   |O| 
  -----
   | | 

[3, 5, 6, 7, 8]
Your move human >>> 7
---- 7 ----
  X|O|X
  -----
   |O| 
  -----
   |O| 

This time you win human! >:(
