<a href="https://colab.research.google.com/github/sugiyama404/AlphaGo/blob/main/AlphaZero_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import math
import numpy as np
from datetime import datetime
from math import sqrt

from tensorflow.keras import Model
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import LearningRateScheduler, LambdaCallback
from tensorflow.keras.layers import Activation, Add, BatchNormalization, Conv2D, Dense, GlobalAvgPool2D, Input
from tensorflow.keras.models import load_model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import plot_model, Progbar
from tensorflow.keras.optimizers import SGD

from dataclasses import dataclass, field
import random, string

In [2]:
class State:
    def __init__(self, pieces=None, enemy_pieces=None):
        self.pieces, self.enemy_pieces = np.zeros(9), np.zeros(9)
        if pieces is not None:
            self.pieces = pieces
        if enemy_pieces is not None:
            self.enemy_pieces = enemy_pieces

    def step(self, action):
        pieces_cp = np.copy(self.pieces)
        pieces_cp[action] = 1
        return State(self.enemy_pieces, pieces_cp)

    def judgment_lose(self):
        val = np.where(self.enemy_pieces == 1)[0]
        arr_corect = np.array([[0,1,2],[3,4,5],[6,7,8],[0,3,6],
                               [1,4,7],[2,5,8],[0,4,8],[2,4,6]])
        for arr in arr_corect:    
            if set(arr) <= set(val):
                return True
        return False
    
    def terminal(self):
        standoff = (len(np.where(self.pieces == 1)[0]) + len(np.where(self.enemy_pieces == 1)[0]) == 9)
        return self.judgment_lose() or standoff

    def possible_actions(self):
        actions = self.pieces + self.enemy_pieces
        actions = np.where(actions == 0)[0]
        return actions

    def first_attack(self):
        return (len(np.where(self.pieces == 0)[0]) == len(np.where(self.enemy_pieces == 0)[0]))

In [3]:
class Brain:
    def __init__(self, loadmodel = False):
        self.filters = 128
        self.hidden_layers = 16
        self.obs_shape = (3, 3, 2)
        self.nn_actions = 9
        self.kr = l2(0.0005)
        self.opt = SGD(learning_rate = 0.001, momentum = 0.9)

        if not loadmodel:
            self._main_network_layer()
        else:
            self._load()

    def _main_network_layer(self):
        input = Input(shape = self.obs_shape)
        x = self._conv_layer(self.filters)(input)
        for i in range(self.hidden_layers):
            x = self._residual_layer()(x)
        x = GlobalAvgPool2D()(x)

        p = Dense(self.nn_actions, kernel_regularizer=self.kr, activation='softmax')(x)
        v = Dense(1, kernel_regularizer=self.kr, activation='tanh')(x)

        model = Model(inputs = input, outputs=[p, v])
        model.compile(loss = ['categorical_crossentropy', 'mse'], optimizer = self.opt)
        self.model = model

        dot_img_file = './alphazero_model.png'
        plot_model(self.model, to_file=dot_img_file, show_shapes=True)

    def _residual_layer(self):
        def f(input_block):
            x = self._conv_layer(self.filters)(input_block)
            x = self._conv_layer(self.filters, join_act = False)(x)
            x = Add()([x, input_block])
            x = Activation('relu')(x)
            return x
        return f

    def _conv_layer(self, filters, join_act = True):
        def f(input_block):
            x = Conv2D(filters, 3, padding='same', use_bias=False,
                       kernel_initializer='he_normal',
                       kernel_regularizer=self.kr)(input_block)
            x = BatchNormalization()(x)
            if join_act:
                x = Activation('relu')(x)
            return x
        return f 
        
    def predict(self, state):
        a, b, c = self.obs_shape
        cc = np.concatenate([state.pieces, state.enemy_pieces])
        cc = np.reshape(cc, [1, 2, 9])
        cc = cc.reshape(c, a, b).transpose(1, 2, 0).reshape(1, a, b, c)

        y = self.model.predict(cc, batch_size = 1)

        policies = y[0][0][state.possible_actions()]
        policies /= sum(policies) if sum(policies) else 1

        value = y[1][0][0]
        return policies, value

    def save(self):
        self.model.save('./alphazero.h5')

    def _load(self):
        self.model = load_model('./alphazero.h5')

    def train(self, trajectory):
        piece, policy, value = trajectory.get_trajectory()
        a, b, c = self.obs_shape
        piece = piece.reshape(len(piece), c, a, b).transpose(0, 2, 3, 1)
        self.model.fit(piece, [policy, value], batch_size=128, epochs = 100, verbose=0)

In [4]:
class MonteCarloTreeSearch:
    def __init__(self, brain):
        self.brain = brain
        self.search_num = 50
        self.temp = 1.0

    def search(self, state):
        node = Node(state, 0, self.brain)
        for _ in range(self.search_num): node.evaluate()

        scores = np.array([i.n for i in node.child_nodes])
        scores = self._boltzman_distribution(scores)
        return scores

    def _boltzman_distribution(self, ps):
        ps = ps ** (1 / self.temp)
        return ps / np.sum(ps)

In [5]:
class Node:
    def __init__(self, state, p, brain):
        self.state = state
        self.p = p
        self.w = 0
        self.n = 0
        self.c_puct = 1.0
        self.obs_shape = (3, 3, 2)
        self.child_nodes = None
        self.brain = brain

    def evaluate(self):
        if self.state.terminal():
            value = -1 if self.state.judgment_lose() else 0
            self.w += value
            self.n += 1
            return value

        if self.child_nodes is None:
            policies, value = self.brain.predict(self.state)
            self.w += value
            self.n += 1

            self.child_nodes = np.array([])
            for action, policy in zip(self.state.possible_actions(), policies):
                node = Node(self.state.step(action), policy, self.brain)
                self.child_nodes = np.append(self.child_nodes, node)
            return value
        else:
            value = -self._move_to_leaf().evaluate()
            self.w += value
            self.n += 1
            return value

    def _move_to_leaf(self):
        scores = np.array([i.n for i in self.child_nodes])
        t = np.sum(scores)
        pucb_values = np.array([],dtype=float)
        for child_node in self.child_nodes:
            puct = self._puct_value(child_node, t)
            pucb_values = np.append(pucb_values, np.array([puct]))        
        return self.child_nodes[np.argmax(pucb_values)]

    def _puct_value(self, c, t):
        return (-c.w / c.n if c.n else 0.0) + self.c_puct * c.p * sqrt(t) / (1 + c.n)

In [6]:
@dataclass
class Trajectory:
    piece : np.ndarray = np.empty((0, 2, 9), float)
    policy : np.ndarray = np.empty((0,9), float)
    value : np.ndarray = np.array([], int)
    code : np.ndarray = np.array([])

    def reset_trajectory(self):
        self.piece = np.empty((0, 2, 9), float)
        self.policy = np.empty((0,9), float)
        self.value = np.array([], int)
        self.code = np.array([])

    def set_trajectory(self, pieces, enemy_pieces, policy, value, code):
        cc = np.concatenate([pieces, enemy_pieces])
        cc = np.reshape(cc, [1, 2, 9])
        self.piece = np.append(self.piece, cc, axis=0)
        policy = np.reshape(policy, [1, 9])
        self.policy = np.append(self.policy, policy, axis=0)
        self.value = np.append(self.value, np.array(value))
        self.code = np.append(self.code, np.array(code))

    def get_trajectory(self):
        return (self.piece, self.policy, self.value)

In [7]:
class Train:
    def __init__(self, brain):
        self.episodes_times = 30
        self.brain = brain

        self._train()

    def _train(self):
        trajectory = Trajectory()
        progbar = Progbar(self.episodes_times)

        for i in range(self.episodes_times):
            self._play(self.brain, trajectory)

            if (i % 5 == 0) and (i > 0):
                self.brain.train(trajectory)
                trajectory.reset_trajectory()
            
            progbar.add(1)

        self.brain.save()

    def _play(self, brain, trajectory):
        state = State()
        mcts = MonteCarloTreeSearch(brain)
        code = self._make_random_code()

        while True:
            if state.terminal():
                break

            scores = mcts.search(state)
            policies = np.zeros(9)
            for action, score in zip(state.possible_actions(), scores):
                policies[action] = score
            trajectory.set_trajectory(state.pieces, state.enemy_pieces, policies, 0, code)

            action = np.random.choice(state.possible_actions(), p = scores)
            state = state.step(action)

            value = -1 if state.first_attack() else 1
            for i in range(len(trajectory.value)):
                if code == trajectory.code[i] and state.judgment_lose():
                    trajectory.value[i] = value
                reward = -value

    def _make_random_code(self, n=10):
        return ''.join(random.choices(string.ascii_letters + string.digits, k=n))

In [8]:
Train(Brain())



  layer_config = serialize_layer_fn(layer)


<__main__.Train at 0x7f5d015fc790>