In [1]:
from collections import deque
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D, BatchNormalization, Activation, add
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
import numpy as np
import tensorflow as tf


class AlphaZeroConfig:

    def __init__(self):
        # Self-Play
        self.self_play_games = 25

        self.num_sampling_moves = 8
        self.num_simulations = 25

        # Root prior exploration noise.
        self.root_dirichlet_alpha = 1
        self.root_exploration_fraction = 0.25

        # UCB formula
        self.pb_c_base = 19652
        self.pb_c_init = 1.25

        # Training
        self.batches_per_iter = 5
        self.epochs_per_batch = 5
        self.learning_rate = 1e-3
        self.window_size = 500
        self.batch_size = 4096


class ReplayBuffer:
    def __init__(self, config=None):
        self.batch_size = 4096
        self.buffer = deque(maxlen=500)
        if config is not None:
            self.batch_size = config.batch_size
            self.buffer = deque(maxlen=config.window_size)

    def save_game(self, game):
        self.buffer.append(game)

    def sample_batch(self):
        move_sum = float(sum(len(g.history) for g in self.buffer))

        games = np.random.choice(self.buffer, size=self.batch_size,
                                 p=[len(g.history) / move_sum for g in self.buffer])
        game_pos = [(g, np.random.randint(len(g.history))) for g in games]
        # batch = []
        images = []
        target_vs = []
        target_ps = []
        for g, i in game_pos:
            image = g.make_image(i)
            target_v, target_p = g.make_target(i)
            target_v = np.array(target_v, dtype=np.float64)  # .reshape(-1, 1)
            target_p = np.array(target_p)

            # Augment Data
            if np.random.random() < 0.5:
                image = np.fliplr(image)
                target_p = np.flip(target_p)

            images.append(image)
            target_vs.append(target_v)
            target_ps.append(target_p)
            # batch.append((image, target_v, target_p))
        batch = [np.array(images), np.array(target_vs), np.array(target_ps)]
        return batch


class ResNet:
    def __init__(self, weights=None):
        self.rows = 6
        self.columns = 7
        self.model = self.create_model()
        if weights:
            self.model.load_weights(weights)

    @staticmethod
    def res_block(inputs, filters, reg=0.01, bn_eps=2e-5):
        x = Conv2D(filters=int(filters), kernel_size=3, padding="same", kernel_regularizer=l2(reg))(inputs)
        x = BatchNormalization(epsilon=bn_eps)(x)
        x = Activation('relu')(x)
        x = Conv2D(filters=int(filters), kernel_size=3, padding="same", kernel_regularizer=l2(reg))(x)
        x = BatchNormalization(epsilon=bn_eps)(x)
        x = add([x, inputs])
        x = Activation('relu')(x)
        return x

    def policy_head(self, x, bn_eps=2e-5):
        x = Conv2D(32, kernel_size=3, padding='same')(x)
        x = BatchNormalization(epsilon=bn_eps)(x)
        x = Activation('relu')(x)
        x = Flatten()(x)
        x = Dense(self.columns, activation='linear', name='policy_head')(x)
        return x

    def value_head(self, x, bn_eps=2e-5):
        x = Conv2D(32, kernel_size=1, padding='same')(x)
        x = BatchNormalization(epsilon=bn_eps)(x)
        x = Activation('relu')(x)
        x = Flatten()(x)
        x = Dense(256, activation='relu')(x)
        x = Dense(1, activation='tanh', name='value_head')(x)
        return x

    def create_model(self, num_residual_blocks=2, reg=0.01, bn_eps=2e-5):
        inputs = Input(shape=(self.rows, self.columns, 4))
        x = Conv2D(256, kernel_size=3, padding='same', kernel_regularizer=l2(reg))(inputs)
        x = BatchNormalization(epsilon=bn_eps)(x)
        x = Activation('relu')(x)

        for _ in range(num_residual_blocks):
            x = ResNet.res_block(x, 256)

        p = self.policy_head(x)
        v = self.value_head(x)

        model = Model(inputs, [v, p])
        return model

    def inference(self, x):
        if len(x.shape) != 4:
            x = np.expand_dims(x, axis=0)
        value, policy_logits = self.model.predict(x)
        prob = tf.nn.softmax(policy_logits)
        return value, prob

class Game:
    def __init__(self, history=None):
        self.history = history or []
        self.child_visits = []
        self.rows = 6
        self.columns = 7
        self.num_actions = self.columns
        self.initial_state = np.zeros((self.rows, self.columns))

    def __repr__(self):
        return str(self.state())

    def state(self, state_index=None):
        board = self.initial_state.copy()
        history = self.history
        if state_index is not None:
            history = history[:state_index]
        for move in history:
            token = -1 if len(np.argwhere(board != 0)) % 2 == 1 else 1
            move_idx = self.lowest_position(board, move)
            board[move_idx] = token
        return board

    @classmethod
    def lowest_position(cls, board, col):
        try:
            row = np.argwhere(board[:, col] == 0)[-1]
            return int(row), col
        except IndexError:
            return False

    @property
    def terminal(self):
        to_play = self.to_play
        if self.terminal_value(to_play) == 0:
            return False
        return True

    def terminal_value(self, to_play, board_state=None):
        reward = 1
        state = self.state()

        if board_state is not None:
            state = board_state

        if len(self.legal_actions) == 0:
            return 1e-10

        winners = [self.check_win_vert(state), self.check_win_horiz(state), self.check_win_diag(state)]

        try:
            consec = list(filter(None, winners))[0]  # 1 or -1
            winner = 0 if consec == 1 else 1
            return reward if winner != int(to_play) else -reward
        except:
            return 0

    @property
    def legal_actions(self):
        # returns np array of legal columns
        try:
            return np.array(list(filter(None, [self.lowest_position(self.state(), c) for c in range(7)])))[:, -1]
        except IndexError:
            return []

    def clone(self):
        return Game(list(self.history))

    def apply(self, action):
        if action not in self.legal_actions:
            raise Exception('Illegal move')
        self.history.append(action)

    def store_search_statistics(self, root):
        sum_visits = sum(child.visit_count for child in root.children.values())
        self.child_visits.append([
            root.children[a].visit_count / sum_visits if a in root.children else 0
            for a in range(self.num_actions)
        ])

    def make_image(self, state_index=None):
        state = self.state(state_index)
        to_play = len(self.history[:state_index]) % 2
        to_play_matrix = np.ones((self.rows, self.columns))
        if to_play != 0:
            to_play_matrix = -to_play_matrix
            #state = state*-1
            #state[state == -0] = 0
        curr_player_binary = np.array(state == to_play_matrix, dtype='float64')
        opp_player_binary = np.array(state == -to_play_matrix, dtype='float64')
        input_image = np.dstack([state, curr_player_binary, opp_player_binary, to_play_matrix])
        return input_image

    def make_target(self, state_index: int):
        #discount_rate = 0.95
        # discount based on distance from terminal state
        #move_dist = (len(self.history) - state_index)/2
        value = self.terminal_value(state_index % 2)  # * discount_rate ** move_dist
        return value, self.child_visits[state_index]

    @property
    def to_play(self):
        # returns 0 if player 1 else returns 1
        return len(self.history) % 2

    @classmethod
    def check_win_horiz(cls, board):
        winner = None
        win_con = 4
        for row in board:
            consecutive = []

            for col in row:
                if len(consecutive) == 0:
                    consecutive = [col]
                elif consecutive[-1] == col:
                    consecutive.append(col)
                else:
                    consecutive = [col]
                if len(consecutive) == win_con and consecutive[0] != 0:
                    winner = consecutive[0]
        return winner

    @classmethod
    def check_win_vert(cls, board):
        winner = None
        win_con = 4
        for row in board.T:
            consecutive = []
            for col in row:
                if len(consecutive) == 0:
                    consecutive = [col]
                elif consecutive[-1] == col:
                    consecutive.append(col)
                else:
                    consecutive = [col]
                if len(consecutive) == win_con and consecutive[0] != 0:
                    winner = consecutive[0]
        return winner

    @classmethod
    def check_win_diag(cls, board):
        winner = None
        win_con = 4
        diags_lr = [board[::-1, :].diagonal(i) for i in range(-board.shape[0] + 1, board.shape[1])]
        board_flip = np.fliplr(board)
        diags_rl = [board_flip[::-1, :].diagonal(i) for i in range(-board.shape[0] + 1, board.shape[1])]
        diags = diags_lr + diags_rl
        for diag in diags:
            if len(diag) >= 4:
                consecutive = []
                for d in diag:
                    if len(consecutive) == 0:
                        consecutive = [d]
                    elif consecutive[-1] == d:
                        consecutive.append(d)
                    else:
                        consecutive = [d]
                    if len(consecutive) == win_con and consecutive[0] != 0:
                        winner = consecutive[0]
        return winner


In [2]:
import numpy as np
import math
import tensorflow as tf

class Node:
    def __init__(self, prior: float):
        self.visit_count = 0
        self.to_play = -1
        self.prior = prior
        self.value_sum = 0
        self.children = {}

    @property
    def is_expanded(self):
        return len(self.children) > 0

    @property
    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

    def __repr__(self):
        return f'Node Value: {self.value}\n' \
               f'Node Prior: {self.prior}\n' \
               f'Node Visits : {self.visit_count}'


def run_mcts(config, game, network, add_noise=True):
    root = Node(0)
    evaluate(root, game, network)
    if add_noise:
        add_exploration_noise(config, root)

    for _ in range(config.num_simulations):
        node = root
        scratch_game = game.clone()
        search_path = [node]

        while node.is_expanded:
            action, node = select_child(config, node)
            scratch_game.apply(action)
            search_path.append(node)
        value = evaluate(node, scratch_game, network)
        backpropagate(search_path, value, scratch_game.to_play)
    return select_action(config, game, root), root


def select_action(config, game, root):
    visit_counts = [(child.visit_count, action) for action, child in root.children.items()]
    if len(game.history) < config.num_sampling_moves:
        _, action = softmax_sample(visit_counts)
    else:
        _, action = max(visit_counts)
    return action


def select_child(config, node):
    _, action, child = max((ucb_score(config, node, child), action, child)
                           for action, child in node.children.items())
    return action, child


# The score for a node is based on its value, plus an exploration bonus based on
# the prior.
def ucb_score(config, parent, child):
    pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /
                    config.pb_c_base) + config.pb_c_init
    pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

    prior_score = pb_c * child.prior
    value_score = child.value
    return prior_score + value_score


def evaluate(node, game, network):
    value, policy_logits = network.inference(game.make_image())
    value, policy_logits = np.squeeze(value), np.squeeze(policy_logits)
    node.to_play = game.to_play
    policy = {a: math.exp(policy_logits[a]) for a in game.legal_actions}
    policy_sum = sum(policy.values())
    for action, p in policy.items():
        node.children[action] = Node(p / policy_sum)
    return value


def backpropagate(search_path: list, value: float, to_play):
    for node in search_path:
        node.value_sum += value if node.to_play == to_play else -value
        node.visit_count += 1


def add_exploration_noise(config, node: Node):
    actions = node.children.keys()
    noise = np.random.gamma(config.root_dirichlet_alpha, 1, len(actions))
    frac = config.root_exploration_fraction
    for a, n in zip(actions, noise):
        node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac


def softmax_sample(visit_counts):
    visits, actions = zip(*visit_counts)
    visits = np.array(visits).astype('float64')
    prob = tf.nn.softmax(visits)
    idx = np.random.choice(len(actions), p=prob)
    return None, actions[idx]

In [3]:
import copy
import random



class SmartRandomAgent:
    ''' Makes random moves, unless it is one move away from victory, in which case, it will always choose that move
           If opponent's next optimal move is a winning move, it will play the move that blocks the win'''

    def __init__(self, dumb=False):
        self.game = None
        self.dumb = dumb

    @property
    def self_winning_move(self):
        token = 1 if self.game.to_play == 0 else -1
        move_col = None
        for i in self.game.legal_actions:
            copy_state = copy.deepcopy(self.game.state())
            move = self.game.lowest_position(copy_state, i)
            copy_state[move] = token
            if self.game.terminal_value(self.game.to_play, copy_state) != 0:
                move_col = i
        return move_col

    @property
    def opp_winning_move(self):
        token = 1 if self.game.to_play == 0 else -1
        move_col = None
        for i in self.game.legal_actions:
            copy_state = copy.deepcopy(self.game.state())
            move = self.game.lowest_position(copy_state, i)
            copy_state[move] = -token
            if self.game.terminal_value(-self.game.to_play, copy_state) != 0:
                move_col = i
        return move_col

    def move(self, game):
        if self.dumb:
            game.apply(random.choice(game.legal_actions))
            return game
        self.game = game.clone()
        opp_win_move = self.opp_winning_move
        if opp_win_move is not None:
            game.apply(opp_win_move)
            return game
        self_win_move = self.self_winning_move
        if self_win_move is not None:
            game.apply(self_win_move)
            return game
        game.apply(random.choice(self.game.legal_actions))
        return game
class Pit:
    def __init__(self, agentZero, SmartRandomAgent, num_sims):
        self.AI = agentZero
        self.opp = SmartRandomAgent
        self.num_sims = num_sims
        self.AI_wins = 0
        self.opp_wins = 0

    def simulate(self):
        # AI as player 1
        for i in range(int(self.num_sims / 2)):
            # print(i)
            game = Game()
            while not game.terminal:
                game = self.AI.move(game)
                if game.terminal:
                    if len(game.legal_actions) != 0:
                        self.AI_wins += 1
                    break
                game = self.opp.move(game)
                if game.terminal:
                    if len(game.legal_actions) != 0:
                        self.opp_wins += 1
                    break

        # AI as player 2
        for i in range(int(self.num_sims / 2)):
            game = Game()
            while not game.terminal:
                game = self.opp.move(game)
                if game.terminal:
                    if len(game.legal_actions) != 0:
                        self.opp_wins += 1
                    break
                game = self.AI.move(game)
                if game.terminal:
                    if len(game.legal_actions) != 0:
                        self.AI_wins += 1
                    break

    @property
    def winrate(self):
        return self.AI_wins / self.num_sims * 100

    @property
    def winrate_opp(self):
        return self.opp_wins / self.num_sims * 100

    def __repr__(self):
        return f'AI Wins : {self.AI_wins}'


class AgentZeroCompetitive:
    def __init__(self, config, net=None, mcts=False):
        self.config = config
        self.net = net
        self.config.num_sampling_moves = 0
        self.config.root_exploration_fraction = 0
        self.mcts = mcts

    def move(self, game, eval=False):
        '''if eval == True:
            action_M, root = run_mcts(self.config, game, self.net)
            game.store_search_statistics(root)
            print(f'MCTS Action: {action_M}')'''
        if self.mcts:
            action, root = run_mcts(self.config, game, self.net)
            game.store_search_statistics(root)
        else:
            img = game.make_image()
            val, prob = self.net.inference(np.expand_dims(img, axis=0))
            prob = [p if idx in game.legal_actions else 0 for idx, p in enumerate(np.squeeze(prob))]
            prob /= sum(prob)
            print(f'{val=}')
            print(prob)

            action = np.argmax(prob)
        game.apply(action)
        return game


In [4]:
bpath = r'/Users/timwu/models/AlphaZeroResNet/buffer.pkl'
bpath2 = r'/Users/timwu/Desktop/buffer_before_epsilon.pkl'

In [5]:
import pickle
with open(bpath, 'rb') as f:
    b1 = pickle.load(f)
with open(bpath2, 'rb') as f:
    b2 = pickle.load(f)

In [6]:
len(b1.buffer)

3689

In [7]:
len(b2.buffer)

2000

In [8]:
lst = b1.buffer + b2.buffer

In [14]:
import random
idx = np.random.permutation(len(lst))
lst = np.array(lst)

array([  18, 5662, 3559, ..., 4568, 5203, 5554])

In [351]:
class AlphaZero:
    def __init__(self, config, replay_buffer: ReplayBuffer, net):
        self.config = config
        self.replay_buffer = replay_buffer
        self.net = net
        self.learning_rate = self.config.learning_rate
        self.total_epochs = 0

    def run_selfplay(self, verbose=False):
        game = Game()
        while not game.terminal:
            action, root = run_mcts(self.config, game, self.net)
            game.apply(action)
            game.store_search_statistics(root)
        if verbose:
            print(game)
        self.replay_buffer.save_game(game)
        return game


    def train_network(self):

        for i in range(self.config.batches_per_iter):
            # get batch of data from replay_buffer
            batch = self.replay_buffer.sample_batch()
            self.update_weights(self.net.model, batch)

    def update_weights(self, network, batch):
        # compile network with most recent learning rate
        network.compile(loss=[tf.nn.softmax_cross_entropy_with_logits, 'mean_squared_error'],
                        optimizer=tf.keras.optimizers.Adam(lr=self.learning_rate))

        images, target_v, target_p = batch

        network.fit(x=images, y=[target_v, target_p], epochs=self.config.epochs_per_batch)

        return

    def train_ep(self, num_self_play_games=None):
        if num_self_play_games is None:
            num_self_play_games = self.config.self_play_games
        for i in range(num_self_play_games):
            print(f'SelfPlay Game: {i + 1}')
            if i % 10 == 0:
                self.run_selfplay(verbose=True)
            else:
                self.run_selfplay()  # self play games added to replay buffer

        self.train_network()


In [356]:
agent = AgentZeroCompetitive(AlphaZeroConfig(),
                             net,
                             mcts=True)

In [357]:
opp = SmartRandomAgent(dumb=True)
pit = Pit(agent, opp, num_sims=20)
pit.simulate()

In [354]:
pit

AI Wins : 12

In [358]:
pit

AI Wins : 11

In [325]:
np.random.randint(1)

0

In [270]:
g.history[:28]

[4,
 4,
 4,
 5,
 2,
 2,
 2,
 6,
 6,
 6,
 6,
 6,
 6,
 5,
 5,
 5,
 5,
 5,
 4,
 4,
 4,
 3,
 3,
 3,
 3,
 3,
 3,
 2]

In [275]:
g.make_image(28)[:,:,0]

array([[ 0.,  0.,  0.,  1.,  1., -1.,  1.],
       [ 0.,  0.,  0., -1., -1.,  1., -1.],
       [ 0.,  0., -1.,  1.,  1., -1.,  1.],
       [ 0.,  0.,  1., -1.,  1.,  1., -1.],
       [ 0.,  0., -1.,  1., -1., -1.,  1.],
       [ 0.,  0.,  1., -1.,  1., -1., -1.]])

In [17]:
import tensorflow as tf
import numpy as np
from tensorflow_addons.optimizers.weight_decay_optimizers import DecoupledWeightDecayExtension

In [23]:
class NadamW(DecoupledWeightDecayExtension, tf.keras.optimizers.Nadam):
    def __init__(self, lr, weight_decay, *args, **kwargs):
        super(NadamW, self).__init__(lr, weight_decay, *args, **kwargs)

In [24]:
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
    [10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)

In [20]:
path = r'/Users/timwu/Desktop/traininglog'

In [40]:
with open(path,'r') as f:
    x = f.read()

In [41]:
x = [i.split('\nloss=') for i in x]

In [48]:
for i in x:
    if 

IndexError: list index out of range

In [26]:
x = [i.split(' - ')]

'2.2278 - value_head_'

In [29]:
x[6]

