# Checkers AI training through self play

This notebook uses a checkers implementation from the pettingzoo package. 

Monte carlo tree search and self-play training is based on https://web.stanford.edu/~surag/posts/alphazero.html

In [34]:
from pettingzoo.utils.env import AECEnv
from pettingzoo.classic import checkers_v3
from copy import deepcopy
import numpy as np

## Utils
Let's start by defining some util classes used later in the code

In [35]:
# Game env wrapper for MCTS search
class State:

    def __init__(self, env : AECEnv):
        self.env = env

    def gameEnded(self):
        _, _, done, _ = self.env.last()
        return done

    def gameReward(self):
        _, reward, _, _ = self.env.last()
        return reward

    def getActionMask(self):
        observation, _, _, _ = self.env.last()
        return observation["action_mask"]

    def getValidActions(self):
        return np.flatnonzero(self.getActionMask())

    def nextState(self, action):
        new_env = deepcopy(self.env)
        new_env.step(action)
        player_changed = self.env.agent_selection != new_env.agent_selection
        return State(new_env), player_changed

    def getObservation(self):
        return self.env.observe(self.currentAgent())["observation"]

    def currentAgent(self):
        return self.env.agent_selection

    def show(self, wait=False):
        self.env.render()
        if wait:
            input("press any key to continue")


    def __eq__(self, x):
        if not isinstance(x, State):
            return False
        # this should be enough
        same_agent = self.env.agent_selection == x.env.agent_selection
        observations_match = (self.getObservation() == x.getObservation()).all()
        return same_agent and observations_match

    def toStr(self):
        o = self.getObservation()
        # reduce dimensions from 3 to 2
        o = np.sum(o, axis = 2) * (np.argmax(o, axis = 2) + 1)
        return str(o)

    def __hash__(self):
        return hash(self.toStr())


In [36]:
class TrainingExample:

    def __init__(self, state : State, pi, reward):
        self.state = state
        self.pi = pi
        self.reward = reward


## The neural net

In [52]:
from typing import List
import random
from tensorflow.keras import layers, Model, Input, metrics

def applyActionMaskToPolicy(p, action_mask):
    p_masked = p * action_mask
    # policy zeroed all possible actions
    if np.sum(p_masked) == 0:
        p_masked = m

    return p_masked / np.sum(p_masked) # renormalize


class NNet:

    def __init__(self, action_size):
        x = Input(shape=(8,8,4))
        y = layers.Conv2D(16, 3, activation='relu')(x)
        y = layers.Conv2D(16, 3, activation='relu')(y)
        y = layers.Conv2D(16, 3, strides=2, activation='relu')(y)
        y = layers.Flatten()(y)
        y = layers.Dropout(0.5)(y)
        p = layers.Dense(action_size, activation='softmax', name="pi")(y)
        v = layers.Dense(1, name="v")(y)
        self.nnet = Model(x, [p,v])
        print(self.nnet.summary())
        self.nnet.compile(
                optimizer='rmsprop',
                loss=["categorical_crossentropy","mean_squared_error"],
                metrics=[metrics.MeanSquaredError(), metrics.CategoricalCrossentropy()]
        )

    def predict(self, state : State):
        x = state.getObservation()
        x = np.expand_dims(x,0)
        p, v = self.nnet.predict(x, batch_size=1)

        p = p[0]
        p = applyActionMaskToPolicy(p, state.getActionMask())
        return p, v[0][0]

    @staticmethod
    def _prepare_examples(examples: List[TrainingExample]):
        X = []
        pi = []
        v = []
        random.shuffle(examples)
        for e in examples:
            X.append(e.state.getObservation())
            pi.append(e.pi)
            v.append(e.reward)
        
        return np.array(X), [np.array(pi), np.array(v)]
     
    def train(self, examples):
        X, y = self._prepare_examples(examples)
        self.nnet.fit(X, y, batch_size=32)
        return self

In [38]:
class RandomPlayer:
    
    def predict(self, state : State):
        p = np.random.uniform(256)
        return applyActionMaskToPolicy(p, state.getActionMask()), 0


Check if it works : create environment, wrap it in state, run nnet predict.

In [39]:
env = checkers_v3.env()
env.reset()
env.render()

  M   M   M   M 
M   M   M   M   
  M   M   M   M 
_   _   _   _   
  _   _   _   _ 
m   m   m   m   
  m   m   m   m 
m   m   m   m   


In [40]:
state = State(env)
state.getObservation().shape

(8, 8, 4)

In [41]:
nnet = NNet(256)

Model: "model_3"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_4 (InputLayer)           [(None, 8, 8, 4)]    0           []                               
                                                                                                  
 conv2d_9 (Conv2D)              (None, 6, 6, 16)     592         ['input_4[0][0]']                
                                                                                                  
 conv2d_10 (Conv2D)             (None, 4, 4, 16)     2320        ['conv2d_9[0][0]']               
                                                                                                  
 conv2d_11 (Conv2D)             (None, 1, 1, 16)     2320        ['conv2d_10[0][0]']              
                                                                                            

In [42]:
policy, value = nnet.predict(state)
print(policy.shape)
print(value)

(256,)
-0.08509848


In [43]:
policy, value = RandomPlayer().predict(state)
print(policy.shape)
print(value)

(256,)
0


## MCTS

In [44]:
class MCTS:

    def __init__(self, nnet, num_mcts_sims, max_depth = 10):
        self.nnet = nnet
        # number of times given state and action has been tested
        self.N = {}
        # policy in each state
        self.P = {}
        # Q value of each state
        self.Q = {}
        # predicted Q value of each state
        self.predicted_v = {}
        # set of visited states states 
        self.visited = set()
        # some paramter
        self.c_puct = 1.0
        
        self.num_mcts_sims = num_mcts_sims
        self.max_depth = max_depth
        
    def search(self, s : State):
        for _ in range(self.num_mcts_sims):
            self._search(s, self.max_depth)

    def _search(self, s : State, max_depth):
        # print("Search : ")
        # print(s.toStr())
        if s.gameEnded(): return s.gameReward()

        if s not in self.visited:
            self.visited.add(s)
            pi, v = self.nnet.predict(s)
            self.predicted_v[s] = v
            self.P[s] = pi
            self.N[s] = np.zeros(len(pi))
            self.Q[s] = np.zeros(len(pi))
            return v

        if max_depth == 0:
            print(f"max depth reached!, a heuristic value of this state {self.predicted_v[s]}")
            return self.predicted_v[s]
      
        max_u, best_a = -np.inf, None
        for a in s.getValidActions():
            u = self.Q[s][a] + self.c_puct * self.P[s][a] * np.sqrt(np.sum(self.N[s])) / (1 + self.N[s][a])
            if u > max_u:
                max_u = u
                best_a = a
        a = best_a
        
        sp, player_changed = s.nextState(a)
        v = self._search(sp, max_depth - 1)
        if player_changed:
            v = -v

        self.Q[s][a] = (self.N[s][a] * self.Q[s][a] + v) / (self.N[s][a] + 1)
        self.N[s][a] += 1
        return v

    # improved policy
    def pi(self, s : State):
        summed_n = np.sum(self.N[s])
        if summed_n == 0:
            return self.P[s]

        return self.N[s] / summed_n


In [30]:
mcts = MCTS(nnet, 2)

In [31]:
mcts.search(state)

In [32]:
mcts.pi(state).shape

(256,)

In [60]:
def pit(new_nnet : NNet, nnet : NNet, games_played = 10):
    new_nnet_tag = "player_0"
    nnet_tag = "player_1"
    wins = 0
    ties = 0

    for g in range(games_played):
        env = checkers_v3.env()
        env.reset()
        s = State(env)
        # swap players before each round
        new_nnet_tag, nnet_tag = nnet_tag, new_nnet_tag  
        agents = {new_nnet_tag : new_nnet, nnet_tag : nnet}

        while not s.gameEnded():
            agent = agents[s.currentAgent()]
            p, _ = agent.predict(s)
            action = np.random.choice(len(p), p=p)
            s.env.step(action)

        if s.gameReward() == 0:
            ties += 1
       
        if s.gameReward() == 1 and s.currentAgent() == new_nnet_tag:
            wins += 1
    
        if s.gameReward() == -1 and s.currentAgent() != new_nnet_tag:
            wins += 1
        
            
    frac_win = wins / (games_played - ties)
    return frac_win

# training
def policyIterSP(env : AECEnv, num_iters = 4, num_eps = 2, frac_win_thresh = 0.55):
    # hard coded action space size
    nnet = NNet(256)
    frac_win = pit(nnet, RandomPlayer())                              # compare new net with a random player
    print("frac_wins against a random player", frac_win)
    examples = []
    for i in range(num_iters):
        for e in range(num_eps):
            examples += executeSelfPlayEpisode(env, nnet)             # collect examples from this game
            print("episode done")
        new_nnet = nnet.train(examples)
        frac_win = pit(new_nnet, nnet)                                # compare new net with previous net
        print("frac_win", frac_win)
        if frac_win > frac_win_thresh:
            print("new net is better!")
            nnet = new_nnet                                           # replace with new net
            frac_win = pit(nnet, RandomPlayer())                      # compare new net with a random player
            print("frac_wins against a random player", frac_win)
        examples = random.sample(examples, len(examples) // 2)
    return nnet

def executeSelfPlayEpisode(env : AECEnv, nnet, num_mcts_sims = 10):
    examples = []
    env.reset()
    s = State(env)
    # s.show(wait = False)
    mcts = MCTS(nnet, num_mcts_sims)

    while True:
        mcts.search(s)
        pi = mcts.pi(s)
        examples.append(TrainingExample(deepcopy(s), pi, None))  # rewards can not be determined yet
        a = np.random.choice(len(pi), p=pi)                      # sample action from improved policy
        s, _ = s.nextState(a)
        # s.show(wait = False)
        if s.gameEnded():
            examples = assignRewards(examples, s.gameReward(), s.currentAgent())
            return examples

def assignRewards(examples, reward, player_w_reward):
    for e in examples:
        e.reward = reward if e.state.currentAgent() == player_w_reward else -reward

    return examples

## Run training!

In [61]:
env = checkers_v3.env()
nnet = policyIterSP(env)

Model: "model_9"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_10 (InputLayer)          [(None, 8, 8, 4)]    0           []                               
                                                                                                  
 conv2d_27 (Conv2D)             (None, 6, 6, 16)     592         ['input_10[0][0]']               
                                                                                                  
 conv2d_28 (Conv2D)             (None, 4, 4, 16)     2320        ['conv2d_27[0][0]']              
                                                                                                  
 conv2d_29 (Conv2D)             (None, 1, 1, 16)     2320        ['conv2d_28[0][0]']              
                                                                                            

KeyboardInterrupt: 