# AlphaZero-like combination of MCTS and DNN

AlphaZero is a very effective approach, but I lack the resources and the implementation skills to execute it. I suggest an approach inspired by AlphaZero, but note that this approach is different from AlphaZero in many ways.

In [None]:
from kaggle_environments import evaluate, make, utils
from kaggle_environments.envs.connectx.connectx import play,is_win
import matplotlib.pyplot as plt
import numpy as np
import torch
import time
import random
import collections
from joblib import Parallel, delayed
import joblib

### Define Network

In [None]:
class Alpha_Net(torch.nn.Module):
    def __init__(self):
        super(Alpha_Net,self).__init__()
        self.block1 = torch.nn.Sequential(
        torch.nn.Conv2d(2,20,kernel_size=2,padding=0,stride=1),
        torch.nn.BatchNorm2d(20),
        torch.nn.LeakyReLU(),
        torch.nn.Conv2d(20,40,kernel_size=3,padding=0,stride=1),
        torch.nn.BatchNorm2d(40),
        torch.nn.LeakyReLU(),
        torch.nn.Conv2d(40,120,kernel_size=3,padding=0,stride=1),
        torch.nn.BatchNorm2d(120),
        torch.nn.LeakyReLU(),
        )
        self.block2 = torch.nn.Sequential(
        torch.nn.Linear(240, 42),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(42, 7),
        )
        self.t_out = torch.nn.LogSoftmax(dim=1)
        self.e_out = torch.nn.Softmax(dim=1)
        self.block3 = torch.nn.Sequential(
        torch.nn.Linear(240, 16),
        torch.nn.LeakyReLU(),
        torch.nn.Linear(16, 1),
        torch.nn.Tanh()
        )

    def forward(self, x):
        x = self.block1(x)
        x = x.view(-1,240)
        x1 = self.block2(x)
        if self.training:
            x1 = self.t_out(x1)
        else:
            x1 = self.e_out(x1)
        x2 = self.block3(x)
        x = torch.cat([x1,x2],dim=1)
        return x

### Monte Carlo tree search with DNN
I think there are more efficient implementations. I hope you will work on improving this.

In [None]:
def alpha_MCTS(net,board,config,start_time=None,info=None,expand=False,root=False,C_p=1000,expand_threshold=1,time_lim=4.0,gamma=0.95,steps=350):#parameter C_p should be tuned
    if info == None:
        b1 = np.array([[[1 if p == 1 else 0 for p in board]]]).reshape(1,1,6,7)
        b2 = np.array([[[1 if p == 2 else 0 for p in board]]]).reshape(1,1,6,7)
        board2 = np.concatenate([b1,b2],axis=1)
        board2 = torch.from_numpy(board2).float()
        pred = net(board2).detach().numpy()[0]
        policy = pred[:7]
        value = pred[7]
        count = [0,0,0,0,0,0,0]
        info_each = [0,0,0,0,0,0,0]
        Q = [0,0,0,0,0,0,0]
    else:
        policy = info['policy']
        info_each = info['info_each']
        count = info['count']
        Q = info['Q']
        value = info['value']
    turn = sum([1 if p != 0 else 0 for p in board])%2 + 1
    if root:
        start_time = time.time()
        for t in range(steps):
            C = np.log((1+sum(count)+C_p)/C_p) + 1.25
            select = np.array([Q[i]+C*policy[i]*(np.sqrt(sum(count))/(1+count[i])) for i in range(7)])
            child_index = int(np.argmax(select))
            if board[child_index] != 0:
                Q[child_index] = -1
            else:
                next_board = board[:]
                if is_win(next_board,child_index,turn,config,False):
                    Q[child_index] = 1
                else:
                    play(next_board,child_index,turn,config)
                    fillness = sum([1 if p != 0 else 0 for p in next_board])
                    if fillness == 42:
                        Q[child_index] = 0
                    else:
                        if count[child_index] >= expand_threshold:
                            scc, ie = alpha_MCTS(net,next_board,config,start_time,expand=True,info=info_each[child_index])
                            Q[child_index] = -scc
                            info_each[child_index] = ie
                        else:
                            if count[child_index] == 0:
                                scc, pol = alpha_MCTS(net,next_board,config,start_time,info=None)
                                Q[child_index] = -scc
                                info_each[child_index] = {'policy':pol,'info_each':[0,0,0,0,0,0,0],'count':[0,0,0,0,0,0,0],'Q':[0,0,0,0,0,0,0],'value':scc}
            count[child_index] += 1
            if time.time()-start_time >= time_lim:
                break
        return np.power(np.array(count),1) / np.power(np.array(count),1).sum(), np.sum(np.array(Q) * np.array(count)) / sum(count)
    else:
        if expand:
            C = np.log((1+sum(count)+C_p)/C_p) + 1.25
            select = np.array([Q[i]+C*policy[i]*(np.sqrt(sum(count))/(1+count[i])) for i in range(7)])
            child_index = int(np.argmax(select))
            if board[child_index] != 0:
                Q[child_index] = -1
            else:
                next_board = board[:]
                if is_win(next_board,child_index,turn,config,False):
                    Q[child_index] = 1
                else:
                    play(next_board,child_index,turn,config)
                    fillness = sum([1 if p != 0 else 0 for p in next_board])
                    if fillness == 42:
                        Q[child_index] = 0
                    else:
                        if count[child_index] >= expand_threshold:
                            scc, ie = alpha_MCTS(net,next_board,config,start_time,expand=True,info=info_each[child_index])
                            Q[child_index] = -gamma*scc
                            info_each[child_index] = ie
                        else:
                            if count[child_index] == 0:
                                scc, pol = alpha_MCTS(net,next_board,config,start_time,info=None)
                                Q[child_index] = -gamma*scc
                                info_each[child_index] = {'policy':pol,'info_each':[0,0,0,0,0,0,0],'count':[0,0,0,0,0,0,0],'Q':[0,0,0,0,0,0,0],'value':scc}
            count[child_index] += 1
            re_Q = np.array(Q) * np.array(count)
            r = np.power(2,sum(count)-1)/(1+np.power(2,sum(count)-1))
            return np.sum(re_Q)/sum(count)*r + value*(1-r), {'policy':policy,'info_each':info_each,'count':count,'Q':Q,'value':value}
        else:
            return value, policy

### Define functions for training

In [None]:
def self_play(net,search=70,gamma=0.95,random_turn=3):
    env = make('connectx',debug=True)
    log = []
    log2 = []
    turn = 0
    while True:
        turn += 1
        p_log = {}
        b1 = np.array([[[1 if p == 1 else 0 for p in env.state[0]['observation']['board']]]]).reshape(1,1,6,7)
        b2 = np.array([[[1 if p == 2 else 0 for p in env.state[0]['observation']['board']]]]).reshape(1,1,6,7)
        board = np.concatenate([b1,b2],axis=1)
        liner_board = env.state[0]['observation']['board']
        p_log['board'] = board[0]
        pol, value = alpha_MCTS(net,liner_board,env.configuration,root=True,steps=search)
        p_log['policy'] = pol
        p_log['reward'] = value  ##different from AlphaZero
        log.append(p_log)
        action = random.choices([0,1,2,3,4,5,6],weights=list(pol),k=1)[0]
        if liner_board[action] != 0 or turn < random_turn:
            action = random.choice([c for c in range(7) if liner_board[c] == 0])
        env.step([action,action])
        if env.state[0]['status'] == 'DONE':
            break
    #Augmentation by reflection
    for i in range(len(log)):
        b = log[i]['board']
        board2 = np.array([np.fliplr(b[0]),np.fliplr(b[1])]).copy()
        policy2 = log[i]['policy'][::-1].copy()
        log2.append({'board':board2,'policy':policy2,'reward':log[i]['reward']})
    log = log + log2
    return log

In [None]:
class original_loss(torch.nn.Module):
    def __init__(self):
        super(original_loss,self).__init__()
    def forward(self,pred,target):
        CEL = (-target[:,:7] * pred[:,:7]).sum(dim=1).sum()
        MSE = torch.pow((target[:,7]-pred[:,7]),2).sum()
        loss = CEL+MSE
        return CEL,MSE

In [None]:
def fight(agent,fs,mode="random"):
    env = make('connectx',debug=True)
    if fs == 1:
        trainer = env.train([None,mode])
    else:
        trainer = env.train([mode,None])
    observation = trainer.reset()
    done = False
    for step in range(42):
        action = agent(observation, env.configuration)
        observation, reward, done, info = trainer.step(action)
        if done:
            break
    return reward

### Define Agent to be trained

In [None]:
class AlphaN():
    def __init__(self):
        self.net = Alpha_Net()
        self.buff = collections.deque([],maxlen=100000)
        self.optim = torch.optim.RMSprop(self.net.parameters(),lr=0.001,weight_decay=1e-4)
        self.criterion = original_loss()
        self.loss_log = []
        self.CEL_log = []
        self.MSE_log = []
        self.score_log = []
        self.env = make("connectx", debug=True)
    def evaluate(self,num,mode='random'):
        def forward(observation, configuration):
            policy,_ = alpha_MCTS(self.net,observation.board,configuration,root=True,steps=70)
            action = random.choices([0,1,2,3,4,5,6],weights=list(policy),k=1)[0]
            if observation.board[action] != 0:
                action = random.choice([c for c in range(configuration.columns) if observation.board[c] == 0])
            return action
        result1 = Parallel(n_jobs=-1,verbose=0)([delayed(fight)(forward,fs=1,mode=mode) for n in range(num)])
        result2 = Parallel(n_jobs=-1,verbose=0)([delayed(fight)(forward,fs=2,mode=mode) for n in range(num)])
        reward1 = sum(result1)
        reward2 = sum(result2)
        return reward1 / num, reward2 / num
    def train(self,num,play_num=40,batch_num=32,batch_size=64,train_loop=50,time_lim=30000):
        start_time = time.time()
        for t in range(num):
            #self play
            self.net.eval()
            with torch.no_grad():
                if t%5 == 0:
                    score1, score2 = self.evaluate(25,mode='negamax')
                    self.score_log.append((score1+score2)/2)
                    print('step: '+str(t)+'  score1: '+str(score1)+'  score2: '+str(score2))
                if t == 0:
                    bu = Parallel(n_jobs=-1,verbose=0)([delayed(self_play)(net,search=49) for net in [self.net]*1000])
                else:
                    bu = Parallel(n_jobs=-1,verbose=0)([delayed(self_play)(net) for net in [self.net]*play_num])
                for ebb in bu:
                    self.buff += ebb
            #Updating Network
            self.net.train()
            for loop in range(train_loop):
                run_loss = 0
                run_CEL = 0
                run_MSE = 0
                for batchs in range(batch_num):
                    batch = random.choices(self.buff,k=batch_size)
                    self.optim.zero_grad()
                    board = np.array([b['board'] for b in batch])
                    board = torch.from_numpy(board).float()
                    policy = np.array([b['policy'] for b in batch])
                    value = np.array([[b['reward']] for b in batch])
                    target = np.concatenate([policy,value],axis=1)
                    target = torch.from_numpy(target).float()
                    pred = self.net(board)
                    with torch.autograd.detect_anomaly():
                        CEL,MSE = self.criterion(pred,target)
                        loss = CEL + MSE
                        loss.backward()
                    self.optim.step()
                    run_loss += loss.item() / batch_num / batch_size
                    run_CEL += CEL.item() / batch_num / batch_size
                    run_MSE += MSE.item() / batch_num / batch_size
                self.loss_log.append(run_loss)
                self.CEL_log.append(run_CEL)
                self.MSE_log.append(run_MSE)
            print('step: '+str(t)+'  loss: '+str(run_loss))
            if time.time() - start_time > time_lim:
                print('time over')
                break
        plt.plot(self.score_log)
        plt.title('Score vsNegaMax')
        plt.xlabel('step')
        plt.ylabel('score')
        plt.show()

### Training

In [None]:
agent1 = AlphaN()

In [None]:
agent1.train(1000)

In [None]:
plt.plot(agent1.loss_log)
plt.title('total loss')
plt.xlabel('loop')
plt.ylabel('loss')
plt.show()

In [None]:
plt.plot(agent1.MSE_log)
plt.title('value loss')
plt.xlabel('loop')
plt.ylabel('loss')
plt.show()

In [None]:
plt.plot(agent1.CEL_log)
plt.title('policy loss')
plt.xlabel('loop')
plt.ylabel('loss')
plt.show()

### Write the submission file

In [None]:
np.set_printoptions(threshold=np.inf)

In [None]:
out = """
from kaggle_environments.envs.connectx.connectx import play,is_win
import time
import random
import torch
import numpy as np
def agent(observation,configuration):
    start_time = time.time()
    class Alpha_Net(torch.nn.Module):
        def __init__(self):
            super(Alpha_Net,self).__init__()
            self.block1 = torch.nn.Sequential(
            torch.nn.Conv2d(2,20,kernel_size=2,padding=0,stride=1),
            torch.nn.BatchNorm2d(20),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(20,40,kernel_size=3,padding=0,stride=1),
            torch.nn.BatchNorm2d(40),
            torch.nn.LeakyReLU(),
            torch.nn.Conv2d(40,120,kernel_size=3,padding=0,stride=1),
            torch.nn.BatchNorm2d(120),
            torch.nn.LeakyReLU(),
            )
            self.block2 = torch.nn.Sequential(
            torch.nn.Linear(240, 42),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(42, 7),
            )
            self.t_out = torch.nn.LogSoftmax(dim=1)
            self.e_out = torch.nn.Softmax(dim=1)
            self.block3 = torch.nn.Sequential(
            torch.nn.Linear(240, 16),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(16, 1),
            torch.nn.Tanh()
            )

        def forward(self, x):
            x = self.block1(x)
            x = x.view(-1,240)
            x1 = self.block2(x)
            if self.training:
                x1 = self.t_out(x1)
            else:
                x1 = self.e_out(x1)
            x2 = self.block3(x)
            x = torch.cat([x1,x2],dim=1)
            return x
    def alpha_MCTS(net,board,config,start_time=None,info=None,expand=False,root=False,C_p=1000,expand_threshold=1,time_lim=2.0,gamma=0.95,steps=10000000):
        if info == None:
            b1 = np.array([[[1 if p == 1 else 0 for p in board]]]).reshape(1,1,6,7)
            b2 = np.array([[[1 if p == 2 else 0 for p in board]]]).reshape(1,1,6,7)
            board2 = np.concatenate([b1,b2],axis=1)
            board2 = torch.from_numpy(board2).float()
            pred = net(board2).detach().numpy()[0]
            policy = pred[:7]
            value = pred[7]
            count = [0,0,0,0,0,0,0]
            info_each = [0,0,0,0,0,0,0]
            Q = [0,0,0,0,0,0,0]
        else:
            policy = info['policy']
            info_each = info['info_each']
            count = info['count']
            Q = info['Q']
            value = info['value']
        turn = sum([1 if p != 0 else 0 for p in board])%2 + 1
        if root:
            for t in range(steps):
                C = np.log((1+sum(count)+C_p)/C_p) + 1.25
                select = np.array([Q[i]+C*policy[i]*(np.sqrt(sum(count))/(1+count[i])) for i in range(7)])
                child_index = int(np.argmax(select))
                if board[child_index] != 0:
                    Q[child_index] = -1
                else:
                    next_board = board[:]
                    if is_win(next_board,child_index,turn,config,False):
                        Q[child_index] = 1
                    else:
                        play(next_board,child_index,turn,config)
                        fillness = sum([1 if p != 0 else 0 for p in next_board])
                        if fillness == 42:
                            Q[child_index] = 0
                        else:
                            if count[child_index] >= expand_threshold:
                                scc, ie = alpha_MCTS(net,next_board,config,start_time,expand=True,info=info_each[child_index])
                                Q[child_index] = -scc
                                info_each[child_index] = ie
                            else:
                                if count[child_index] == 0:
                                    scc, pol = alpha_MCTS(net,next_board,config,start_time,info=None)
                                    Q[child_index] = -scc
                                    info_each[child_index] = {'policy':pol,'info_each':[0,0,0,0,0,0,0],'count':[0,0,0,0,0,0,0],'Q':[0,0,0,0,0,0,0],'value':scc}
                count[child_index] += 1
                if time.time()-start_time >= time_lim:
                    break
            return np.power(np.array(count),1) / np.power(np.array(count),1).sum()
        else:
            if expand:
                C = np.log((1+sum(count)+C_p)/C_p) + 1.25
                select = np.array([Q[i]+C*policy[i]*(np.sqrt(sum(count))/(1+count[i])) for i in range(7)])
                child_index = int(np.argmax(select))
                if board[child_index] != 0:
                    Q[child_index] = -1
                else:
                    next_board = board[:]
                    if is_win(next_board,child_index,turn,config,False):
                        Q[child_index] = 1
                    else:
                        play(next_board,child_index,turn,config)
                        fillness = sum([1 if p != 0 else 0 for p in next_board])
                        if fillness == 42:
                            Q[child_index] = 0
                        else:
                            if count[child_index] >= expand_threshold:
                                scc, ie = alpha_MCTS(net,next_board,config,start_time,expand=True,info=info_each[child_index])
                                Q[child_index] = -gamma*scc
                                info_each[child_index] = ie
                            else:
                                if count[child_index] == 0:
                                    scc, pol = alpha_MCTS(net,next_board,config,start_time,info=None)
                                    Q[child_index] = -gamma*scc
                                    info_each[child_index] = {'policy':pol,'info_each':[0,0,0,0,0,0,0],'count':[0,0,0,0,0,0,0],'Q':[0,0,0,0,0,0,0],'value':scc}
                count[child_index] += 1
                re_Q = np.array(Q) * np.array(count)
                r = np.power(2,sum(count)-1)/(1+np.power(2,sum(count)-1))
                return np.sum(re_Q)/sum(count)*r + value*(1-r), {'policy':policy,'info_each':info_each,'count':count,'Q':Q,'value':value}
            else:
                return value, policy
    net = Alpha_Net()
"""
for key in agent1.net.state_dict().keys():
    if 'num' in key:
        out += "    net.state_dict()['"+key+"'] = torch.tensor("+str(agent1.net.state_dict()[key].item())+")\n"
    else:
        out += "    net.state_dict()['"+key+"'][:] = torch.tensor("+str(list(agent1.net.state_dict()[key].to(torch.device("cpu")).numpy())).replace('array(', '').replace(')', '').replace(' ', '').replace('\n', '').replace(',dtype=float32','')+")\n"
out += """
    net.eval()
    with torch.no_grad():
        pol = alpha_MCTS(net,observation.board,configuration,start_time=start_time,root=True)
    action = int(np.argmax(np.array(pol)))
    if observation.board[action] != 0:
        action = random.choice([c for c in range(configuration.columns) if observation.board[c] == 0])
    return action"""

In [None]:
with open('submission.py', 'w') as f:
    f.write(out)

### Test submission

In [None]:
from submission import agent

In [None]:
env = make("connectx", debug=True)
trainer = env.train(["negamax",None])

observation = trainer.reset()

while not env.done:
    my_action = agent(observation, env.configuration)
    print("My Action", my_action)
    observation, reward, done, info = trainer.step(my_action)
env.render()
print(reward)

### Complete!
This is my first public notebook and I'm not an expert. If there are any errors, please point them out to me.