In [None]:

from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Model, load_model, clone_model
from numpy import power, arctanh, array, float32
import numpy as np
from time import time
from random import sample, choice, random
from numpy import array, float32, rot90

from math import ceil
from numpy import flip


In [None]:


WALL = 1.0
MY_HEAD = -1.0
# mutipliers
HEALTH_m = 0.01
SNAKE_m = 0.02
HEAD_m = 0.04

class Game:
    
    def __init__(self, ID, height = 11, width = 11, snake_cnt = 4, health_dec = 1, food_spawn_chance = 0.15):
        self.id = ID
        self.height = height
        self.width = width
        self.snake_cnt = snake_cnt
        self.health_dec = health_dec
        self.food_spawn_chance = food_spawn_chance
        self.rewards = [None]*snake_cnt
        
        # standard starting board positions (in order) for 7x7, 11x11, and 19x19
        # battlesnake uses random positions for any non-standard board size
        # https://github.com/BattlesnakeOfficial/engine/blob/master/rules/create.go
        positions = sample(((1, 1), (height - 2, width - 2),
                            (height - 2, 1), (1, width - 2),
                            (1, width//2), (height//2, width - 2),
                            (height - 2, width//2), (height//2, 1)),
                           snake_cnt)
        self.last_moves = {i: choice((0, 1, 2, 3)) for i in range(snake_cnt)}
        
        # I changed the data structure to speed up the game
        # empty_positions is used to generate food randomly
        self.empty_positions = {(y, x) for y in range(height) for x in range(width)}

        # place snakes
        self.snakes = [Snake(ID, 100, [positions[ID]] * 3) for ID in range(snake_cnt)]
        for snake in self.snakes:
            self.empty_positions.remove(snake.head.position)

        # place food
        # one food at the center and one 2-step-away food for each snake
        self.food = {(height//2, width//2)}
        for snake in self.snakes:
            head_y, head_x = snake.head.position
            temp_a = [(head_y - 1, head_x - 1), (head_y - 1, head_x + 1), (head_y + 1, head_x - 1), (head_y + 1, head_x + 1)]
            self.food.add(choice(temp_a))
        for food in self.food:
            self.empty_positions.remove(food)
        
        # two board sets are used to reduce run time
        self.heads = {snake.head.position: {snake} for snake in self.snakes}
        self.bodies = {body for snake in self.snakes for body in snake}
        
        # log
        self.wall_collision = 0
        self.body_collision = 0
        self.head_collision = 0
        self.starvation = 0
        self.food_eaten = 0
        self.game_length = 0

    """
    Get states for all snakes
    Return:
        a list of states for each snake
    """
    def get_states(self):
        return [self.make_state(snake, self.last_moves[snake.id]) for snake in self.snakes]
    
    """
    Get game and snake ids
    Return:
        a list of pairs (this game's id, snake's id)
    """
    def get_ids(self):
        return [(self.id, snake.id) for snake in self.snakes]
    
    """
    Move the game to the next turn
    Args:
        moves: a list of moves the snakes will make
        show: weather to draw the game board in "replay.rep"
    Return:
        0 if the game continues or the rewards list if the game ends
    """
    def tic(self, moves, show = False):
        snakes = self.snakes
        # execute moves
        for i in range(len(snakes)):
            snake = snakes[i]
            move = (moves[i] + self.last_moves[snake.id] - 1) % 4
            self.last_moves[snake.id] = move
            # make the move
            new_head, old_head, tail = snake.move(move)
            # update board sets
            try:
                # several heads might come to the same cell
                self.heads[new_head].add(snake)
            except KeyError:
                self.heads[new_head] = {snake}
                # if it goes into an empty cell
                if new_head in self.empty_positions:
                    self.empty_positions.remove(new_head)
            if len(self.heads[old_head]) == 1:
                del self.heads[old_head]
            else:
                self.heads[old_head].remove(snake)
            self.bodies.add(old_head)
            if tail:
                self.bodies.remove(tail)
                # no one enters this cell
                if tail not in self.heads:
                    self.empty_positions.add(tail)
        
        # reduce health
        for snake in snakes:
            snake.health -= self.health_dec
        
        # check for food eaten
        for snake in snakes:
            if snake.head.position in self.food:
                food = snake.head.position
                self.food.remove(food)
                snake.health = 100
                snake.grow()
                self.food_eaten += 1
        
        # spawn food
        if self.food_spawn_chance > 0.0:
            if len(self.food) == 0 or random() <= self.food_spawn_chance:
                try:
                    food = choice(tuple(self.empty_positions))
                    self.food.add(food)
                    self.empty_positions.remove(food)
                except IndexError:
                    # Cannot choose from an empty set
                    pass
        
        if show:
            self.draw()
        
        # remove dead snakes
        kills = set()
        for snake in snakes:
            head = snake.head.position
            # check for wall collisions
            if head[0] < 0 or head[0] >= self.height or head[1] < 0 or head[1] >= self.width:
                kills.add(snake)
                self.wall_collision += 1
            # check for body collisions
            elif head in self.bodies:
                kills.add(snake)
                self.body_collision += 1
            # check for head on collisions
            elif len(self.heads[head]) > 1:
                for s in self.heads[head]:
                    if snake.length <= s.length and s != snake:
                        kills.add(snake)
                        self.head_collision += 1
                        break
            # check for starvation
            elif snake.health <= 0:
                kills.add(snake)
                self.starvation += 1
        # remove from snakes set
        for snake in kills:
            # update board sets
            head = snake.head.position
            if len(self.heads[head]) == 1:
                del self.heads[head]
                # it might die due to starvation or equal-length head on collision
                # only in those two cases, the head position should become an empty space
                # not out of bound and not into a body and not into a food
                if head[0] >= 0 and head[0] < self.height and head[1] >= 0 and head[1] < self.width:
                    # head is in range
                    if head not in self.bodies and head not in self.food:
                        self.empty_positions.add(head)
            else:
                self.heads[head].remove(snake)
            for body in snake:
                # it is possible that a snake has eaten on its first move and then die on its second move
                # in that case the snake will have a repeated tail
                # removing it from bodies twice causes an error
                # tried to debug this one for 5 hours and finally got it
                try:
                    self.bodies.remove(body)
                    self.empty_positions.add(body)
                except KeyError:
                    pass
            snakes.remove(snake)
            self.rewards[snake.id] = -1.0
        
        if show:
            self.draw()
        
        self.game_length += 1
        # return rewards if the game ends
        if len(snakes) <= 1:
            if snakes:
                self.rewards[snakes[0].id] = 1.0
            return self.rewards
        # return 0 if the game continues
        else:
            return 0
    
    """
    Process the game data and translate them into a game state
    Args:
        you: a Snake object that represents this snake
        last_move: the last move you made; one of {0, 1, 2, 3}
    Return:
        a grid that represents the game state for a snake
    """
    def make_state(self, you, last_move):
        # gotta do the math to recenter the grid
        width = self.width * 2 - 1
        height = self.height * 2 - 1
        grid = [[[0.0, WALL, 0.0] for col in range(width)] for row in range(height)]
        center_y = height//2
        center_x = width//2
        # the original game board
        # it's easier to work on the original board then transfer it onto the grid
        board = [[[0.0, 0.0, 0.0] for col in range(self.width)] for row in range(self.height)]
        
        # positions are (y, x) not (x, y)
        # because you read the grid row by row, i.e. (row number, column number)
        # otherwise the board is transposed
        length_minus_half = you.length - 0.5
        for snake in self.snakes:
            # get the head
            board[snake.head.position[0]][snake.head.position[1]][0] = (snake.length - length_minus_half)*HEAD_m
            # get the body
            # the head is also counted as a body for the making of the state because it will be a body next turn
            # going backwards because there could be a repeated tail when snake eats food
            body = snake.tail
            dist = 1
            while body:
                board[body.position[0]][body.position[1]][1] = dist*SNAKE_m
                body = body.prev_node
                dist += 1
        
        for food in self.food:
            board[food[0]][food[1]][2] = (101 - you.health)*HEALTH_m
        
        # from this point, all positions are measured relative to our head
        head_y, head_x = you.head.position
        board[head_y][head_x] = [MY_HEAD]*3
        for y in range(self.height):
            for x in range(self.width):
                grid[y - head_y + center_y][x - head_x + center_x] = board[y][x]
        
        # k = 0 => identity
        # k = 1 => rotate left
        # k = 2 => rotate 180
        # k = 3 => rotate right
        return rot90(array(grid, dtype = float32), k = last_move)

    """
    Copy the game
    Args:
        subgame_id: the game id of the created copy
    Return:
        a deep copy of the game
    """
    def subgame(self, subgame_id):
        # subgames don't spawn food
        game = Game(subgame_id, self.height, self.width, self.snake_cnt, self.health_dec, 0.0)
        game.last_moves = {i: self.last_moves[i] for i in range(self.snake_cnt)}
        game.empty_positions = {(yx[0], yx[1]) for yx in self.empty_positions}
        game.snakes = [snake.copy() for snake in self.snakes]
        game.food = {(yx[0], yx[1]) for yx in self.food}
        game.heads = {snake.head.position: {snake} for snake in game.snakes}
        game.bodies = {body for snake in game.snakes for body in snake}
        game.rewards = self.rewards[:]
        return game

    """
    Draw the game board in "replay.rep"
    """
    def draw(self):
        board = [[0] * self.width for _ in range(self.height)]
        
        for food in self.food:
            board[food[0]][food[1]] = 9
        
        for snake in sorted(self.snakes, key = lambda s: s.length):
            # head might go out of bound
            head_y, head_x = snake.head.position
            if head_y >= 0 and head_y < self.height and head_x >= 0 and head_x < self.width:
                board[snake.head.position[0]][snake.head.position[1]] = -(snake.id + 1)
        for snake in self.snakes:
            for body in snake:
                board[body[0]][body[1]] = snake.id + 1
        
        f = open("replay.rep", 'a')
        for row in board:
            f.write(str(row) + '\n')
        f.write('\n')
        f.close()

class Snake:
    
    def __init__(self, ID, health, head_and_body):
        self.id = ID
        self.health = health
        self.length = len(head_and_body)
        self.head = Node(head_and_body[0])
        self.tail = self.head
        for i in range(1, len(head_and_body)):
            new_node = Node(head_and_body[i])
            new_node.prev_node = self.tail
            self.tail.next_node = new_node
            self.tail = new_node
    
    # iterate through the body's position (not including the head)
    def __iter__(self):
        self.curr = self.head.next_node
        return self
    
    def __next__(self):
        if self.curr:
            position = self.curr.position
            self.curr = self.curr.next_node
            return position
        else:
            raise StopIteration
    
    def move(self, direction):
        if direction == 0:   # up
            y = self.head.position[0] - 1
            x = self.head.position[1]
        elif direction == 1: # right
            y = self.head.position[0]
            x = self.head.position[1] + 1
        elif direction == 2: # down
            y = self.head.position[0] + 1
            x = self.head.position[1]
        # if direction == 3: # left
        else:
            y = self.head.position[0]
            x = self.head.position[1] - 1
        new_head = Node((y, x))
        new_head.next_node = self.head
        self.head.prev_node = new_head
        old_head = self.head
        self.head = new_head
        old_tail = self.tail
        self.tail = self.tail.prev_node
        self.tail.next_node = None
        
        # return the new head, the old head and the removed tail
        # tells the Game how to up date the board sets
        # don't remove the tail if it is on top of another body
        if old_tail.position == self.tail.position:
            old_tail.position = None
        
        return (new_head.position, old_head.position, old_tail.position)
    
    def grow(self):
        self.length += 1
        new_tail = Node(self.tail.position)
        new_tail.prev_node = self.tail
        self.tail.next_node = new_tail
        self.tail = new_tail
    
    def copy(self):
        snake = Snake(self.id, self.health, [(0, 0)])
        snake.length = self.length
        snake.head = Node(self.head.position)
        snake.tail = snake.head
        curr = self.head.next_node
        while curr:
            new_node = Node(curr.position)
            new_node.prev_node = snake.tail
            snake.tail.next_node = new_node
            snake.tail = new_node
            curr = curr.next_node
        return snake

class Node:
    
    def __init__(self, yx):
        self.position = yx
        self.prev_node = None
        self.next_node = None

In [None]:


class MPGameRunner:
    
    def __init__(self, height = 11, width = 11, snake_cnt = 4, health_dec = 1, game_cnt = 1):
        self.height = height
        self.width = width
        self.snake_cnt = snake_cnt
        self.health_dec = health_dec
        self.game_cnt = game_cnt
        self.games = {ID: Game(ID, height, width, snake_cnt, health_dec) for ID in range(game_cnt)}
        # log
        self.wall_collision = 0
        self.body_collision = 0
        self.head_collision = 0
        self.starvation = 0
        self.food_eaten = 0
        self.game_length = 0
    
    # Alice is the agent
    def run(self, Alice):
        t0 = time()
        games = self.games
        show = self.game_cnt == 1
        rewards = [None]*self.game_cnt
        
        # run all the games in parallel
        turn = 0
        while games:
            turn += 1
            # print information
            if len(games) == 1:
                print("Running the root game. On turn", str(turn) + "...")
            else:
                print("Concurrently running", len(games), "root games. On turn", str(turn) + "...")
            
            # ask for moves from the Agent
            ids = []
            for game_id in games:
                ids += games[game_id].get_ids()
            moves = Alice.make_moves(games, ids)
            moves_for_game = {game_id: [] for game_id in games}
            for i in range(len(moves)):
                moves_for_game[ids[i][0]].append(moves[i])
            
            # tic all games
            kills = set()
            print(kills)
            for game_id in games:
                game = games[game_id]
                result = game.tic(moves_for_game[game_id], show)
                # if game ended
                if result != 0:
                    # log
                    self.wall_collision += game.wall_collision
                    self.body_collision += game.body_collision
                    self.head_collision += game.head_collision
                    self.starvation += game.starvation
                    self.food_eaten += game.food_eaten
                    self.game_length += game.game_length
                    rewards[game_id] = result
                    kills.add(game_id)
            # remove games that ended
            for game_id in kills:
                del games[game_id]
            
            print("Root game turn", str(turn), "finished. Total time spent:", time() - t0, end = "\n\n")
        
        # log
        self.wall_collision /= self.game_cnt
        self.body_collision /= self.game_cnt
        self.head_collision /= self.game_cnt
        self.starvation /= self.game_cnt
        self.food_eaten /= self.game_cnt
        self.game_length /= self.game_cnt
        return rewards

class MCTSMPGameRunner(MPGameRunner):

    def __init__(self, games):
        self.games = games
    
    # MCTSAlice is the agent
    def run(self, MCTSAlice, MCTS_depth):
        games = self.games
        rewards = {game_id: None for game_id in games}
        
        # run all the games in parallel
        turn = 0
        while games:
            turn += 1
            # ask for moves from the Agent
            ids = []
            for game_id in games:
                ids += games[game_id].get_ids()
            moves = MCTSAlice.make_moves(games, ids)
            moves_for_game = {game_id: [] for game_id in games}
            for i in range(len(moves)):
                moves_for_game[ids[i][0]].append(moves[i])
            
            # tic all games
            kills = set()
            for game_id in games:
                game = games[game_id]
                result = game.tic(moves_for_game[game_id])
                # if game ended or MCTS subgame max length reached
                if result != 0 or turn >= MCTS_depth[game_id]:
                    rewards[game_id] = game.rewards
                    kills.add(game_id)
            # remove games that ended
            for game_id in kills:
                del games[game_id]
        
        return rewards

In [None]:



class Agent:
    
    def __init__(self, nnet, softmax_base = 100, training = False,
                 max_MCTS_depth = 8, max_MCTS_breadth = 128):
        self.nnet = nnet
        self.softmax_base = softmax_base
        self.training = training
        self.max_MCTS_depth = max_MCTS_depth
        self.max_MCTS_breadth = max_MCTS_breadth
        self.cached_values = {}
        self.total_rewards = {}
        self.visit_cnts = {}
        self.cache_hit = {}
        # record data for training
        if training:
            self.records = []
            self.values = []
    
    def make_moves(self, games, ids):
        cached_values = self.cached_values
        total_rewards = self.total_rewards
        visit_cnts = self.visit_cnts
        cache_hit = self.cache_hit
        for key in cache_hit:
            cache_hit[key] += 1
        parallel = 8
        if self.max_MCTS_breadth < parallel:
            parallel = self.max_MCTS_breadth
        
        # MCTS
        for _ in range(self.max_MCTS_breadth//parallel):
            # make a subgame for each game
            MCTS_depth = {}
            parent_games = {}
            subgames = {}
            subgame_id = 0
            for game_id in games:
                # calculate the max MCTS depth for each game
                depth = self.max_MCTS_depth - 2*(len(games[game_id].snakes) - 2)
                for _ in range(parallel):
                    MCTS_depth[subgame_id] = depth
                    parent_games[subgame_id] = game_id
                    subgames[subgame_id] = games[game_id].subgame(subgame_id)
                    subgame_id += 1
            # run MCTS subgames
            MCTSAlice = MCTSAgent(self.nnet, self.softmax_base, subgames,
                                  cached_values, total_rewards, visit_cnts, cache_hit)
            MCTS = MCTSMPGameRunner(subgames)
            t0 = time()
            rewards = MCTS.run(MCTSAlice, MCTS_depth)
            if self.training:
                print("MCTS epoch finished. Time spent:", time() - t0)
            
            # update edge stats
            for subgame_id in MCTSAlice.keys:
                for snake_id in MCTSAlice.keys[subgame_id]:
                    my_keys = MCTSAlice.keys[subgame_id][snake_id]
                    my_moves = MCTSAlice.moves[subgame_id][snake_id]
                    if not rewards[subgame_id][snake_id] is None:
                        # back up
                        for i in range(len(my_keys) - 1, -1, -1):
                            key = my_keys[i]
                            move = my_moves[i]
                            visit_cnts[key][move] += 1.0
                            total_rewards[key][move] += rewards[subgame_id][snake_id]
                            cached_values[key][move] = total_rewards[key][move]/visit_cnts[key][move]
        
        V = [None]*len(ids)
        # store the index of the value in V a (game_id, snake_id) coresponds to
        value_index = {}
        for i in range(len(ids)):
            try:
                value_index[ids[i][0]][ids[i][1]] = i
            except KeyError:
                value_index[ids[i][0]] = {ids[i][1]: i}
        # set Q values based on the subgames' stats
        for subgame_id in MCTSAlice.keys:
            game_id = parent_games[subgame_id]
            for snake_id in MCTSAlice.keys[subgame_id]:
                first_key = MCTSAlice.keys[subgame_id][snake_id][0]
                V[value_index[game_id][snake_id]] = cached_values[first_key]
        
        # make moves
        if self.training:
            pmfs = [self.softermax(v) for v in V]
            moves = [np.random.choice([0, 1, 2], p = pmf) for pmf in pmfs]
            states = []
            for game_id in games:
                states += games[game_id].get_states()
            self.records += states
            self.values += V
        else:
            moves = self.argmaxs(V)
        
        # RAM recycle
        kills = set()
        for key in cache_hit:
            if cache_hit[key] > self.max_MCTS_depth:
                kills.add(key)
        for key in kills:
            del cached_values[key]
            del total_rewards[key]
            del visit_cnts[key]
            del cache_hit[key]
        return moves
    
    # a softmax function with customized base
    def softermax(self, z):
        # the higher the base is, the more it highlights the higher ones
        normalized = power(self.softmax_base, arctanh(z))
        # in case all three cells are obstacles (-1.0), softmax will fail on 0.0/0.0
        sigma = sum(normalized)
        if sigma == 0.0:
            return array([1.0/3.0]*3, dtype = float32)
        else:
            return normalized/sigma
    
    def argmaxs(self, Z):
        argmaxs = [-1] * len(Z)
        for i in range(len(Z)):
            if Z[i][0] > Z[i][1]:
                if Z[i][0] > Z[i][2]:
                    argmaxs[i] = 0
                else:
                    argmaxs[i] = 2
            else:
                if Z[i][1] > Z[i][2]:
                    argmaxs[i] = 1
                else:
                    argmaxs[i] = 2
        return argmaxs
    
    # clear memory
    def clear(self):
        self.cached_values = {}
        self.total_rewards = {}
        self.visit_cnts = {}
        self.cache_hit = {}
        if self.training:
            self.records = []
            self.values = []

class MCTSAgent(Agent):
    
    def __init__(self, nnet, softmax_base, games, cached_values, total_rewards, visit_cnts, cache_hit):
        self.nnet = nnet
        self.softmax_base = softmax_base
        self.cached_values = cached_values
        self.total_rewards = total_rewards
        self.visit_cnts = visit_cnts
        self.cache_hit = cache_hit
        self.keys = {i: {s.id: [] for s in games[i].snakes} for i in games}
        self.moves = {i: {s.id: [] for s in games[i].snakes} for i in games}
    
    def make_moves(self, games, ids):
        cached_values = self.cached_values
        total_rewards = self.total_rewards
        visit_cnts = self.visit_cnts
        cache_hit = self.cache_hit
        V = [None]*len(ids)
        keys = [None]*len(ids)
        all_states = []
        
        # get states without duplicates
        i = 0
        for game_id in games:
            states = games[game_id].get_states()
            for state in states:
                key = state.tostring()
                keys[i] = key
                try:
                    cache = cached_values[key]
                    if not cache is None:
                        V[i] = cache
                except KeyError:
                    all_states.append(state)
                    # a new state to be stored
                    cached_values[key] = None
                cache_hit[key] = 0
                i += 1
        
        # calculate values using the net
        if all_states:
            calculated_V = self.nnet.v(all_states)
            # assign values calculated by the net and store them into the cache
            j = 0
            for i in range(len(V)):
                if V[i] is None:
                    if cached_values[keys[i]] is None:
                        # the calculated Q values will be a prior
                        total_rewards[keys[i]] = calculated_V[j]
                        visit_cnts[keys[i]] = array([1.0]*3, dtype = float32)
                        cached_values[keys[i]] = total_rewards[keys[i]]/visit_cnts[keys[i]]
                        j += 1
                    V[i] = cached_values[keys[i]]
        
        # make randomized moves
        pmfs = [self.softermax(v) for v in V]
        moves = [np.random.choice([0, 1, 2], p = pmf) for pmf in pmfs]
        
        # update MCTS edge stats and record state keys and moves
        for i in range(len(ids)):
            game_id = ids[i][0]
            snake_id = ids[i][1]
            my_keys = self.keys[game_id][snake_id]
            my_moves = self.moves[game_id][snake_id]
            # back up
            estimated_reward = pmfs[i]@V[i]
            for j in range(len(my_keys) - 1, -1, -1):
                key = my_keys[j]
                move = my_moves[j]
                visit_cnts[key][move] += 1.0
                total_rewards[key][move] += estimated_reward
                cached_values[key][move] = total_rewards[key][move]/visit_cnts[key][move]
            my_keys.append(keys[i])
            my_moves.append(moves[i])
        return moves

In [None]:


class AlphaNNet:
    
    def __init__(self, model_name = None, input_shape = None):
        if model_name:
            self.v_net = load_model(model_name)
        elif input_shape:
            # regularization constant
            c = 1e-5
            # number of filters
            k = 128
            
            X = Input(input_shape)
            
            H = Conv2D(k, (3, 3), padding = "same", use_bias = False, kernel_regularizer = l2(c))(X)
            H = Activation('relu')(BatchNormalization(axis = 3)(H))
            
            # a residual block
            H_shortcut = H
            H = Conv2D(k, (3, 3), padding = "same", use_bias = False, kernel_regularizer = l2(c))(H)
            H = Activation('relu')(BatchNormalization(axis = 3)(H))
            H = Conv2D(k, (3, 3), padding = "same", use_bias = False, kernel_regularizer = l2(c))(H)
            H = Activation('relu')(Add()([BatchNormalization(axis = 3)(H), H_shortcut]))
            
            H_shortcut = H
            H = Conv2D(k, (3, 3), padding = "same", use_bias = False, kernel_regularizer = l2(c))(H)
            H = Activation('relu')(BatchNormalization(axis = 3)(H))
            H = Conv2D(k, (3, 3), padding = "same", use_bias = False, kernel_regularizer = l2(c))(H)
            H = Activation('relu')(Add()([BatchNormalization(axis = 3)(H), H_shortcut]))

            H_shortcut = H
            H = Conv2D(k, (3, 3), padding = "same", use_bias = False, kernel_regularizer = l2(c))(H)
            H = Activation('relu')(BatchNormalization(axis = 3)(H))
            H = Conv2D(k, (3, 3), padding = "same", use_bias = False, kernel_regularizer = l2(c))(H)
            H = Activation('relu')(Add()([BatchNormalization(axis = 3)(H), H_shortcut]))
            
            H = Conv2D(1, (1, 1), use_bias = False, kernel_regularizer = l2(c))(H)
            H = Activation('relu')(BatchNormalization(axis = 3)(H))
            
            H = Activation('relu')(Dense(k, kernel_regularizer = l2(c))(Flatten()(H)))
            
            Y = Activation('tanh')(Dense(3, kernel_regularizer = l2(c))(H))
            
            self.v_net = Model(inputs = X, outputs = Y)
    
    def train(self, X, Y, epochs = 10, batch_size = 2048):
        self.v_net.fit(array(X), array(Y), epochs = epochs, batch_size = batch_size)
    
    def v(self, X):
        V = self.v_net.predict(array(X))
        center_y = len(X[0])//2
        center_x = len(X[0][0])//2
        for i in range(len(X)):
            # assign -1.0 to known obstacles
            if self.is_obstacle(X[i][center_y][center_x - 1][1]):
                V[i][0] = -1.0
            if self.is_obstacle(X[i][center_y - 1][center_x][1]):
                V[i][1] = -1.0
            if self.is_obstacle(X[i][center_y][center_x + 1][1]):
                V[i][2] = -1.0
        return V
    
    def is_obstacle(self, value):
        return value >= 0.04
    
    def copy_and_compile(self, learning_rate = 0.0001, TPU = None):
        boundaries = [20, 40, 60, 80, 100]
        values = [0.0]*(len(boundaries) + 1)
        n = learning_rate
        for i in range(len(boundaries)):
            values[i] = n
            n *= 0.25
        if TPU:
            with TPU.scope():
                nnet_copy = AlphaNNet()
                nnet_copy.v_net = clone_model(self.v_net)
                nnet_copy.v_net.build(self.v_net.layers[0].input_shape)
                nnet_copy.v_net.set_weights(self.v_net.get_weights())
                lr = schedules.PiecewiseConstantDecay(boundaries, values)
                nnet_copy.v_net.compile(
                    optimizer = Adam(learning_rate = lr),
                    loss = 'mean_squared_error'
                )
        else:
            nnet_copy = AlphaNNet()
            nnet_copy.v_net = clone_model(self.v_net)
            nnet_copy.v_net.build(self.v_net.layers[0].input_shape)
            nnet_copy.v_net.set_weights(self.v_net.get_weights())
            lr = schedules.PiecewiseConstantDecay(boundaries, values)
            nnet_copy.v_net.compile(
                optimizer = Adam(learning_rate = lr),
                loss = 'mean_squared_error'
            )
        return nnet_copy
    
    def save(self, name):
        self.v_net.save('/content/drive/MyDrive/Train_res/' + name + '.h5')

In [None]:



class AlphaSnakeZeroTrainer:
    
    def __init__(self,
                 self_play_games,
                 max_MCTS_depth,
                 max_MCTS_breadth,
                 learning_rate,
                 learning_rate_decay,
                 height = 11,
                 width = 11,
                 snake_cnt = 4,
                 TPU = None):
        
        self.self_play_games = self_play_games
        self.max_MCTS_depth = max_MCTS_depth
        self.max_MCTS_breadth = max_MCTS_breadth
        self.lr = learning_rate
        self.lr_decay = learning_rate_decay
        self.height = height
        self.width = width
        self.snake_cnt = snake_cnt
        self.TPU = TPU
    
    def train(self, nnet, name = "AlphaSnake", iteration = 0):
        nnet = nnet.copy_and_compile()
        # log
        if iteration == 0:
            f = open("log.csv", 'a')
            f.write("new model " + name + "\n")
            f.write("iteration, wall_collision, body_collision, head_collision, "
                    + "starvation, food_eaten, game_length\n")
            f.close()
        health_dec = 9
        while True:
            if iteration > 32:
                health_dec = 1
            elif iteration > 8:
                health_dec = 3
            # self play
            # for training, all snakes are played by the same agent
            print("\nSelf playing games...")
            # the second arg is the softmax base (a snake with lower base is more explorative)
            Alice = Agent(nnet, 2 + iteration, True, self.max_MCTS_depth, self.max_MCTS_breadth)
            gr = MPGameRunner(self.height, self.width, self.snake_cnt, health_dec, self.self_play_games)
            gr.run(Alice)
            # log
            log_list = [gr.wall_collision, gr.body_collision, gr.head_collision,
                        gr.starvation, gr.food_eaten, gr.game_length]
            log = str(iteration) + ', ' + ', '.join(map(str, log_list)) + '\n'
            f = open("log.csv", 'a')
            f.write(log)
            f.close()
            # collect training examples
            batch_size = 2048
            max_batches = len(Alice.records)//batch_size
            batches = 5
            if batches > max_batches:
                batches = max_batches
            samples = batch_size*batches
            if samples > len(Alice.records):
                batch_size = len(Alice.records)
                samples = batch_size
            indexs = sample(range(len(Alice.records)), samples)
            print("indexes" , len(indexs))
            X = [Alice.records[index] for index in indexs]
            V = [Alice.values[index] for index in indexs]
            Alice.clear()
            print("x", len(X))
            print("v", len(V))
            f = open("x.csv", 'a')
            for element in X:
              f.write(element)
            f.close()
            f = open("v.csv", 'a')
            for element in V:
              f.write(element)
            f.close()
            X += self.mirror_states(X)
            V += self.mirror_values(V)
            # training
            nnet = nnet.copy_and_compile(learning_rate = self.lr, TPU = self.TPU)
            t0 = time()
            nnet.train(X, V, batch_size = batch_size)
            print("Training time", time() - t0)
            nnet = nnet.copy_and_compile()
            # learning rate decay
            self.lr *= self.lr_decay
            X = None
            V = None
            # save the model
            iteration += 1
            print("\nSaving the model " + name + str(iteration) + "...")
            nnet.save( name + str(iteration))
    
    def mirror_states(self, states):
        # flip return a numpy.ndarray
        # need to return a list
        # otherwise X += does vector addition
        return list(flip(states, axis = 2))
    
    def mirror_values(self, values):
        return list(flip(values, axis = 1))

In [None]:
import tensorflow as tf


game_board_height = 11
game_board_width = 11
number_of_snakes = 4
self_play_games = 64 #64
max_MCTS_depth = 4
max_MCTS_breadth = 32 # 32
initial_learning_rate = 0.0001
learning_rate_decay = 0.98

try:
    # when running on Google Cloud
    # the TPU must be located in the same area as the CPU
    # pass the TPU's name

    # tpu_name = input("Enter the name of the Google Cloud TPU (Leave empty if not using a TPU):\n")
    Resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(Resolver)
    tf.tpu.experimental.initialize_tpu_system(Resolver)
    TPU = tf.distribute.experimental.TPUStrategy(Resolver)
    print("Google Cloud TPU online.")

except:
    print("Cannot find the Google Cloud TPU. Using the CPU.")
    TPU = None

name = input("Enter the model name (not including the generation number nor \".h5\"):\n")
start = int(input("Enter the starting generation (0 for creating a new model):\n"))
if start == 0:
    ANNet = AlphaNNet(input_shape = (game_board_height*2 - 1, game_board_width*2 - 1, 3))
    ANNet.save(name + "0")
    print("train from scratch")
else:
    ANNet = AlphaNNet(model_name = "/content/drive/MyDrive/Train_res/" + name + str(start) + ".h5")
    initial_learning_rate *= learning_rate_decay**start
Trainer = AlphaSnakeZeroTrainer(self_play_games, max_MCTS_depth, max_MCTS_breadth,
                                initial_learning_rate, learning_rate_decay,
                                game_board_height, game_board_width, number_of_snakes, TPU)
Trainer.train(ANNet, name = name, iteration = start)

Cannot find the Google Cloud TPU. Using the CPU.
