In [None]:
import numpy as np
import seaborn as sns
from time import time

import tensorflow as tf
from tensorflow import keras
from keras import layers, models
from keras.regularizers import L1L2
import keras.backend as K

import os
import gc
from pathlib import Path

tf.config.experimental.set_visible_devices([], 'GPU')
es = keras.callbacks.EarlyStopping(patience=0, restore_best_weights=True)
np.random.seed(0)

In [None]:
init_state = np.array([
    [0, 0, 0],
    [0, 0, 0],
    [0, 0, 0]
])

In [None]:
class Game:
    def __init__(self, state, FIRST=1):
        self.state = state
        self.empty = self.make_empty(state)
        self.first_player = FIRST
        
    def make_empty(self, state):
        emp = []
        for i in range(3):
            for j in range(3):
                if state[i][j] == 0:
                    emp.append(3*i + j)
        
        return emp
    
    def is_lose(self):
        a = self.next_opp()
        
        for i in range(3):
            if self.state[i][0] == self.state[i][1] == self.state[i][2] != 0:
                return True
            elif self.state[0][i] == self.state[1][i] == self.state[2][i] != 0:
                return True
        if self.state[0][0] == self.state[1][1] == self.state[2][2] != 0:
            return True
        if self.state[0][2] == self.state[1][1] == self.state[2][0] != 0:
            return True
        return 0
    
    def is_draw(self):
        a = self.next_opp()
        if self.is_lose():
            return 0
        if np.all(self.state):
            return 1
        else:
            return 0
        
    def is_done(self):
        if self.is_lose() or self.is_draw():
            return 1
        else:
            return 0
        
        
    def update(self, target):
        state = self.state.copy()
        x, y = target//3, target%3
        a = self.next_opp()
        state[x][y] = a
        return Game(state)
    
    
    def next_opp(self):
        a = b = 0
        for i in range(len(self.state)):
            for j in range(len(self.state)):
                if self.state[i][j] == self.first_player:
                    a += 1
                elif self.state[i][j] != 0:
                    b += 1
                    
        if a == b:
            return self.first_player
        else:
            return 2 + min(0, 1-self.first_player)

In [None]:
class Random:
    def action(self, game):
        return np.random.choice(game.empty)

In [None]:
n_steps=100
def playout(game):
    if game.is_lose():
        return -1

    if game.is_draw():
        return 0

    return -playout(game.update(np.random.choice(game.empty)))


def action(game):
    values = [0] * len(game.empty)

    for i, a in enumerate(game.empty):
        if i in game.empty:
            for _ in range(n_steps):
                g = game.update(i)
                values[i] += -playout(g)

    return game.empty[np.argmax(values)]


def value(game):

    values = [0] * 9
    for i in range(9):
        if i in game.empty:
            for _ in range(n_steps):
                g = game.update(i)
                values[i] += -playout(g)
                
            values[i] /= n_steps
    return values



def pi(a):
    a = list(map(np.exp, a))
    t = np.sum(a)
    if t == 0:
        return a
    for i in range(len(a)):
        a[i] /= t
    
    return a


def batch_gen(a, n):
    idx = np.random.choice(range(len(a)), n)
    return idx
    

In [None]:
DN_FILTERS = 128  # 컨볼루션 레이어 커널 수(오리지널 256）
DN_RESIDUAL_NUM = 16  # 레지듀얼 블록 수(오리지널 19)
DN_INPUT_SHAPE = (3, 3, 2)  # 입력 셰이프
DN_OUTPUT_SIZE = 9  # 행동 수(배치 수(3*3))
    
def residual_block():
    def f(x):
        sc = x
        x = layers.Conv2D(DN_FILTERS, 3, padding='same', use_bias=False,
              kernel_initializer='he_normal', kernel_regularizer=L1L2(l2=0.0005))(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        x = layers.Conv2D(DN_FILTERS, 3, padding='same', use_bias=False,
              kernel_initializer='he_normal', kernel_regularizer=L1L2(l2=0.0005))(x)
        x = layers.BatchNormalization()(x)
        x = layers.Add()([x, sc])
        x = layers.Activation('relu')(x)
        return x

    return f
    
def dual_network():
    # 모델 생성이 완료된 경우 처리하지 않음
#     if os.path.exists('./model/best.h5'):
#         return

    # 입력 레이어
    input = layers.Input(shape=DN_INPUT_SHAPE)

    # 컨볼루션 레이어
    x = layers.Conv2D(DN_FILTERS, 3, padding='same', use_bias=False,
              kernel_initializer='he_normal', kernel_regularizer=L1L2(l2=0.0005))(input)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # 레지듀얼 블록 x 16
    for i in range(DN_RESIDUAL_NUM):
        x = residual_block()(x)

    # 풀링 레이어
    x = layers.GlobalAveragePooling2D()(x)
    
    x = layers.Dense(16, activation='relu', kernel_regularizer=L1L2(l2 = 0.0001), kernel_initializer='he_normal')(x)

    # policy 출력
    p = layers.Dense(DN_OUTPUT_SIZE, kernel_regularizer=L1L2(l2=0.0005),
              activation='softmax', name='pi')(x)

    # value 출력
#     v = layers.Dense(1, kernel_regularizer=L1L2(l2=0.0005))(x)
#     v = layers.Activation('tanh', name='v')(v)

    # 모델 생성
    model = models.Model(inputs=input, outputs=p)

    model.compile(optimizer = 'adam',
                 loss = 'categorical_crossentropy')

    return model

class CNN:
    def __init__(self):
        K.clear_session()
        self.model = dual_network()
        self.X = []
        self.y = []
        
    def action(self, game):
        res = self.predict(game)
        a = np.argmax(res)
        a = game.empty[a]

        return a

    def warmup(self, n=100, wn=30):
        GAMMA = 0.99
        for p in range(n):
            print('epochs:', (p+1))
            game = Game(init_state)

            while 1:
                state = self.make_state(game)
                if p < wn:
                    a = action(game)
                    values = value(game)
                    values = np.reshape(pi(values), (1, 9))
                    game = game.update(a) 
                else:
                    a = self.action(game)
                    values = np.reshape(self.model.predict(state)[0], (1, 9))
                    game = game.update(a) 
                    if game.is_lose():
                        r = 1
                    else:
                        r = 0
                    state_next = self.make_state(game)
                    values[0][a] += (r + GAMMA*max(value(game)))
                        
                if self.X == []:
                    self.X = state
                    self.y = values
                else:           
                    self.X = np.concatenate([self.X, state])
                    self.y = np.concatenate([self.y, values])
   
                if game.is_done():
                    break
                    
            if p > wn:
                idx = batch_gen(self.X, 128)
                train_X = self.X[idx]
                train_y = self.y[idx]

#                 self.model.fit(train_X, train_y, verbose=0, epochs=100, callbacks=[es], validation_split=0.2)
                self.model.fit(train_X, train_y, verbose=0)
            else:
                self.model.fit(self.X, self.y, verbose=0, epochs=1)
                self.X = []
                self.y = []
        
    def make_state(self, game):
        status = game.next_opp()
        opp = 3 - status
        a = game.state
        a1 = np.where(a==status, 1, 0)
        a2 = np.where(a==opp, 1, 0)
        res = np.array([a1, a2])
        res = res.reshape(2, 3, 3).transpose(1, 2, 0).reshape(1, 3, 3, 2)
        
        return res
    
    def predict(self, game):
        state = self.make_state(game)
        
        res = self.model.predict(state)[0]
        
        res = res[game.empty]
        
        return res 


In [None]:
dd = CNN()

In [None]:
dd.warmup(wn=30, n=30)

In [None]:
g = Game(np.array([
    [0, 0, 0],
    [0, 0, 0],
    [0, 0, 0]
]))

dd.action(g)

In [None]:
dd.predict(g)

In [None]:
gc.collect()

In [None]:
def play(game, m1, m2):
    global score
    while 1:
        a1 = m1.action(game)
        game = game.update(a1)
#         print(game.state)
        if game.is_lose():
            score[0] += 1
#             print(game.state)
            return 
        elif game.is_draw():
            score[2] += 1
#             print(game.state)
            return 

        a2 = m2.action(game)
        game = game.update(a2)
#         print(game.state)
        if game.is_lose():
            score[1] += 1
#             print(game.state)
            return 
        elif game.is_draw():
            score[2] += 1
#             print(game.state)
            return 
        

In [None]:
game = Game(init_state)
m1 = Random()
# m2 = CNN()

In [None]:
%%time
score = [0, 0, 0]
for _ in range(100):
#     print(_)
#     print(score)
    play(game, dd, m1)
print(score)

score = [0, 0, 0]
for _ in range(100):
    play(game, m1, dd)
#     print(score)
print(score)

In [None]:
# sns.barplot(x = [1, 2], y = score[:2])

In [None]:
# tanh + mse
# v1: 732vs601 // 350vs323
# v2: 738vs399 // 411vs149
# v3: 1161vs773 // 791vs170
# v4: 1305vs567 // 832vs86

In [None]:
dd.model.save('./CNN.h5')