In [1]:
import sys
sys.path.append('../input/pytorch-geometric')

In [2]:
import numpy as np
import math
from IPython.display import clear_output
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import Data, Dataset, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn.conv.gin_conv import GINConv
from torch.optim import Adam
import time
# import multiprocessing as mp
from tqdm import trange
import gc

In [3]:
class TicTacToe:
    def __init__(self, size_of_board, length_to_win):
        self.size_of_board = size_of_board
        self.length_to_win = length_to_win
        self.row_count = size_of_board
        self.column_count = size_of_board
        self.action_size = self.row_count * self.column_count

    def get_initial_state(self):
        return np.zeros((self.row_count, self.column_count))

    def get_next_state(self, state, action, player):
        row = action // self.column_count
        column = action % self.column_count
        state[row, column] = player
        return state

    def get_valid_moves(self, state):
        return (state.reshape(-1) == 0).astype(np.uint8)

    def get_limited_valid_moves(self, state):
        limited_moves = (state.reshape(-1) == 0).astype(np.uint8)


        for m in range(len(limited_moves)):

            move = limited_moves[m]
            if move == 0: continue
            limited_moves[m] = 0
            r0 = m // self.size_of_board
            c0 = m % self.size_of_board
            for i in range (-1, 2, 1):
                for j in range (-1, 2, 1):
                    r1 = r0 + i
                    c1 = c0 + j
                    if r1 < 0 or r1 >= self.size_of_board or c1 < 0 or c1 >= self.size_of_board or r1 * self.size_of_board + c1 == m: continue
                    if state[r1][c1] != 0: limited_moves[m] = 1

        return limited_moves

    def get_adjacent_valid_moves(self, state, action):
        limited_moves = (state.reshape(-1) == 0).astype(np.uint8)

        for m in range(len(limited_moves)):

            move = limited_moves[m]
            if move == 0: continue
            limited_moves[m] = 0
            r0 = m // self.size_of_board
            c0 = m % self.size_of_board

            if action == None: return limited_moves
            r_act = action // self.size_of_board
            c_act = action % self.size_of_board
            for i in range (-2, 3, 1):
                for j in range (-2, 3, 1):
                    r1 = r0 + i
                    c1 = c0 + j
                    if r1 < 0 or r1 >= self.size_of_board or c1 < 0 or c1 >= self.size_of_board or r1 * self.size_of_board + c1 == m: continue
                    if not (r_act - 2 <= r1 and r1 <= r_act + 2 and c_act - 2 <= c1 and c1 <= c_act + 2): continue
                    if state[r1][c1] != 0: limited_moves[m] = 1

        return limited_moves

    def check_win(self, state, action):
        if action == None:
            return False

        row = action // self.column_count
        column = action % self.column_count
        player = state[row, column]

        row_lower_limit = - (self.length_to_win - 1) if 0 < row - (self.length_to_win - 1) else 0 - row
        row_upper_limit = (self.length_to_win - 1) if self.row_count > (row + (self.length_to_win - 1)) else (self.row_count - 1) - row

        col_lower_limit = - (self.length_to_win - 1) if 0 < column - (self.length_to_win - 1) else 0 - column
        col_upper_limit = (self.length_to_win - 1) if self.column_count > (column + (self.length_to_win - 1)) else (self.column_count - 1) - column

        count = 0
        for i in range (row_lower_limit, row_upper_limit + 1):
            if state[row + i][column] == player:
                count += 1
                if count == self.length_to_win: return True
            else: count = 0

        count = 0
        for i in range (col_lower_limit, col_upper_limit + 1):
            if state[row][column + i] == player:
                count += 1
                if count == self.length_to_win: return True
            else: count = 0

        count = 0
        for i in range (max(row_lower_limit, col_lower_limit), min(row_upper_limit, col_upper_limit) + 1):
            if state[row + i][column + i] == player:
                count += 1
                if count == self.length_to_win: return True
            else: count = 0

        count = 0
        for i in range (max(- row_upper_limit, col_lower_limit), min(- row_lower_limit, col_upper_limit) + 1):
            if state[row - i][column + i] == player:
                count += 1
                if count == self.length_to_win: return True
            else: count = 0

        return False



    def get_value_and_terminated(self, state, action):
        if self.check_win(state, action):
            return 1, True
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True
        return 0, False

    def get_opponent(self, player):
        return -player

    def get_opponent_value(self, value):
        return -value

    def change_perspective(self, state, player):
        for i in range(len(state)):
            for j in range(len(state[i])):
                state[i][j] *= player
        return state
    
    def is_initial_state(self, state):
        return np.all(state == 0)
    
    def get_neutral_state(self, state, player):
        return state * player

In [4]:
def get_edge_index(x):
    board_size = int(np.sqrt(len(x) - 1))
    edge_index = []
    for row in range(board_size):
        for col in range(board_size):
            curr_idx = row * board_size + col
            if col < board_size - 1:
                right_neighbor = row * board_size + col + 1
                edge_index.append([curr_idx, right_neighbor])
                if row > 0:
                    up_neighbor = right_neighbor - board_size
                    edge_index.append([curr_idx, up_neighbor])
                if row < board_size - 1:
                    down_neighbor = right_neighbor + board_size
                    edge_index.append([curr_idx, down_neighbor])
            if row < board_size - 1:
                bottom_neighbor = (row + 1) * board_size + col
                edge_index.append([curr_idx, bottom_neighbor])
    additional_node_idx = board_size * board_size
    for i in range(board_size * board_size):
        edge_index.append([additional_node_idx, i])
    edge_index = torch.tensor(edge_index).t().contiguous()
    return edge_index

def create_dataset(data):

    data_x, data_y = data
    data_x = np.array(data_x, dtype=int)
    data_x = np.append(data_x, np.full((data_x.shape[0], 1), 0), axis=1)
#     data_x = data_x.astype(int)
    pi, v = data_y
    data_y = [(pi[i], v[i]) for i in range(len(pi))]

    return [Data(x=g, y=label, edge_index=get_edge_index(g)) for g, label in zip(data_x, data_y)]

def create_GIN_input_data(state):

    state = np.array(state, dtype=int)
    state = np.append(state, 0)
#     state = state.astype(int)

    return Data(x=state, edge_index=get_edge_index(state))

def get_subgraph(data: Data):
    x = data.x
    board_size = int(np.sqrt(len(x) - 1))
    num_subgraphs = board_size // 2
    index_choice = [[(i, j) for i in range(k + 1) for j in range(k + 1)]\
                    for k in range(2)]
    additional_node_idx = board_size * board_size
    data = []
    for _ in range(num_subgraphs):
        while True:
            n = random.choice([board_size - 1, board_size - 2])
            if len(index_choice[board_size - n - 1]) != 0:
                break
        
        coo = random.choice(index_choice[board_size - n - 1])
        index_choice[board_size - n - 1].remove(coo)
        coo_x, coo_y = coo
        edge_index = []
        for row in range(coo_x, coo_x + n):
            for col in range(coo_y, coo_y + n):
                curr_idx = row * board_size + col
                if col >= coo_y and col < coo_y + n - 1:
                    right_neighbor = row * board_size + col + 1
                    edge_index.append([curr_idx, right_neighbor])
                if row >= coo_x and row < coo_x + n - 1:
                    bottom_neighbor = (row + 1) * board_size + col
                    edge_index.append([curr_idx, bottom_neighbor])
                edge_index.append([additional_node_idx, curr_idx])

        edge_index = torch.tensor(edge_index).t().contiguous()
        new_data = Data(x=x, edge_index=edge_index)
        data.append(new_data)
        
    return data

def train(model, optimizer, loader):
    model.train()

    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        # print(data.edge_index)
        out_pi, out_val = model(data)

        y1, y2 = torch.tensor(data.y[0]), torch.tensor(data.y[1])
        y2 = y2.to(torch.float32)

        loss = F.mse_loss(out_val, y2) + F.cross_entropy(out_pi, y1)
        total_loss += loss
        loss.backward()
        optimizer.step()
    return total_loss

In [5]:
class GIN(nn.Module):
    def __init__(self, hidden_dim, dropout_rate=0.5) -> None:
        super(GIN, self).__init__()
        self.hidden_dim = hidden_dim
        self.dropout_rate = dropout_rate
        self.conv_1 = GINConv(nn=nn.Sequential(
                nn.Linear(1, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim)
            ), train_eps=True
        )


        self.conv_2 = GINConv(nn=nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim)
            ), train_eps=True
        )

        self.conv_3 = GINConv(nn=nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim)
            ), train_eps=True
        )

        self.fc_1 = nn.Linear(3 * hidden_dim + 1, 2 * hidden_dim)
        self.fc_2 = nn.Linear(2 * hidden_dim, hidden_dim)

        self.fc_out = nn.Linear(hidden_dim, 1)

    def output_pi(self, x, mask):
        p = torch.mean(x, dim=-1)

        p[mask] = float('-inf')
        p = F.softmax(p, dim=0)
        p = p[:-1]
        return p

    def output_v(self, x):
        v = torch.mean(x)
        v = F.tanh(v)
        return v

    def forward(self, data: Data, subgraph=False):

        x, edge_index = data.x, data.edge_index
        # output_dim = len(x)
        mask = (x != 0)
        mask[-1] = True
        mask = torch.tensor(mask)

        # print("edge index:",edge_index)
        # print("old mask:", mask)
        if subgraph:

            mask_subgraph = torch.ones((len(x))).bool()
            for node in edge_index[0]:
                mask_subgraph[node] = False
            for node in edge_index[1]:
                mask_subgraph[node] = False

            # print("mask_subgraph:", mask_subgraph)
            mask += mask_subgraph
        # print("new mask", mask)
        
        x = torch.tensor(x)
        x = torch.unsqueeze(x, 1)

        x1 = self.conv_1(x, edge_index)
        x2 = self.conv_2(x1, edge_index)
        x3 = self.conv_3(x2, edge_index)

        # print("shape x3: ", x3.shape)


        x = torch.cat((x, x1, x2, x3), dim=-1)

        # print("shape x after cat: ", x.shape)

        x = F.relu(self.fc_1(x))

        # print("shape x after fc1: ", x.shape)
        x = F.dropout(x, p=self.dropout_rate, training=self.training)
        x = self.fc_2(x)

        # print("shape x after fc2: ", x.shape)

        x = self.fc_out(x)

        # print("shape x after fc_out: ", x.shape)

        out_policy = self.output_pi(x, mask)
        out_value = self.output_v(x)

        return out_policy, out_value

In [6]:

class Node:
    def __init__(self, game, args, state, parent=None, action_taken=None, p_value=0, q_value=0, n_value=0, init=True, is_root=False):
        self.game = game
        self.args = args
        if is_root: self.state = state
        else: self.state = None
        self.parent = parent
        self.action_taken = action_taken

        self.children = []
        if init == True: self.expandable_moves = game.get_valid_moves(state)
        else: self.expandable_moves = game.get_limited_valid_moves(state)
        if action_taken != None: self.expandable_moves[action_taken] = 0

        self.q_value = q_value          #Q value
        self.p_value = p_value          #probability to choose
        self.n_value = n_value          #number of visit
        self.u_value = 0



    def is_expanded_node(self):
        return np.sum(self.expandable_moves) == 0

    def have_child(self):
        return len(self.children) > 0

    def select(self, iter_state):

        best_child = self.children[0]
        best_ucb = self.get_ucb_of_child(self.children[0])

        # print(self.state)
        for child in self.children:
            ucb = self.get_ucb_of_child(child)
            # print(ucb)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb

        # print("Number of children:", len(self.children))
        # print(best_child.state)
        child_state = self.game.get_next_state(iter_state, best_child.action_taken, 1)
        child_state = self.game.change_perspective(iter_state, player=-1)
        return best_child, child_state

    def get_ucb_of_child(self, child):

        child.u_value = self.args['C'] * child.p_value * math.sqrt(self.n_value) / (child.n_value + 1)

        # print("q_value:", self.q_value)
        # print("u_value:", self.u_value)

        return child.q_value + child.u_value

    def expand(self, model, iter_state):
        
        #get p_s, v_s from model given the state
#         iter_state = create_GIN_input_data(iter_state)
        p1, v_s = model(create_GIN_input_data(iter_state))
        p1 = (p1.detach() * self.expandable_moves).numpy()
        p1 = p1 / np.sum(p1)

#         p1 = np.ones(self.game.action_size) / self.game.action_size
#         v_s = 0.5

        #get subgraphs output
        # subgraphs = get_subgraph(GIN_input)
        # all_p2 = torch.tensor([])
        # for subgraph in subgraphs:
        #     p, _ = model(subgraph, True)
        #     all_p2 = torch.cat((all_p2, p.unsqueeze(0)), dim = 0)

        # p2 = torch.mean(all_p2, dim = 0)
        ######

        for action in (np.where(self.expandable_moves == 1)[0]):
            self.expandable_moves[action] = 0

            # child_state = self.state.copy()
            # child_state = self.game.get_next_state(child_state, action, 1)
            # child_state = self.game.change_perspective(child_state, player=-1)

            # child = Node(self.game, self.args, child_state, self, action, (p1[action] + p2[action])/2, v_s, 0)
            child = Node(self.game, self.args, iter_state, self, action, p1[action], v_s, 0, False)

            self.children.append(child)
        
        del p1

    def backpropagate(self, value):
        self.q_value = (self.n_value * self.q_value + value) / (self.n_value + 1)
        self.n_value += 1

        value = self.game.get_opponent_value(value)
        if self.parent is not None:
            self.parent.backpropagate(value)


class MCTS:
    def __init__(self, game, model, args):
        self.game = game
        self.args = args
        self.model = model

    def search(self, state):
        
        is_init = self.game.is_initial_state(state)
        root = Node(self.game, self.args, state, None, None, 0, 0, 0, is_init, True)

        for search in trange(self.args['num_searches'], desc="Searching"):
            #garbage collection
#             if search % 100 ==0: gc.collect()
            ###

            node = root
            iter_state = root.state.copy()
            # print("iter state at root:", iter_state)

            # Selection
            while node.is_expanded_node() and node.have_child():
                node, iter_state = node.select(iter_state)
                # print("iter state after select:", iter_state)

            # Expansion
            value , is_terminal = self.game.get_value_and_terminated(iter_state, node.action_taken)

            if not is_terminal:

                # get value from model
                model_input = create_GIN_input_data(iter_state)
                _ , value = self.model(model_input)
                # expand
                node.expand(self.model, iter_state)


            # Backpropagation
            node.backpropagate(value)

        #evaluate time
        action_probs = np.zeros(self.game.action_size)

        # print("root's children:", len(root.children))
        for child in root.children:
            action_probs[child.action_taken] = child.n_value
        action_probs = action_probs / np.sum(action_probs)
        # print("action_probs:", action_probs.reshape(self.game.row_count, self.game.column_count))
        
        return action_probs

In [7]:
def MCTS_self_play(board_size, win_length, args):

    go_first = random.choice([-1, 1])
    first_move = random.randint(0, board_size * board_size - 1)
    player = go_first

    tictactoe = TicTacToe(board_size, win_length)
    state = tictactoe.get_initial_state()
    state = tictactoe.get_next_state(state, first_move, player)
    player = tictactoe.get_opponent(player)
    

    GIN_bot = GIN(512)
    mcts = MCTS(tictactoe, GIN_bot, args)

    all_state = []
    act_probs = []
    z_value = []

    while True:

        print_board(state)
        if player == 1:

            neutral_state = tictactoe.get_neutral_state(state, player)
            mcts_probs = mcts.search(neutral_state)
            action = np.argmax(mcts_probs)

            neutral_state = np.array(neutral_state).reshape(-1)

            all_state.append(neutral_state)
            act_probs.append(mcts_probs)
            z_value.append(player)

        else:

            neutral_state = tictactoe.get_neutral_state(state, player)
            mcts_probs = mcts.search(neutral_state)
            action = np.argmax(mcts_probs)

            neutral_state = np.array(neutral_state).reshape(-1)

            all_state.append(neutral_state)
            act_probs.append(mcts_probs)
            z_value.append(player)

        state = tictactoe.get_next_state(state, action, player)
        winner, is_terminal = tictactoe.get_value_and_terminated(state, action)

        if is_terminal:
            print_board(state)
            if winner == 1:
                # print(player, "won")
                final_z = player
            else:
                # print("draw")
                final_z = 0

            result_tuple = (all_state, (act_probs,  [i * final_z for i in z_value]))
            # print(result_tuple)
            break

        player = tictactoe.get_opponent(player)

    return result_tuple

In [8]:
def get_MCTS_self_play_examples(arguments, board_size, win_length):

    all_state = []
    act_probs = []
    z_value = []

    for i in range(3):
        result = MCTS_self_play(board_size, win_length, arguments)
        all_state += result[0]
        act_probs += result[1][0]
        z_value += result[1][1]

    bundled_results = (all_state, (act_probs,  z_value))
    # print("done")
    # print("results: ")
    # print(bundled_results)

    return bundled_results

In [9]:
from collections import deque
import random


def scalable_training(max_board_size, n_iter, history_size, args, model_link=None):
    history = deque(maxlen=history_size)
    model = GIN(hidden_dim = args['hidden_dim'])
    if model_link != None:
        model.load_state_dict(torch.load(model_link))
        print("Model " + model_link + " loaded!")
    optimizer = Adam(model.parameters(), lr=0.01, weight_decay=0)
    for i in range(n_iter):
        board_size = random.choices([max_board_size, max_board_size - 1, max_board_size - 2, max_board_size - 3],
                            weights=[0.4, 0.3, 0.2, 0.1])[0]
        print(f"Board in n_iter {i}: {board_size}x{board_size}")
        training_example = MCTS_self_play(board_size, 5, args)
        #đưa training_example về dạng Data
        training_example = create_dataset(training_example)

        print(f"#Move in n_iter {i}: {len(training_example)}")

        history.extend(training_example)
        shuffled_history = list(history)
        random.shuffle(shuffled_history)

        start_train = time.time()
        train_loss = train(model, optimizer, shuffled_history)
        print("interation train time:", time.time() - start_train)
        print(f"Loss in n_iter {i}: {train_loss}")

        torch.save(model.state_dict(), 'model_parameters_no_subgraphs.pth')

    return model

In [10]:
def print_board(state):

    clear_output()
    
    print("------------------------")
    print(end = "   ")
    for col in range (state[0].size):

        print("{0:>2}".format(col), end = "  ")
    print("")
    for row in range (state[0].size):
        print("{0:>2}".format(row), end="  ")
        for col in range (state[0].size):
            if state[row][col] == 1:
                print("\033[94mX\033[0m", end = "   ")
            elif state[row][col] == -1:
                print("\033[91mO\033[0m", end = "   ")
            else:
                print("-", end = "   ")
        print("")

In [11]:

def testModel(model, board_size, human_go_first=True):

    tictactoe = TicTacToe(board_size,5)
    if human_go_first: player = 1
    else: player = -1
    args = {
        'C': 1.41,
        'num_searches': 10000
    }

    #A good 8x8 model
    state = tictactoe.get_initial_state()
    GIN_bot = model
#     GIN_bot = GIN(512)
#     print(torch.load('model_parameters.pth'))

    while True:
        
        print_board(state)
        if player == 1:
            valid_moves = tictactoe.get_valid_moves(state)
            row_input, col_input = map(int, input("Enter ROW and COL separated by space:").split())
            action = row_input * tictactoe.column_count + col_input
            if valid_moves[action] == 0:
                print("action not valid")
                continue

          
        else:

            neutral_state = tictactoe.get_neutral_state(state, player)
            model_input = create_GIN_input_data(neutral_state)
            prob_s, v_s = GIN_bot(model_input)
            
            action = torch.argmax(prob_s)


        state = tictactoe.get_next_state(state, action, player)
        winner, is_terminal = tictactoe.get_value_and_terminated(state, action)

        if is_terminal:
            print_board(state)
            if winner == 1:
                print(player, "won")
            else:
                print("draw")
            break

        player = tictactoe.get_opponent(player)

In [12]:

max_board_size = 9
args = {
    'C': 1.41,
    'num_searches': 100,
    'hidden_dim': 512,
    'processes_num': 1,
}
model_link = '/kaggle/input/scalablealphazero/pytorch/train12hours/5/model_parameters_no_subgraphs (4).pth'
model = scalable_training(max_board_size, n_iter=400, history_size=512, args=args, model_link=model_link)

    

------------------------
    0   1   2   3   4   5   6   7  
 0  [94mX[0m   -   [91mO[0m   [91mO[0m   [94mX[0m   [94mX[0m   -   [94mX[0m   
 1  -   [94mX[0m   [91mO[0m   [94mX[0m   [94mX[0m   [91mO[0m   [91mO[0m   [91mO[0m   
 2  -   [94mX[0m   [91mO[0m   -   [94mX[0m   [91mO[0m   [94mX[0m   -   
 3  -   [94mX[0m   [91mO[0m   -   [91mO[0m   -   -   -   
 4  -   -   [91mO[0m   -   [91mO[0m   -   -   -   
 5  -   -   -   [94mX[0m   -   -   -   -   
 6  -   -   -   -   -   -   -   -   
 7  -   -   -   -   -   -   -   -   
#Move in n_iter 399: 23
interation train time: 41.97084951400757
Loss in n_iter 399: 2993.4781616986425
