In [0]:
#!/usr/bin/env python

"""SnakeGame: A simple and fun exploration, meant to be used by Human and AI. - THIS
"""

import sys  # To close the window when the game is over
from array import array  # Efficient numeric arrays
from os import environ, path  # To center the game window the best possible
import random  # Random numbers used for the food
import logging  # Logging function for movements and errors
import json # For file handling (leaderboards)
from itertools import tee  # For the color gradient on snake

import numpy as np # Used in calculations and math

__author__ = "Victor Neves"
__license__ = "MIT"
__version__ = "1.0"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"

# Actions, options and forbidden moves
OPTIONS = {'QUIT': 0,
           'PLAY': 1,
           'BENCHMARK': 2,
           'LEADERBOARDS': 3,
           'MENU': 4,
           'ADD_TO_LEADERBOARDS': 5}
RELATIVE_ACTIONS = {'LEFT': 0,
                    'FORWARD': 1,
                    'RIGHT': 2}
ABSOLUTE_ACTIONS = {'LEFT': 0,
                    'RIGHT': 1,
                    'UP': 2,
                    'DOWN': 3,
                    'IDLE': 4}
FORBIDDEN_MOVES = [(0, 1), (1, 0), (2, 3), (3, 2)]

# Possible rewards in the game
REWARDS = {'MOVE': -0.005,
           'GAME_OVER': -1,
           'SCORED': 1}

# Types of point in the board
POINT_TYPE = {'EMPTY': 0,
              'FOOD': 1,
              'BODY': 2,
              'HEAD': 3,
              'DANGEROUS': 4}

# Speed levels possible to human players. MEGA HARDCORE starts with MEDIUM and
# increases with snake size
LEVELS = [" EASY ", " MEDIUM ", " HARD ", " MEGA HARDCORE "]
SPEEDS = {'EASY': 80,
          'MEDIUM': 60,
          'HARD': 40,
          'MEGA_HARDCORE': 65}

# Set the constant FPS limit for the game. Smoothness depend on this.
GAME_FPS = 100


class GlobalVariables:
    """Global variables to be used while drawing and moving the snake game.

    Attributes
    ----------
    board_size: int, optional, default = 30
        The size of the board.
    block_size: int, optional, default = 20
        The size in pixels of a block.
    head_color: tuple of 3 * int, optional, default = (42, 42, 42)
        Color of the head. Start of the body color gradient.
    tail_color: tuple of 3 * int, optional, default = (152, 152, 152)
        Color of the tail. End of the body color gradient.
    food_color: tuple of 3 * int, optional, default = (200, 0, 0)
        Color of the food.
    game_speed: int, optional, default = 10
        Speed in ticks of the game. The higher the faster.
    benchmark: int, optional, default = 10
        Ammount of matches to benchmark and possibly go to leaderboards.
    """
    def __init__(self, board_size = 30, block_size = 20,
                 head_color = (42, 42, 42), tail_color = (152, 152, 152),
                 food_color = (200, 0, 0), game_speed = 80, benchmark = 10):
        """Initialize all global variables. Updated with argument_handler."""
        self.board_size = board_size
        self.block_size = block_size
        self.head_color = head_color
        self.tail_color = tail_color
        self.food_color = food_color
        self.game_speed = game_speed
        self.benchmark = benchmark

        if self.board_size > 50: # Warn the user about performance
            LOGGER.warning('WARNING: BOARD IS TOO BIG, IT MAY RUN SLOWER.')

    @property
    def canvas_size(self):
        """Canvas size is updated with board_size and block_size."""
        return self.board_size * self.block_size

class TextBlock:
    """Block of text class, used by pygame. Can be used to both text and menu.

    Attributes:
    ----------
    text: string
        The text to be displayed.
    pos: tuple of 2 * int
        Color of the tail. End of the body color gradient.
    screen: pygame window object
        The screen where the text is drawn.
    scale: int, optional, default = 1 / 12
        Adaptive scale to resize if the board size changes.
    type: string, optional, default = "text"
        Assert whether the BlockText is a text or menu option.
    """
    def __init__(self, text, pos, screen, scale = (1 / 12), block_type = "text"):
        """Initialize, set position of the rectangle and render the text block."""
        self.block_type = block_type
        self.hovered = False
        self.text = text
        self.pos = pos
        self.screen = screen
        self.scale = scale
        self.set_rect()
        self.draw()

    def draw(self):
        """Set what to render and blit on the pygame screen."""
        self.set_rend()
        self.screen.blit(self.rend, self.rect)

    def set_rend(self):
        """Set what to render (font, colors, sizes)"""
        font = pygame.font.Font(resource_path("resources/fonts/freesansbold.ttf"),
                                int((VAR.canvas_size) * self.scale))
        self.rend = font.render(self.text, True, self.get_color(),
                                self.get_background())

    def get_color(self):
        """Get color to render for text and menu (hovered or not).

        Return
        ----------
        color: tuple of 3 * int
            The color that will be rendered for the text block.
        """
        color = pygame.Color(42, 42, 42)

        if self.block_type == "menu" and not self.hovered:
                color = pygame.Color(152, 152, 152)

        return color

    def get_background(self):
        """Get background color to render for text (hovered or not) and menu.

        Return
        ----------
        color: tuple of 3 * int
            The color that will be rendered for the background of the text block.
        """
        color = None

        if self.block_type == "menu" and self.hovered:
            color = pygame.Color(152, 152, 152)

        return color

    def set_rect(self):
        """Set the rectangle and it's position to draw on the screen."""
        self.set_rend()
        self.rect = self.rend.get_rect()
        self.rect.center = self.pos


class Snake:
    """Player (snake) class which initializes head, body and board.

    The body attribute represents a list of positions of the body, which are in-
    cremented when moving/eating on the position [0]. The orientation represents
    where the snake is looking at (head) and collisions happen when any element
    is superposed with the head.

    Attributes
    ----------
    head: list of 2 * int, default = [board_size / 4, board_size / 4]
        The head of the snake, located according to the board size.
    body: list of lists of 2 * int
        Starts with 3 parts and grows when food is eaten.
    previous_action: int, default = 1
        Last action which the snake took.
    length: int, default = 3
        Variable length of the snake, can increase when food is eaten.
    """
    def __init__(self):
        """Inits Snake with 3 body parts (one is the head) and pointing right"""
        self.head = [int(VAR.board_size / 4), int(VAR.board_size / 4)]
        self.body = [[self.head[0], self.head[1]],
                     [self.head[0] - 1, self.head[1]],
                     [self.head[0] - 2, self.head[1]]]
        self.previous_action = 1
        self.length = 3

    def is_movement_invalid(self, action):
        """Check if the movement is invalid, according to FORBIDDEN_MOVES."""
        valid = False

        if (action, self.previous_action) in FORBIDDEN_MOVES:
            valid = True

        return valid

    def move(self, action, food_pos):
        """According to orientation, move 1 block. If the head is not positioned
        on food, pop a body part. Else, return without popping.

        Return
        ----------
        ate_food: boolean
            Flag which represents whether the snake ate or not food.
        """
        ate_food = False

        if (action == ABSOLUTE_ACTIONS['IDLE'] or
            self.is_movement_invalid(action)):
            action = self.previous_action
        else:
            self.previous_action = action

        if action == ABSOLUTE_ACTIONS['LEFT']:
            self.head[0] -= 1
        elif action == ABSOLUTE_ACTIONS['RIGHT']:
            self.head[0] += 1
        elif action == ABSOLUTE_ACTIONS['UP']:
            self.head[1] -= 1
        elif action == ABSOLUTE_ACTIONS['DOWN']:
            self.head[1] += 1

        self.body.insert(0, list(self.head))

        if self.head == food_pos:
            LOGGER.info('EVENT: FOOD EATEN')
            self.length = len(self.body)

            ate_food = True
        else:
            self.body.pop()

        return ate_food


class FoodGenerator:
    """Generate and keep track of food.

    Attributes
    ----------
    pos:
        Current position of food.
    is_food_on_screen:
        Flag for existence of food.
    """
    def __init__(self, body):
        """Initialize a food piece and set existence flag."""
        self.is_food_on_screen = False
        self.pos = self.generate_food(body)

    def generate_food(self, body):
        """Generate food and verify if it's on a valid place.

        Return
        ----------
        pos: tuple of 2 * int
            Position of the food that was generated. It can't be in the body.
        """
        if not self.is_food_on_screen:
            while True:
                food = [int((VAR.board_size - 1) * random.random()),
                        int((VAR.board_size - 1) * random.random())]

                if food in body:
                    continue
                else:
                    self.pos = food
                    break

            LOGGER.info('EVENT: FOOD APPEARED')
            self.is_food_on_screen = True

        return self.pos


class Game:
    """Hold the game window and functions.

    Attributes
    ----------
    window: pygame display
        Pygame window to show the game.
    fps: pygame time clock
        Define Clock and ticks in which the game will be displayed.
    snake: object
        The actual snake who is going to be played.
    food_generator: object
        Generator of food which responds to the snake.
    food_pos: tuple of 2 * int
        Position of the food on the board.
    game_over: boolean
        Flag for game_over.
    player: string
        Define if human or robots are playing the game.
    board_size: int, optional, default = 30
        The size of the board.
    local_state: boolean, optional, default = False
        Whether to use or not game expertise (used mostly by robots players).
    relative_pos: boolean, optional, default = False
        Whether to use or not relative position of the snake head. Instead of
        actions, use relative_actions.
    screen_rect: tuple of 2 * int
        The screen rectangle, used to draw relatively positioned blocks.
    """
    def __init__(self, player, board_size = 30, local_state = False, relative_pos = False):
        """Initialize window, fps and score. Change nb_actions if relative_pos"""
        VAR.board_size = board_size
        self.local_state = local_state
        self.relative_pos = relative_pos
        self.player = player

        if player == "ROBOT":
            if self.relative_pos:
                self.nb_actions = 3
            else:
                self.nb_actions = 5

            self.reset_game()

    def reset_game(self):
        """Reset the game environment."""
        self.step = 0
        self.snake = Snake()
        self.food_generator = FoodGenerator(self.snake.body)
        self.food_pos = self.food_generator.pos
        self.scored = False
        self.game_over = False

    def create_window(self):
        """Create a pygame display with board_size * block_size dimension."""
        pygame.init()
        flags = pygame.DOUBLEBUF | pygame.HWSURFACE
        self.window = pygame.display.set_mode((VAR.canvas_size, VAR.canvas_size),
                                              flags)
        self.window.set_alpha(None)

        self.screen_rect = self.window.get_rect()
        self.fps = pygame.time.Clock()

    def cycle_menu(self, menu_options, list_menu, dictionary, img = None,
                   img_rect = None):
        """Cycle through a given menu, waiting for an option to be clicked."""
        selected = False
        selected_option = None

        while not selected:
            pygame.event.pump()
            events = pygame.event.get()

            self.window.fill(pygame.Color(225, 225, 225))

            for i, option in enumerate(menu_options):
                if option is not None:
                    option.draw()
                    option.hovered = False

                    if option.rect.collidepoint(pygame.mouse.get_pos()):
                        option.hovered = True

                        for event in events:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = dictionary[list_menu[i]]

            if selected_option is not None:
                selected = True
            if img is not None:
                self.window.blit(img, img_rect.bottomleft)

            pygame.display.update()

        return selected_option

    def cycle_matches(self, n_matches, mega_hardcore = False):
        """Cycle through matches until the end."""
        score = array('i')

        for _ in range(n_matches):
            self.reset_game()
            self.start_match(wait = 3)
            score.append(self.single_player(mega_hardcore))

        return score

    def menu(self):
        """Main menu of the game.

        Return
        ----------
        selected_option: int
            The selected option in the main loop.
        """
        pygame.display.set_caption("SNAKE GAME  | PLAY NOW!")

        img = pygame.image.load(resource_path("resources/images" +
                                              "/snake_logo.png")).convert()
        img = pygame.transform.scale(img, (VAR.canvas_size,
                                           int(VAR.canvas_size / 3)))
        img_rect = img.get_rect()
        img_rect.center = self.screen_rect.center
        list_menu = ['PLAY', 'BENCHMARK', 'LEADERBOARDS', 'QUIT']
        menu_options = [TextBlock(' PLAY GAME ',
                                  (self.screen_rect.centerx,
                                   4 * self.screen_rect.centery / 10),
                                  self.window, (1 / 12), "menu"),
                        TextBlock(' BENCHMARK ',
                                  (self.screen_rect.centerx,
                                   6 * self.screen_rect.centery / 10),
                                  self.window, (1 / 12), "menu"),
                        TextBlock(' LEADERBOARDS ',
                                  (self.screen_rect.centerx,
                                   8 * self.screen_rect.centery / 10),
                                  self.window, (1 / 12), "menu"),
                        TextBlock(' QUIT ',
                                  (self.screen_rect.centerx,
                                   10 * self.screen_rect.centery / 10),
                                  self.window, (1 / 12), "menu")]
        selected_option = self.cycle_menu(menu_options, list_menu, OPTIONS,
                                          img, img_rect)

        return selected_option

    def start_match(self, wait):
        """Create some wait time before the actual drawing of the game."""
        for i in range(wait):
            time = str(wait - i)
            self.window.fill(pygame.Color(225, 225, 225))

            # Game starts in 3, 2, 1
            text = [TextBlock('Game starts in',
                              (self.screen_rect.centerx,
                               4 * self.screen_rect.centery / 10),
                              self.window, (1 / 10), "text"),
                    TextBlock(time, (self.screen_rect.centerx,
                                     12 * self.screen_rect.centery / 10),
                              self.window, (1 / 1.5), "text")]

            for text_block in text:
                text_block.draw()

            pygame.display.update()
            pygame.display.set_caption("SNAKE GAME  |  Game starts in "
                                       + time + " second(s) ...")
            pygame.time.wait(1000)

        LOGGER.info('EVENT: GAME START')

    def start(self):
        """Use menu to select the option/game mode."""
        opt = self.menu()

        while True:
            if opt == OPTIONS['QUIT']:
                pygame.quit()
                sys.exit()
            elif opt == OPTIONS['PLAY']:
                VAR.game_speed, mega_hardcore = self.select_speed()
                score = self.cycle_matches(n_matches = 1,
                                           mega_hardcore = mega_hardcore)
                opt = self.over(score)
            elif opt == OPTIONS['BENCHMARK']:
                VAR.game_speed, mega_hardcore = self.select_speed()
                score = self.cycle_matches(n_matches = VAR.benchmark,
                                           mega_hardcore = mega_hardcore)
                opt = self.over(score)
            elif opt == OPTIONS['LEADERBOARDS']:
                self.view_leaderboards()
            elif opt == OPTIONS['MENU']:
                opt = self.menu()
            if opt == OPTIONS['ADD_TO_LEADERBOARDS']:
                self.add_to_leaderboards(score, None) # Gotta improve this logic.
                self.view_leaderboards()

    def over(self, score):
        """If collision with wall or body, end the game and open options.

        Return
        ----------
        selected_option: int
            The selected option in the main loop.
        """
        score_option = None

        if len(score) == VAR.benchmark:
            score_option = TextBlock(' ADD TO LEADERBOARDS ',
                                     (self.screen_rect.centerx,
                                      8 * self.screen_rect.centery / 10),
                                     self.window, (1 / 15), "menu")

        text_score = 'SCORE: ' + str(int(np.mean(score)))
        list_menu = ['PLAY', 'MENU', 'ADD_TO_LEADERBOARDS', 'QUIT']
        menu_options = [TextBlock(' PLAY AGAIN ', (self.screen_rect.centerx,
                                                   4 * self.screen_rect.centery / 10),
                                  self.window, (1 / 15), "menu"),
                        TextBlock(' GO TO MENU ', (self.screen_rect.centerx,
                                                   6 * self.screen_rect.centery / 10),
                                  self.window, (1 / 15), "menu"),
                        score_option,
                        TextBlock(' QUIT ', (self.screen_rect.centerx,
                                             10 * self.screen_rect.centery / 10),
                                  self.window, (1 / 15), "menu"),
                        TextBlock(text_score, (self.screen_rect.centerx,
                                               15 * self.screen_rect.centery / 10),
                                  self.window, (1 / 10), "text")]
        pygame.display.set_caption("SNAKE GAME  | " + text_score
                                   + "  |  GAME OVER...")
        LOGGER.info('EVENT: GAME OVER | FINAL %s', text_score)
        selected_option = self.cycle_menu(menu_options, list_menu, OPTIONS)

        return selected_option

    def select_speed(self):
        """Speed menu, right before calling start_match.

        Return
        ----------
        speed: int
            The selected speed in the main loop.
        """
        list_menu = ['EASY', 'MEDIUM', 'HARD', 'MEGA_HARDCORE']
        menu_options = [TextBlock(LEVELS[0], (self.screen_rect.centerx,
                                              4 * self.screen_rect.centery / 10),
                                  self.window, (1 / 10), "menu"),
                        TextBlock(LEVELS[1], (self.screen_rect.centerx,
                                              8 * self.screen_rect.centery / 10),
                                  self.window, (1 / 10), "menu"),
                        TextBlock(LEVELS[2], (self.screen_rect.centerx,
                                              12 * self.screen_rect.centery / 10),
                                  self.window, (1 / 10), "menu"),
                        TextBlock(LEVELS[3], (self.screen_rect.centerx,
                                              16 * self.screen_rect.centery / 10),
                                  self.window, (1 / 10), "menu")]

        speed = self.cycle_menu(menu_options, list_menu, SPEEDS)
        mega_hardcore = False

        if speed == SPEEDS['MEGA_HARDCORE']:
            mega_hardcore = True

        return speed, mega_hardcore

    def single_player(self, mega_hardcore = False):
        """Game loop for single_player (HUMANS).

        Return
        ----------
        score: int
            The final score for the match (discounted of initial length).
        """
        # The main loop, it pump key_presses and update the board every tick.
        previous_size = self.snake.length # Initial size of the snake
        current_size = previous_size # Initial size
        color_list = self.gradient([(42, 42, 42), (152, 152, 152)],
                                   previous_size)

        # Main loop, where snakes moves after elapsed time is bigger than the
        # move_wait time. The last_key pressed is recorded to make the game more
        # smooth for human players.
        elapsed = 0
        last_key = self.snake.previous_action
        move_wait = VAR.game_speed

        while not self.game_over:
            elapsed += self.fps.get_time()  # Get elapsed time since last call.

            if mega_hardcore:  # Progressive speed increments, the hardest.
                move_wait = VAR.game_speed - (2 * (self.snake.length - 3))

            key_input = self.handle_input()  # Receive inputs with tick.
            invalid_key = self.snake.is_movement_invalid(key_input)

            if key_input is not None and not invalid_key:
                last_key = key_input

            if elapsed >= move_wait:  # Move and redraw
                elapsed = 0
                self.play(last_key)
                current_size = self.snake.length  # Update the body size

                if current_size > previous_size:
                    color_list = self.gradient([(42, 42, 42), (152, 152, 152)],
                                               current_size)

                    previous_size = current_size

                self.draw(color_list)

            pygame.display.update()
            self.fps.tick(GAME_FPS)  # Limit FPS to 100

        score = current_size - 3  # After the game is over, record score

        return score

    def check_collision(self):
        """Check wether any collisions happened with the wall or body.

        Return
        ----------
        collided: boolean
            Whether the snake collided or not.
        """
        collided = False

        if self.snake.head[0] > (VAR.board_size - 1) or self.snake.head[0] < 0:
            LOGGER.info('EVENT: WALL COLLISION')
            collided = True
        elif self.snake.head[1] > (VAR.board_size - 1) or self.snake.head[1] < 0:
            LOGGER.info('EVENT: WALL COLLISION')
            collided = True
        elif self.snake.head in self.snake.body[1:]:
            LOGGER.info('EVENT: BODY COLLISION')
            collided = True

        return collided

    def is_won(self):
        """Verify if the score is greater than 0.

        Return
        ----------
        won: boolean
            Whether the score is greater than 0.
        """
        return self.snake.length > 3

    def generate_food(self):
        """Generate new food if needed.

        Return
        ----------
        food_pos: tuple of 2 * int
            Current position of the food.
        """
        food_pos = self.food_generator.generate_food(self.snake.body)

        return food_pos

    def handle_input(self):
        """After getting current pressed keys, handle important cases.

        Return
        ----------
        action: int
            Handle human input to assess the next action.
        """
        pygame.event.set_allowed([pygame.QUIT, pygame.KEYDOWN])
        keys = pygame.key.get_pressed()
        pygame.event.pump()
        action = None

        if keys[pygame.K_ESCAPE] or keys[pygame.K_q]:
            LOGGER.info('ACTION: KEY PRESSED: ESCAPE or Q')
            self.over(self.snake.length - 3)
        elif keys[pygame.K_LEFT]:
            LOGGER.info('ACTION: KEY PRESSED: LEFT')
            action = ABSOLUTE_ACTIONS['LEFT']
        elif keys[pygame.K_RIGHT]:
            LOGGER.info('ACTION: KEY PRESSED: RIGHT')
            action = ABSOLUTE_ACTIONS['RIGHT']
        elif keys[pygame.K_UP]:
            LOGGER.info('ACTION: KEY PRESSED: UP')
            action = ABSOLUTE_ACTIONS['UP']
        elif keys[pygame.K_DOWN]:
            LOGGER.info('ACTION: KEY PRESSED: DOWN')
            action = ABSOLUTE_ACTIONS['DOWN']

        return action

    def state(self):
        """Create a matrix of the current state of the game.

        Return
        ----------
        canvas: np.array of size board_size**2
            Return the current state of the game in a matrix.
        """
        canvas = np.zeros((VAR.board_size, VAR.board_size))

        if self.game_over:
            pass
        else:
            body = self.snake.body

            for part in body:
                canvas[part[0], part[1]] = POINT_TYPE['BODY']

            canvas[body[0][0], body[0][1]] = POINT_TYPE['HEAD']

            if self.local_state:
                canvas = self.eval_local_safety(canvas, body)

            canvas[self.food_pos[0], self.food_pos[1]] = POINT_TYPE['FOOD']

        return canvas

    def relative_to_absolute(self, action):
        """Translate relative actions to absolute.

        Return
        ----------
        action: int
            Translated action from relative to absolute.
        """
        if action == RELATIVE_ACTIONS['FORWARD']:
            action = self.snake.previous_action
        elif action == RELATIVE_ACTIONS['LEFT']:
            if self.snake.previous_action == ABSOLUTE_ACTIONS['LEFT']:
                action = ABSOLUTE_ACTIONS['DOWN']
            elif self.snake.previous_action == ABSOLUTE_ACTIONS['RIGHT']:
                action = ABSOLUTE_ACTIONS['UP']
            elif self.snake.previous_action == ABSOLUTE_ACTIONS['UP']:
                action = ABSOLUTE_ACTIONS['LEFT']
            else:
                action = ABSOLUTE_ACTIONS['RIGHT']
        else:
            if self.snake.previous_action == ABSOLUTE_ACTIONS['LEFT']:
                action = ABSOLUTE_ACTIONS['UP']
            elif self.snake.previous_action == ABSOLUTE_ACTIONS['RIGHT']:
                action = ABSOLUTE_ACTIONS['DOWN']
            elif self.snake.previous_action == ABSOLUTE_ACTIONS['UP']:
                action = ABSOLUTE_ACTIONS['RIGHT']
            else:
                action = ABSOLUTE_ACTIONS['LEFT']

        return action

    def play(self, action):
        """Move the snake to the direction, eat and check collision."""
        self.scored = False
        self.step += 1
        self.food_pos = self.generate_food()

        if self.relative_pos:
            action = self.relative_to_absolute(action)

        if self.snake.move(action, self.food_pos):
            self.scored = True
            self.food_generator.is_food_on_screen = False

        if self.player == "HUMAN":
            if self.check_collision():
                self.game_over = True
        elif self.check_collision() or self.step > 50 * self.snake.length:
            self.game_over = True

    def get_reward(self):
        """Return the current reward. Can be used as the reward function.

        Return
        ----------
        reward: float
            Current reward of the game.
        """
        reward = REWARDS['MOVE']

        if self.game_over:
            reward = REWARDS['GAME_OVER']
        elif self.scored:
            reward = self.snake.length

        return reward

    def draw(self, color_list):
        """Draw the game, the snake and the food using pygame."""
        self.window.fill(pygame.Color(225, 225, 225))

        for part, color in zip(self.snake.body, color_list):
            pygame.draw.rect(self.window, color, pygame.Rect((part[0] *
                        VAR.block_size), part[1] * VAR.block_size,
                        VAR.block_size, VAR.block_size))

        pygame.draw.rect(self.window, VAR.food_color,
                         pygame.Rect(self.food_pos[0] * VAR.block_size,
                         self.food_pos[1] * VAR.block_size, VAR.block_size,
                         VAR.block_size))

        pygame.display.set_caption("SNAKE GAME  |  Score: "
                                   + str(self.snake.length - 3))

    def get_name(self):
        """See test.py in my desktop, for a textbox input in pygame"""
        return None

    def add_to_leaderboards(self, score, step):
        file_path = resource_path("resources/scores.json")

        name = self.get_name()
        new_score = {'name': 'test',
                     'ranking_data': {'score': score,
                                      'step': step}}

        with open(file_path, 'w') as leaderboards_file:
            json.dump(new_score, leaderboards_file)

    def view_leaderboards(self):
        list_menu = ['MENU']
        menu_options = [TextBlock('LEADERBOARDS',
                                  (self.screen_rect.centerx,
                                   2 * self.screen_rect.centery / 10),
                                  self.window, (1 / 12), "text")]

        file_path = resource_path("resources/scores.json")

        with open(file_path, 'r') as leaderboards_file:
            scores_data = json.loads(leaderboards_file.read())

        scores_data.sort(key = operator.itemgetter('score'))

#        for score in formatted_scores:
#            menu_options.append(TextBlock(person_ranked,
#                                (self.screen_rect.centerx,
#                                10 * self.screen_rect.centery / 10),
#                                self.window, (1 / 12), "text"))

        menu_options.append(TextBlock('MENU',
                            (self.screen_rect.centerx,
                            10 * self.screen_rect.centery / 10),
                            self.window, (1 / 12), "menu"))
        selected_option = self.cycle_menu(menu_options, list_menu, OPTIONS)

    @staticmethod
    def format_scores(scores, ammount):
        scores = scores[-ammount:]



    @staticmethod
    def eval_local_safety(canvas, body):
        """Evaluate the safety of the head's possible next movements.

        Return
        ----------
        canvas: np.array of size board_size**2
            After using game expertise, change canvas values to DANGEROUS if true.
        """
        if ((body[0][0] + 1) > (VAR.board_size - 1)
            or ([body[0][0] + 1, body[0][1]]) in body[1:]):
            canvas[VAR.board_size - 1, 0] = POINT_TYPE['DANGEROUS']
        if (body[0][0] - 1) < 0 or ([body[0][0] - 1, body[0][1]]) in body[1:]:
            canvas[VAR.board_size - 1, 1] = POINT_TYPE['DANGEROUS']
        if (body[0][1] - 1) < 0 or ([body[0][0], body[0][1] - 1]) in body[1:]:
            canvas[VAR.board_size - 1, 2] = POINT_TYPE['DANGEROUS']
        if ((body[0][1] + 1) > (VAR.board_size - 1)
            or ([body[0][0], body[0][1] + 1]) in body[1:]):
            canvas[VAR.board_size - 1, 3] = POINT_TYPE['DANGEROUS']

        return canvas

    @staticmethod
    def gradient(colors, steps, components = 3):
        """Function to create RGB gradients given 2 colors and steps. If
        component is changed to 4, it does the same to RGBA colors.

        Return
        ----------
        result: list of steps length of tuple of 3 * int (if RGBA, 4 * int)
            List of colors of calculated gradient from start to end.
        """
        def linear_gradient(start, finish, substeps):
            yield start

            for substep in range(1, substeps):
                yield tuple([(start[component]
                              + (float(substep) / (substeps - 1))
                              * (finish[component] - start[component]))
                             for component in range(components)])

        def pairs(seq):
            first_color, second_color = tee(seq)
            next(second_color, None)

            return zip(first_color, second_color)

        result = []
        substeps = int(float(steps) / (len(colors) - 1))

        for first_color, second_color in pairs(colors):
            for gradient_color in linear_gradient(first_color, second_color,
                                                  substeps):
                result.append(gradient_color)

        return result


def resource_path(relative_path):
    """Function to return absolute paths. Used while creating .exe file."""
    if hasattr(sys, '_MEIPASS'):
        return path.join(sys._MEIPASS, relative_path)

    return path.join(path.dirname(path.realpath(__file__)), relative_path)

VAR = GlobalVariables() # Initializing GlobalVariables
LOGGER = logging.getLogger(__name__) # Setting logger
environ['SDL_VIDEO_CENTERED'] = '1' # Centering the window


"""THIS"""

import numpy as np

from random import sample, uniform

class ExperienceReplay:
    """The class that handles memory and experiences replay.

    Attributes:
        memory: memory array to insert experiences.
        memory_size: the ammount of experiences to be stored in the memory.
        input_shape: the shape of the input which will be stored.
        batch_function: returns targets according to S.
        per: flag for PER usage.
        per_epsilon: used to replace "0" probabilities cases.
        per_alpha: how much prioritization to use.
        per_beta: importance sampling weights (IS_weights).
    """
    def __init__(self, memory_size = 100, per = False, alpha = 0.6,
                 epsilon = 0.001, beta = 0.4, nb_epoch = 10000, decay = 0.5):
        """Initialize parameters and the memory array."""
        self.per = per
        self.memory_size = memory_size
        self.reset_memory() # Initiate the memory

        if self.per:
            self.per_epsilon = epsilon
            self.per_alpha = alpha
            self.per_beta = beta
            self.schedule = LinearSchedule(nb_epoch * decay, 1.0, beta)

    def exp_size(self):
        """Returns how much memory is stored."""
        if self.per:
            return self.exp
        else:
            return len(self.memory)

    def get_priority(self, errors):
        """Returns priority based on how much prioritization to use."""
        return (errors + self.per_epsilon) ** self.per_alpha

    def update(self, tree_indices, errors):
        """Update a list of nodes, based on their errors."""
        priorities = self.get_priority(errors)

        for index, priority in zip(tree_indices, priorities):
            self.memory.update(index, priority)

    def remember(self, s, a, r, s_prime, game_over):
        """Remember SARS' experiences, with the game_over parameter (done)."""
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])

        if self.per: # If using PER, insert in the max_priority.
            max_priority = self.memory.max_leaf()

            if max_priority == 0:
                max_priority = self.get_priority(0)

            self.memory.insert(experience, max_priority)
            self.exp += 1
        else: # Else, just append the experience to the list.
            self.memory.append(experience)

            if self.memory_size > 0 and self.exp_size() > self.memory_size:
                self.memory.pop(0)

    def get_samples(self, batch_size):
        """Sample the memory according to PER flag."""
        if self.per:
            batch = [None] * batch_size
            IS_weights = np.zeros((batch_size, ))
            tree_indices = [0] * batch_size

            memory_sum = self.memory.sum()
            len_seg = memory_sum / batch_size
            min_prob = self.memory.min_leaf() / memory_sum

            for i in range(batch_size):
                val = uniform(len_seg * i, len_seg * (i + 1))
                tree_indices[i], priority, batch[i] = self.memory.retrieve(val)
                prob = priority / self.memory.sum()
                IS_weights[i] = np.power(prob / min_prob, -self.per_beta)

            return np.array(batch), IS_weights, tree_indices

        else:
            IS_weights = np.ones((batch_size, ))
            batch = sample(self.memory, batch_size)
            return np.array(batch), IS_weights, None

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        r = r.repeat(nb_actions).reshape((batch_size, nb_actions))
        game_over = game_over.repeat(nb_actions)\
                             .reshape((batch_size, nb_actions))
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])
            for i in range(batch_size):
                Qsa[i] = Y_target[i][actions[i]]
            Qsa = np.array(Qsa).repeat(nb_actions).reshape((batch_size, nb_actions))

        else:
            Qsa = np.max(Y[batch_size:], axis = 1).repeat(nb_actions)\
                                                .reshape((batch_size, nb_actions))

        # The targets here already take into account
        delta = np.zeros((batch_size, nb_actions))
        a = np.cast['int'](a)
        delta[np.arange(batch_size), a] = 1
        targets = ((1 - delta) * Y[:batch_size]
                  + delta * (r + gamma * (1 - game_over) * Qsa))

        if self.per: # Update the Sum Tree with the absolute error.
            errors = np.abs((targets - Y[:batch_size]).max(axis = 1)).clip(max = 1.)
            self.update(tree_indices, errors)

        return S, targets, IS_weights

    def reset_memory(self):
        """Set the memory as a blank list."""
        if self.per:
            if self.memory_size <= 0:
                self.memory_size = 150000

            self.memory = SumTree(self.memory_size)
            self.exp = 0
        else:
            self.memory = []


#!/usr/bin/env python

"""dqn: First try to create an AI for SnakeGame. Is it good enough?

This algorithm is a implementation of DQN, Double DQN logic (using a target
network to have fixed Q-targets), Dueling DQN logic (Q(s,a) = Advantage + Value),
PER (Prioritized Experience Replay, using Sum Trees) and Multi-step returns. You
can read more about these on https://goo.gl/MctLzp

Implemented algorithms
----------
    * Simple Deep Q-network (DQN with ExperienceReplay);
        Paper: https://arxiv.org/abs/1312.5602
    * Double Deep Q-network (Double DQN);
        Paper: https://arxiv.org/abs/1509.06461
    * Dueling Deep Q-network (Dueling DQN);
        Paper: https://arxiv.org/abs/1511.06581
    * Prioritized Experience Replay (PER);
        Paper: https://arxiv.org/abs/1511.05952
    * Multi-step returns (n-steps);
        Paper: https://arxiv.org/pdf/1703.01327
    * Noisy nets.
        Paper: https://arxiv.org/abs/1706.10295

Arguments
----------
--load: 'file.h5'
    Load a previously trained model in '.h5' format.
--board_size: int, optional, default = 10
    Assign the size of the board.
--nb_frames: int, optional, default = 4
    Assign the number of frames per stack, default = 4.
--nb_actions: int, optional, default = 5
    Assign the number of actions possible.
--update_freq: int, optional, default = 0.001
    Whether to soft or hard update the target. Epochs or ammount of the update.
--visual: boolean, optional, default = False
    Select wheter or not to draw the game in pygame.
--double: boolean, optional, default = False
    Use a target network with double DQN logic.
--dueling: boolean, optional, default = False
    Whether to use dueling network logic, Q(s,a) = A + V.
--per: boolean, optional, default = False
    Use Prioritized Experience Replay (based on Sum Trees).
--local_state: boolean, optional, default = True
    Verify is possible next moves are dangerous (field expertise)
    THIS
"""

import numpy as np
from array import array
import random

__author__ = "Victor Neves"
__license__ = "MIT"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"


class Agent:
    """Agent based in a simple DQN that can read states, remember and play.

    Attributes
    ----------
    memory: object
        Memory used in training. ExperienceReplay or PrioritizedExperienceReplay
    memory_size: int, optional, default = -1
        Capacity of the memory used.
    model: keras model
        The input model in Keras.
    target: keras model, optional, default = None
        The target model, used to calculade the fixed Q-targets.
    nb_frames: int, optional, default = 4
        Ammount of frames for each experience (sars).
    board_size: int, optional, default = 10
        Size of the board used.
    frames: list of experiences
        The buffer of frames, store sars experiences.
    per: boolean, optional, default = False
        Flag for PER usage.
    update_target_freq: int or float, default = 0.001
        Whether soft or hard updates occur. If < 1, soft updated target model.
    n_steps: int, optional, default = 1
        Size of the rewards buffer, to use Multi-step returns.
    """
    def __init__(self, model, sess, target = None, memory_size = -1, nb_frames = 4,
                 board_size = 10, per = False, update_target_freq = 0.001):
        """Initialize the agent with given attributes."""
        if per:
            self.memory = PrioritizedExperienceReplay(memory_size = memory_size)
        else:
            self.memory = ExperienceReplay(memory_size = memory_size)

        self.per = per
        self.model = model
        self.target = target
        self.nb_frames = nb_frames
        self.board_size = board_size
        self.update_target_freq = update_target_freq
        self.sess = sess
        self.set_noise_list()
        self.clear_frames()

    def reset_memory(self):
        """Reset memory if necessary."""
        self.memory.reset_memory()

    def set_noise_list(self):
        """Set a list of noise variables if NoisyNet is involved."""
        self.noise_list = []
        for layer in self.model.layers:
            if type(layer) in {NoisyDenseFG}:
                self.noise_list.extend(layer.noise_list)

    def sample_noise(self):
        """Resample noise variables in NoisyNet."""
        for noise in self.noise_list:
            self.sess.run(noise.initializer)

    def get_game_data(self, game):
        """Create a list with 4 frames and append/pop them each frame.

        Return
        ----------
        expanded_frames: list of experiences
            The buffer of frames, shape = (nb_frames, board_size, board_size)
        """
        frame = game.state()

        if self.frames is None:
            self.frames = [frame] * self.nb_frames
        else:
            self.frames.append(frame)
            self.frames.pop(0)

        expanded_frames = np.expand_dims(self.frames, 0)
        # expanded_frames = np.transpose(expanded_frames, [0, 3, 2, 1])

        return expanded_frames

    def clear_frames(self):
        """Reset frames to restart appending."""
        self.frames = None

    def update_target_model_hard(self):
        """Update the target model with the main model's weights."""
        self.target.set_weights(self.model.get_weights())

    def transfer_weights(self):
        """Transfer Weights from Model to Target at rate update_target_freq."""
        model_weights = self.model.get_weights()
        target_weights = self.target.get_weights()

        for i in range(len(W)):
            target_weights[i] = (self.update_target_freq * model_weights[i]
                                 + ((1 - self.update_target_frequency)
                                    * target_weights[i]))

        self.target.set_weights(target_weights)

    def print_metrics(self, epoch, nb_epoch, history_size, policy, value,
                      win_count, history_step, history_reward,
                      history_loss = None, verbose = 1):
        """Function to print metrics of training steps."""
        if verbose == 0:
            pass
        elif verbose == 1:
            text_epoch = ('Epoch: {:03d}/{:03d} | Mean size 10: {:.1f} | '
                           + 'Longest 10: {:03d} | Mean steps 10: {:.1f} | '
                           + 'Wins: {:d} | Win percentage: {:.1f}%')
            print(text_epoch.format(epoch + 1, nb_epoch,
                                    np.mean(history_size[-10:]),
                                    max(history_size[-10:]),
                                    np.mean(history_step[-10:]),
                                    win_count, 100 * win_count/(epoch + 1)))
        else:
            text_epoch = 'Epoch: {:03d}/{:03d}'  # Print epoch info
            print(text_epoch.format(epoch + 1, nb_epoch))

            if loss is not None:  # Print training performance
                text_train = ('\t\x1b[0;30;47m' + ' Training metrics ' + '\x1b[0m'
                              + '\tTotal loss: {:.4f} | Loss per step: {:.4f} | '
                              + 'Mean loss - 100 episodes: {:.4f}')
                print(text_perf.format(history_loss[-1],
                                       history_loss[-1] / history_step[-1],
                                       np.mean(history_loss[-100:])))

            text_game = ('\t\x1b[0;30;47m' + ' Game metrics ' + '\x1b[0m'
                         + '\t\tSize: {:d} | Ammount of steps: {:d} | '
                         + 'Steps per food eaten: {:.1f} | '
                         + 'Mean size - 100 episodes: {:.1f}')
            print(text_game.format(history_size[-1], history_step[-1],
                                   history_size[-1] / history_step[-1],
                                   np.mean(history_step[-100:])))

            # Print policy metrics
            if policy == "BoltzmannQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tBoltzmann Temperature: {:.2f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            elif policy == "BoltzmannGumbelQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tNumber of actions: {:.0f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            else:
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tEpsilon: {:.2f} | Episode reward: {:.1f} | '
                               + 'Wins: {:d} | Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))

    def train_model(self, model, target, batch_size, gamma, nb_actions, epoch = 0):
        """Function to train the model on a batch of the data. The optimization
        flag is used when we are not playing, just batching and optimizing.

        Return
        ----------
        loss: float
            Training loss of given batch.
        """
        loss = 0.
        batch = self.memory.get_targets(model = self.model,
                                        target = self.target,
                                        batch_size = batch_size,
                                        gamma = gamma,
                                        nb_actions = nb_actions)

        if batch:
            inputs, targets, IS_weights = batch

            if inputs is not None and targets is not None:
                loss = float(self.model.train_on_batch(inputs,
                                                       targets,
                                                       IS_weights))

        return loss

    def train(self, game, nb_epoch = 10000, batch_size = 64, gamma = 0.95,
              eps = [1., .01], temp = [1., 0.01], learning_rate = 0.5,
              observe = 0, optim_rounds = 1, policy = "EpsGreedyQPolicy",
              verbose = 1, n_steps = 1):
        """The main training function, loops the game, remember and choose best
        action given game state (frames)."""
        if not hasattr(self, 'n_steps'):
            self.n_steps = n_steps  # Set attribute only once

        history_size = array('i')  # Holds all the sizes
        history_step = array('f')  # Holds all the steps
        history_loss = array('f')  # Holds all the losses
        history_reward = array('f')  # Holds all the rewards

        # Select exploration policy. EpsGreedyQPolicy runs faster, but takes
        # longer to converge. BoltzmannGumbelQPolicy is the slowest, but
        # converge really fast (0.1 * nb_epoch used in EpsGreedyQPolicy).
        # BoltzmannQPolicy is in the middle.
        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp[0], temp[1], nb_epoch * learning_rate)
        elif policy == "BoltzmannGumbelQPolicy":
            q_policy = BoltzmannGumbelQPolicy()
        elif policy == 'GreedyQPolicy':
            q_policy = GreedyQPolicy()        
        else:
            q_policy = EpsGreedyQPolicy(eps[0], eps[1], nb_epoch * learning_rate)

        nb_actions = game.nb_actions
        win_count = 0

        # If optim_rounds is bigger than one, the model will keep optimizing
        # after the exploration, in turns of nb_epoch size.
        for turn in range(optim_rounds):
            if turn > 0:
                for epoch in range(nb_epoch):
                    loss = self.train_model(model = self.model,
                                            epoch = epoch,
                                            target = self.target,
                                            batch_size = batch_size,
                                            gamma = gamma,
                                            nb_actions = nb_actions)
                    text_optim = ('Optimizer turn: {:2d} | Epoch: {:03d}/{:03d}'
                                  + '| Loss: {:.4f}')
                    print(text_optim.format(turn, epoch + 1, nb_epoch, loss))
            else:  # Exploration and training
                for epoch in range(nb_epoch):
                    loss = 0.
                    total_reward = 0.
                    game.reset_game()
                    self.clear_frames()
                    S = self.get_game_data(game)

                    if n_steps > 1:  # Create multi-step returns buffer.
                        n_step_buffer = array('f')

                    while not game.game_over:  # Main loop, until game_over
                        game.food_pos = game.generate_food()
                        self.sample_noise()
                        action, value = q_policy.select_action(self.model,
                                                               S, epoch,
                                                               nb_actions)
                        game.play(action)
                        r = game.get_reward()
                        total_reward += r

                        if n_steps > 1:
                            n_step_buffer.append(r)

                            if len(n_step_buffer) < n_steps:
                                R = r
                            else:
                                R = sum([n_step_buffer[i] * (gamma ** i)\
                                        for i in range(n_steps)])

                                n_step_buffer.pop(0)
                        else:
                            R = r

                        S_prime = self.get_game_data(game)
                        experience = [S, action, R, S_prime, game.game_over]
                        self.memory.remember(*experience)  # Add to the memory
                        S = S_prime  # Advance to the next state (stack of S)

                        if epoch >= observe:  # Get the batchs and train
                            loss += self.train_model(model = self.model,
                                                     target = self.target,
                                                     batch_size = batch_size,
                                                     gamma = gamma,
                                                     nb_actions = nb_actions)

                    if game.is_won():
                        win_count += 1  # Counter of wins for metrics

                    if self.per:  # Advance beta, used in PER
                        self.memory.beta = self.memory.schedule.value(epoch)

                    if self.target is not None:  # Update the target model
                        if update_target_freq >= 1: # Hard updates
                            if epoch % self.update_target_freq == 0:
                                self.update_target_model_hard()
                        elif update_target_freq < 1.:  # Soft updates
                            self.transfer_weights()

                    history_size.append(game.snake.length)
                    history_step.append(game.step)
                    history_loss.append(loss)
                    history_reward.append(total_reward)

                    if (epoch + 1) % 10 == 0:
                        self.print_metrics(epoch = epoch, nb_epoch = nb_epoch,
                                           history_size = history_size,
                                           history_loss = history_loss,
                                           history_step = history_step,
                                           history_reward = history_reward,
                                           policy = policy, value = value,
                                           win_count = win_count,
                                           verbose = verbose)

    def test(self, game, nb_epoch = 1000, eps = 0.01, temp = 0.01,
             visual = False, policy = "GreedyQPolicy"):
        """Play the game with the trained agent. Can use the visual tag to draw
            in pygame."""
        win_count = 0

        history_size = array('i')  # Holds all the sizes
        history_step = array('f')  # Holds all the steps
        history_reward = array('f')  # Holds all the rewards

        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp, temp, nb_epoch)
        elif policy == "EpsGreedyQPolicy":
            q_policy = EpsGreedyQPolicy(eps, eps, nb_epoch)
        else:
            q_policy = GreedyQPolicy()

        for epoch in range(nb_epoch):
            game.reset_game()
            self.clear_frames()

            if visual:
                game.create_window()
                previous_size = game.snake.length  # Initial size of the snake
                color_list = game.gradient([(42, 42, 42), (152, 152, 152)],\
                                               previous_size)
                elapsed = 0

            while not game.game_over:
                if visual:
                    elapsed += game.fps.get_time()  # Get elapsed time since last call.

                    if elapsed >= 60:
                        elapsed = 0
                        S = self.get_game_data(game)
                        action, value = q_policy.select_action(self.model, S, epoch, game.nb_actions)
                        game.play(action)
                        current_size = game.snake.length  # Update the body size

                        if current_size > previous_size:
                            color_list = game.gradient([(42, 42, 42), (152, 152, 152)],
                                                       game.snake.length)

                            previous_size = current_size

                        game.draw(color_list)

                    pygame.display.update()
                    game.fps.tick(120)  # Limit FPS to 100
                else:
                    S = self.get_game_data(game)
                    action, value = q_policy.select_action(self.model, S, epoch, game.nb_actions)
                    game.play(action)
                    current_size = game.snake.length  # Update the body size

                if game.game_over:
                    history_size.append(current_size)
                    history_step.append(game.step)
                    history_reward.append(game.get_reward())

            if game.is_won():
                win_count += 1

        print("Accuracy: {} %".format(100. * win_count / nb_epoch))
        print("Mean size: {} | Biggest size: {} | Smallest size: {}"\
              .format(np.mean(history_size), np.max(history_size),
                      np.min(history_size)))
        print("Mean steps: {} | Biggest step: {} | Smallest step: {}"\
              .format(np.mean(history_step), np.max(history_step),
                      np.min(history_step)))
        print("Mean rewards: {} | Biggest reward: {} | Smallest reward: {}"\
              .format(np.mean(history_reward), np.max(history_reward),
                      np.min(history_reward)))

"""THIS"""
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base
from tensorflow.python.ops.init_ops import Constant

class NoisyDense(tf.keras.layers.Dense):

    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        if input_shape[-1].value is None:
            raise ValueError('The last dimension of the inputs to `Dense` '
                             'should be defined. Found `None`.')
        self.input_spec = base.InputSpec(min_ndim=2,
                                         axes={-1: input_shape[-1].value})
        kernel_shape = [input_shape[-1].value, self.units]
        kernel_quiet = self.add_variable('kernel_quiet',
                                         shape=kernel_shape,
                                         initializer=self.kernel_initializer,
                                         regularizer=self.kernel_regularizer,
                                         constraint=self.kernel_constraint,
                                         dtype=self.dtype,
                                         trainable=True)
        scale_init = Constant(value=(0.5 / np.sqrt(kernel_shape[0])))
        kernel_noise_scale = self.add_variable('kernel_noise_scale',
                                               shape=kernel_shape,
                                               initializer=scale_init,
                                               dtype=self.dtype,
                                               trainable=True)
        kernel_noise = self.make_kernel_noise(shape=kernel_shape)
        self.kernel = kernel_quiet + kernel_noise_scale * kernel_noise
        if self.use_bias:
            bias_shape = [self.units,]
            bias_quiet = self.add_variable('bias_quiet',
                                           shape=bias_shape,
                                           initializer=self.bias_initializer,
                                           regularizer=self.bias_regularizer,
                                           constraint=self.bias_constraint,
                                           dtype=self.dtype,
                                           trainable=True)
            bias_noise_scale = self.add_variable(name='bias_noise_scale',
                                                 shape=bias_shape,
                                                 initializer=scale_init,
                                                 dtype=self.dtype,
                                                 trainable=True)
            bias_noise = self.make_bias_noise(shape=bias_shape)
            self.bias = bias_quiet + bias_noise_scale * bias_noise
        else:
            self.bias = None
        self.built = True

    def make_kernel_noise(self, shape):
        raise NotImplementedError

    def make_bias_noise(self, shape):
        raise NotImplementedError


class NoisyDenseIG(NoisyDense):
    '''
    Noisy dense layer with independent Gaussian noise
    '''
    def make_kernel_noise(self, shape):
        noise = tf.random_normal(shape, dtype=self.dtype)
        kernel_noise = tf.Variable(noise, trainable=False, dtype=self.dtype)
        self.noise_list = [kernel_noise]
        return kernel_noise

    def make_bias_noise(self, shape):
        noise = tf.random_normal(shape, dtype=self.dtype)
        bias_noise = tf.Variable(noise, trainable=False, dtype=self.dtype)
        self.noise_list.append(bias_noise)
        return bias_noise


class NoisyDenseFG(NoisyDense):
    '''
    Noisy dense layer with factorized Gaussian noise
    '''
    def make_kernel_noise(self, shape):
        kernel_noise_input = self.make_fg_noise(shape=[shape[0]])
        kernel_noise_output = self.make_fg_noise(shape=[shape[1]])
        self.noise_list = [kernel_noise_input, kernel_noise_output]
        kernel_noise = kernel_noise_input[:, tf.newaxis] * kernel_noise_output
        return kernel_noise

    def make_bias_noise(self, shape):
        return self.noise_list[1] # kernel_noise_output

    def make_fg_noise(self, shape):
        noise = tf.random_normal(shape, dtype=self.dtype)
        trans_noise = tf.sign(noise) * tf.sqrt(tf.abs(noise))
        return tf.Variable(trans_noise, trainable=False, dtype=self.dtype)
        
        
import random
import numpy as np

class LinearSchedule(object):
    def __init__(self, schedule_timesteps, final_p, initial_p):
        """Linear interpolation between initial_p and final_p over
        schedule_timesteps. After this many timesteps pass final_p is
        returned.
        Parameters
        ----------
        schedule_timesteps: int
            Number of timesteps for which to linearly anneal initial_p
            to final_p
        initial_p: float
            initial output value
        final_p: float
            final output value
        """
        self.schedule_timesteps = schedule_timesteps
        self.final_p = final_p
        self.initial_p = initial_p

    def value(self, t):
        """See Schedule.value"""
        fraction = min(float(t) / self.schedule_timesteps, 1.0)
        return self.initial_p + fraction * (self.final_p - self.initial_p)


import random
import numpy as np

class GreedyQPolicy:
    """Implement the greedy policy

    Greedy policy always takes current best action.
    """
    def __init__(self):
        super(GreedyQPolicy, self).__init__()

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)
        action = int(np.argmax(q[0]))

        return action, 0

    def get_config(self):
        """Return configurations of GreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(GreedyQPolicy, self).get_config()
        return config


class EpsGreedyQPolicy:
    """Implement the epsilon greedy policy

    Eps Greedy policy either:

    - takes a random action with probability epsilon
    - takes current best action with prob (1 - epsilon)
    """
    def __init__(self, max_eps=1., min_eps = .01, nb_epoch = 10000):
        super(EpsGreedyQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_eps, max_eps)

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        rand = random.random()
        self.eps = self.schedule.value(epoch)

        if rand < self.eps:
            action = int(nb_actions * rand)
        else:
            q = model.predict(state)
            action = int(np.argmax(q[0]))

        return action, self.eps

    def get_config(self):
        """Return configurations of EpsGreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(EpsGreedyQPolicy, self).get_config()
        config['eps'] = self.eps
        return config


class BoltzmannQPolicy:
    """Implement the Boltzmann Q Policy
    Boltzmann Q Policy builds a probability law on q values and returns
    an action selected randomly according to this law.
    """
    def __init__(self, max_temp = 1., min_temp = .01, nb_epoch = 10000, clip = (-500., 500.)):
        super(BoltzmannQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_temp, max_temp)
        self.clip = clip

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        self.temp = self.schedule.value(epoch)
        arg = q / self.temp

        exp_values = np.exp(arg - arg.max())
        probs = exp_values / exp_values.sum()
        action = np.random.choice(range(nb_actions), p = probs)

        return action, self.temp

    def get_config(self):
        """Return configurations of BoltzmannQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannQPolicy, self).get_config()
        config['temp'] = self.temp
        config['clip'] = self.clip
        return config


class BoltzmannGumbelQPolicy:
    """Implements Boltzmann-Gumbel exploration (BGE) adapted for Q learning
    based on the paper Boltzmann Exploration Done Right
    (https://arxiv.org/pdf/1705.10257.pdf).
    BGE is invariant with respect to the mean of the rewards but not their
    variance. The parameter C, which defaults to 1, can be used to correct for
    this, and should be set to the least upper bound on the standard deviation
    of the rewards.
    BGE is only available for training, not testing. For testing purposes, you
    can achieve approximately the same result as BGE after training for N steps
    on K actions with parameter C by using the BoltzmannQPolicy and setting
    tau = C/sqrt(N/K)."""

    def __init__(self, C = 1.0):
        super(BoltzmannGumbelQPolicy, self).__init__()
        self.C = C
        self.action_counts = None

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        q = q.astype('float64')

        # If we are starting training, we should reset the action_counts.
        # Otherwise, action_counts should already be initialized, since we
        # always do so when we begin training.
        if epoch == 0:
            self.action_counts = np.ones(q.shape)

        beta = self.C/np.sqrt(self.action_counts)
        Z = np.random.gumbel(size = q.shape)

        perturbation = beta * Z
        perturbed_q_values = q + perturbation
        action = np.argmax(perturbed_q_values)

        self.action_counts[action] += 1
        return action, np.sum(self.action_counts)

    def get_config(self):
        """Return configurations of BoltzmannGumbelQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannGumbelQPolicy, self).get_config()
        config['C'] = self.C
        return config

from keras.engine import Layer, InputSpec
from keras import initializers
from keras import regularizers
from keras import constraints
from keras import backend as K

from keras.utils.generic_utils import get_custom_objects

from keras.layers import BatchNormalization


class SwitchNormalization(Layer):
    """Switchable Normalization layer
    Switch Normalization performs Instance Normalization, Layer Normalization and Batch
    Normalization using its parameters, and then weighs them using learned parameters to
    allow different levels of interaction of the 3 normalization schemes for each layer.
    Only supports the moving average variant from the paper, since the `batch average`
    scheme requires dynamic graph execution to compute the mean and variance of several
    batches at runtime.
    # Arguments
        axis: Integer, the axis that should be normalized
            (typically the features axis).
            For instance, after a `Conv2D` layer with
            `data_format="channels_first"`,
            set `axis=1` in `BatchNormalization`.
        momentum: Momentum for the moving mean and the moving variance. The original
            implementation suggests a default momentum of `0.997`, however it is highly
            unstable and training can fail after a few epochs. To stabilise training, use
            lower values of momentum such as `0.99` or `0.98`.
        epsilon: Small float added to variance to avoid dividing by zero.
        final_gamma: Bool value to determine if this layer is the final
            normalization layer for the residual block.  Overrides the initialization
            of the scaling weights to be `zeros`. Only used for Residual Networks,
            to make the forward/backward signal initially propagated through an
            identity shortcut.
        center: If True, add offset of `beta` to normalized tensor.
            If False, `beta` is ignored.
        scale: If True, multiply by `gamma`.
            If False, `gamma` is not used.
            When the next layer is linear (also e.g. `nn.relu`),
            this can be disabled since the scaling
            will be done by the next layer.
        beta_initializer: Initializer for the beta weight.
        gamma_initializer: Initializer for the gamma weight.
        mean_weights_initializer: Initializer for the mean weights.
        variance_weights_initializer: Initializer for the variance weights.
        moving_mean_initializer: Initializer for the moving mean.
        moving_variance_initializer: Initializer for the moving variance.
        beta_regularizer: Optional regularizer for the beta weight.
        gamma_regularizer: Optional regularizer for the gamma weight.
        mean_weights_regularizer: Optional regularizer for the mean weights.
        variance_weights_regularizer: Optional regularizer for the variance weights.
        beta_constraint: Optional constraint for the beta weight.
        gamma_constraint: Optional constraint for the gamma weight.
        mean_weights_constraints: Optional constraint for the mean weights.
        variance_weights_constraints: Optional constraint for the variance weights.
    # Input shape
        Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a model.
    # Output shape
        Same shape as input.
    # References
        - [Differentiable Learning-to-Normalize via Switchable Normalization](https://arxiv.org/abs/1806.10779)
    """

    def __init__(self,
                 axis=-1,
                 momentum=0.99,
                 epsilon=1e-3,
                 final_gamma=False,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 mean_weights_initializer='ones',
                 variance_weights_initializer='ones',
                 moving_mean_initializer='ones',
                 moving_variance_initializer='zeros',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 mean_weights_regularizer=None,
                 variance_weights_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 mean_weights_constraints=None,
                 variance_weights_constraints=None,
                 **kwargs):
        super(SwitchNormalization, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.momentum = momentum
        self.epsilon = epsilon
        self.center = center
        self.scale = scale

        self.beta_initializer = initializers.get(beta_initializer)
        if final_gamma:
            self.gamma_initializer = initializers.get('zeros')
        else:
            self.gamma_initializer = initializers.get(gamma_initializer)
        self.mean_weights_initializer = initializers.get(mean_weights_initializer)
        self.variance_weights_initializer = initializers.get(variance_weights_initializer)
        self.moving_mean_initializer = initializers.get(moving_mean_initializer)
        self.moving_variance_initializer = initializers.get(moving_variance_initializer)
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.mean_weights_regularizer = regularizers.get(mean_weights_regularizer)
        self.variance_weights_regularizer = regularizers.get(variance_weights_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)
        self.mean_weights_constraints = constraints.get(mean_weights_constraints)
        self.variance_weights_constraints = constraints.get(variance_weights_constraints)

    def build(self, input_shape):
        dim = input_shape[self.axis]

        if dim is None:
            raise ValueError('Axis ' + str(self.axis) + ' of '
                             'input tensor should have a defined dimension '
                             'but the layer received an input with shape ' +
                             str(input_shape) + '.')

        self.input_spec = InputSpec(ndim=len(input_shape),
                                    axes={self.axis: dim})
        shape = (dim,)

        if self.scale:
            self.gamma = self.add_weight(
                shape=shape,
                name='gamma',
                initializer=self.gamma_initializer,
                regularizer=self.gamma_regularizer,
                constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(
                shape=shape,
                name='beta',
                initializer=self.beta_initializer,
                regularizer=self.beta_regularizer,
                constraint=self.beta_constraint)
        else:
            self.beta = None

        self.moving_mean = self.add_weight(
            shape=shape,
            name='moving_mean',
            initializer=self.moving_mean_initializer,
            trainable=False)

        self.moving_variance = self.add_weight(
            shape=shape,
            name='moving_variance',
            initializer=self.moving_variance_initializer,
            trainable=False)

        self.mean_weights = self.add_weight(
            shape=(3,),
            name='mean_weights',
            initializer=self.mean_weights_initializer,
            regularizer=self.mean_weights_regularizer,
            constraint=self.mean_weights_constraints)

        self.variance_weights = self.add_weight(
            shape=(3,),
            name='variance_weights',
            initializer=self.variance_weights_initializer,
            regularizer=self.variance_weights_regularizer,
            constraint=self.variance_weights_constraints)

        self.built = True

    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)

        # Prepare broadcasting shape.
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]

        if self.axis != 0:
            del reduction_axes[0]

        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        mean_instance = K.mean(inputs, reduction_axes, keepdims=True)
        variance_instance = K.var(inputs, reduction_axes, keepdims=True)

        mean_layer = K.mean(mean_instance, self.axis, keepdims=True)
        temp = variance_instance + K.square(mean_instance)
        variance_layer = K.mean(temp, self.axis, keepdims=True) - K.square(mean_layer)

        def training_phase():
            mean_batch = K.mean(mean_instance, axis=0, keepdims=True)
            variance_batch = K.mean(temp, axis=0, keepdims=True) - K.square(mean_batch)

            mean_batch_reshaped = K.flatten(mean_batch)
            variance_batch_reshaped = K.flatten(variance_batch)

            if K.backend() != 'cntk':
                sample_size = K.prod([K.shape(inputs)[axis]
                                      for axis in reduction_axes])
                sample_size = K.cast(sample_size, dtype=K.dtype(inputs))

                # sample variance - unbiased estimator of population variance
                variance_batch_reshaped *= sample_size / (sample_size - (1.0 + self.epsilon))

            self.add_update([K.moving_average_update(self.moving_mean,
                                                     mean_batch_reshaped,
                                                     self.momentum),
                             K.moving_average_update(self.moving_variance,
                                                     variance_batch_reshaped,
                                                     self.momentum)],
                            inputs)

            return normalize_func(mean_batch, variance_batch)

        def inference_phase():
            mean_batch = self.moving_mean
            variance_batch = self.moving_variance

            return normalize_func(mean_batch, variance_batch)

        def normalize_func(mean_batch, variance_batch):
            mean_batch = K.reshape(mean_batch, broadcast_shape)
            variance_batch = K.reshape(variance_batch, broadcast_shape)

            mean_weights = K.softmax(self.mean_weights, axis=0)
            variance_weights = K.softmax(self.variance_weights, axis=0)

            mean = (mean_weights[0] * mean_instance +
                    mean_weights[1] * mean_layer +
                    mean_weights[2] * mean_batch)

            variance = (variance_weights[0] * variance_instance +
                        variance_weights[1] * variance_layer +
                        variance_weights[2] * variance_batch)

            outputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))

            if self.scale:
                broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                outputs = outputs * broadcast_gamma

            if self.center:
                broadcast_beta = K.reshape(self.beta, broadcast_shape)
                outputs = outputs + broadcast_beta

            return outputs

        if training in {0, False}:
            return inference_phase()

        return K.in_train_phase(training_phase,
                                inference_phase,
                                training=training)

    def get_config(self):
        config = {
            'axis': self.axis,
            'epsilon': self.epsilon,
            'momentum': self.momentum,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'mean_weights_initializer': initializers.serialize(self.mean_weights_initializer),
            'variance_weights_initializer': initializers.serialize(self.variance_weights_initializer),
            'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer),
            'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'mean_weights_regularizer': regularizers.serialize(self.mean_weights_regularizer),
            'variance_weights_regularizer': regularizers.serialize(self.variance_weights_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint),
            'mean_weights_constraints': constraints.serialize(self.mean_weights_constraints),
            'variance_weights_constraints': constraints.serialize(self.variance_weights_constraints),
        }
        base_config = super(SwitchNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape
    
get_custom_objects().update({'SwitchNormalization': SwitchNormalization})
    
"""THIS"""
#!/usr/bin/env python

""" Needs update!
"""

import numpy as np
import tensorflow as tf
try:
    from keras.optimizers import RMSprop, Nadam
    from keras.models import Sequential, load_model, Model
    from keras.layers import Conv2D, Dense, Flatten, MaxPooling2D, Flatten,\
                             Input, Lambda, Add
    from keras import backend as K

    K.set_image_dim_ordering('th')
except ImportError:
    from tensorflow.keras.optimizers import RMSprop, Nadam
    from tensorflow.keras.models import Sequential, load_model, Model
    from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D,\
                                        Flatten, Input, Lambda, Add


__author__ = "Victor Neves"
__license__ = "MIT"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"

DENSES = {'dense': Dense,
          'noisy_dense_fg': NoisyDenseFG,
          'noisy_dense_ig': NoisyDenseIG}

def select_optimizer(optimizer):
    assert optimizer in {'Nadam', 'RMSprop'}, "Optimizer should be RMSprop or Nadam."

    if optimizer == 'Nadam':
        optimizer = Nadam()
    else:
        optimizer = RMSprop()

    return optimizer

def select_error(error):
    assert type(error) is str, "Should use string to select error."

    if error == 'clipped_error':
        error = clipped_error

    return error

def CNN1(inputs):
    net = Conv2D(16, (3, 3), activation = 'relu')(inputs)
    net = Conv2D(32, (3, 3), activation = 'relu')(net)

    return model

def CNN2(inputs):
    net = Conv2D(16, (3, 3), activation = 'relu')(inputs)
    net = Conv2D(32, (3, 3), activation = 'relu')(net)
    net = Conv2D(32, (3, 3), activation = 'relu')(net)
    net = Flatten()(net)

    return model

def CNN3(inputs):
    """From @Kaixhin implementation's of the Rainbow paper."""
    
    net = Conv2D(32, (3, 3), activation = 'relu')(inputs)
    net = Conv2D(64, (3, 3), activation = 'relu')(net)
    net = Conv2D(64, (3, 3), activation = 'relu')(net)
    net = Flatten()(net)

    return net

def create_cnn(cnn, inputs):
    if cnn == "CNN1":
        net = CNN1(inputs)
    elif cnn == "CNN2":
        net = CNN2(inputs)
    else:
        net = CNN3(inputs)

    return net

def create_model(optimizer, loss, stack, input_size, output_size,
                 dueling = False, cnn = "CNN3", dense_type = "dense"):
    # optimizer = select_optimizer(optimizer)
    # loss = select_error(loss)
    inputs = Input(shape = (stack, input_size, input_size))
    net = create_cnn(cnn, inputs)

    if dueling:
        advt = DENSES[dense_type](3136, activation = 'relu')(net)
        advt = DENSES[dense_type](output_size)(advt)
        value = DENSES[dense_type](3136, activation = 'relu')(net)
        value = DENSES[dense_type](1)(value)

        # now to combine the two streams
        advt = Lambda(lambda advt: advt - tf.reduce_mean(advt, axis = -1,
                                                         keepdims = True))(advt)
        value = Lambda(lambda value: tf.tile(value, [1, output_size]))(value)
        final = Add()([value, advt])
    else:
        final = DENSES[dense_type](3136, activation = 'relu')(net)
        final = DENSES[dense_type](output_size)(final)

    model = Model(inputs = inputs, outputs = final)
   
    model.compile(optimizer = optimizer, loss = loss)

    return model


from __future__ import absolute_import

from keras import backend as K
from keras.optimizers import Optimizer

from tensorflow.python.ops import state_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export

if K.backend() == 'tensorflow':
    import tensorflow as tf

class COCOB(Optimizer):
    """COCOB-Backprop optimizer.
    It is recommended to leave the parameters of this optimizer
    at their default values
    (except the learning rate, which can be freely tuned).
    This optimizer, unlike other stochastic gradient based optimizers, optimize the function by
    finding individual learning rates in a coin-betting way.
    # Arguments
        alphs: float >= 0. Multiples of the largest absolute magtitude of gradients.
        epsilon: float >= 0. Fuzz factor.
    # References
        - [Training Deep Networks without Learning Rates Through Coin Betting](http://https://arxiv.org/pdf/1705.07795.pdf)
    """

    def __init__(self, alpha=100, epsilon=1e-8, **kwargs):
        super(COCOB, self).__init__(**kwargs)
        with K.name_scope(self.__class__.__name__):
            self.alpha = K.variable(alpha, name='alpha')
            self.iterations = K.variable(0., name='iterations')
        self.epsilon = epsilon

    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        L = [K.zeros(K.get_variable_shape(p), dtype=K.dtype(p)) for p in params]
        M = [K.zeros(K.get_variable_shape(p), dtype=K.dtype(p)) for p in params]
        Reward = [K.zeros(K.get_variable_shape(p), dtype=K.dtype(p)) for p in params]
        grad_sum = [K.zeros(K.get_variable_shape(p), dtype=K.dtype(p)) for p in params]

        if K.eval(self.iterations) == 0:
            old_params = [K.constant(K.eval(p)) for p in params]
            # [K.eval(p) for p in params]

        self.weights = [self.iterations] + L + M + Reward + grad_sum

        for old_p, p, g, gs, l, m, r in zip(old_params, params, grads, grad_sum, L, M, Reward):
            # update accumulator
            # old_p = K.variable(old_p)

            new_l = K.maximum(l, K.abs(g))
            self.updates.append(K.update(l, new_l))

            new_m = m + K.abs(g)
            self.updates.append(K.update(m, new_m))

            new_r = K.maximum(r - (p - old_p)*g, 0)
            self.updates.append(K.update(r, new_r))

            new_gs = gs + g
            self.updates.append(K.update(gs, new_gs))

            new_p = old_p - (new_gs/(self.epsilon + new_l*K.maximum(new_m+new_l, self.alpha*new_l)))*(new_l + new_r)

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates

    def get_config(self):
        config = {'alpha': float(K.get_value(self.alpha)),
                  'epsilon': self.epsilon}
        base_config = super(COCOB, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


class SMORMS3(Optimizer):
    '''SMORMS3 optimizer.
    Implemented based on http://sifter.org/~simon/journal/20150420.html
    # Arguments
        lr: float >= 0. Learning rate.
        epsilon: float >= 0. Fuzz factor.
        decay: float >= 0. Learning rate decay over each update.
    '''

    def __init__(self, lr=0.001, epsilon=1e-16, decay=0.,
                 **kwargs):
        super(SMORMS3, self).__init__(**kwargs)
        self.__dict__.update(locals())
        with K.name_scope(self.__class__.__name__):
            self.lr = K.variable(lr)
            # self.rho = K.variable(rho)
            self.decay = K.variable(decay)
            self.inital_decay = decay
            self.iterations = K.variable(0.)
        self.epsilon = epsilon
        self.initial_decay = decay

    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        shapes = [K.get_variable_shape(p) for p in params]
        self.updates.append(K.update_add(self.iterations, 1))

        g1s = [K.zeros(shape) for shape in shapes]
        g2s = [K.zeros(shape) for shape in shapes]
        mems = [K.ones(shape) for shape in shapes]

        lr = self.lr
        if self.inital_decay > 0:
            lr *= (1. / (1. + self.decay * self.iterations))

        self.weights = [self.iterations] + g1s + g2s + mems

        for p, g, g1, g2, m in zip(params, grads, g1s, g2s, mems):
            r = 1. / (m + 1)
            new_g1 = (1. - r) * g1 + r * g
            new_g2 = (1. - r) * g2 + r * K.square(g)
            # update accumulators
            self.updates.append(K.update(g1, new_g1))
            self.updates.append(K.update(g2, new_g2))
            new_p = p - g * K.minimum(lr, K.square(new_g1) / (new_g2 + self.epsilon)) / (
            K.sqrt(new_g2) + self.epsilon)
            new_m = 1 + m * (1 - K.square(new_g1) / (new_g2 + self.epsilon))
            # update rho
            self.updates.append(K.update(m, new_m))
            # apply constraints
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, new_p))
        return self.updates

    def get_config(self):
        config = {'lr': float(K.get_value(self.lr)),
                  'decay': float(K.get_value(self.decay)),
                  'epsilon': self.epsilon}
        base_config = super(SMORMS3, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

class Yogi(Optimizer):
    """Yogi optimizer.
    Default parameters follow those provided in the original paper.
    Arguments:
      lr: float >= 0. Learning rate.
      beta_1: float, 0 < beta < 1. Generally close to 1.
      beta_2: float, 0 < beta < 1. Generally close to 1.
      epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
      decay: float >= 0. Learning rate decay over each update.
      amsgrad: boolean. Whether to apply the AMSGrad variant of this
          algorithm from the paper "On the Convergence of Adam and
          Beyond".
    """

    def __init__(self,
               lr=0.001,
               beta_1=0.9,
               beta_2=0.999,
               epsilon=None,
               decay=0.00000001,
               amsgrad=False,
               **kwargs):
        super(Yogi, self).__init__(**kwargs)
        with K.name_scope(self.__class__.__name__):
            self.iterations = K.variable(0, dtype='int64', name='iterations')
            self.lr = K.variable(lr, name='lr')
            self.beta_1 = K.variable(beta_1, name='beta_1')
            self.beta_2 = K.variable(beta_2, name='beta_2')
            self.decay = K.variable(decay, name='decay')
        if epsilon is None:
            epsilon = K.epsilon()
        self.epsilon = epsilon
        self.initial_decay = decay
        self.amsgrad = amsgrad

    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [state_ops.assign_add(self.iterations, 1)]

        lr = self.lr
        if self.initial_decay > 0:
            lr = lr * (  # pylint: disable=g-no-augmented-assignment
                1. / (1. + self.decay * math_ops.cast(self.iterations,
                                                    K.dtype(self.decay))))

        t = math_ops.cast(self.iterations, K.floatx()) + 1
        lr_t = lr * (
            K.sqrt(1. - math_ops.pow(self.beta_2, t)) /
            (1. - math_ops.pow(self.beta_1, t)))

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        if self.amsgrad:
            vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        else:
            vhats = [K.zeros(1) for _ in params]
        self.weights = [self.iterations] + ms + vs + vhats

        for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
            m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
            #v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g) # from amsgrad
            v_t = v - (1-self.beta_2)*K.sign(v-math_ops.square(g))*math_ops.square(g)
            p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)

            self.updates.append(state_ops.assign(m, m_t))
            self.updates.append(state_ops.assign(v, v_t))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(state_ops.assign(p, new_p))
        return self.updates

    def get_config(self):
        config = {
            'lr': float(K.get_value(self.lr)),
            'beta_1': float(K.get_value(self.beta_1)),
            'beta_2': float(K.get_value(self.beta_2)),
            'decay': float(K.get_value(self.decay)),
            'epsilon': self.epsilon,
            'amsgrad': self.amsgrad
            }
        base_config = super(Yogi, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

class SGD2(Optimizer):
    """Stochastic gradient descent using second-order information.
   Ye, C., Yang, Y., Fermuller, C., & Aloimonos, Y. (2017).
   On the Importance of Consistency in Training Deep Neural Networks. arXiv pre
   arXiv:1708.00631.
     # Arguments
        lr: float >= 0. Learning rate.
        momentum: float >= 0. Parameter updates momentum.
        decay: float >= 0. Learning rate decay over each update.
        nesterov: boolean. Whether to apply Nesterov momentum.
    """
    def __init__(self, lr=0.01, momentum=0., decay=0.,
                 nesterov=False, **kwargs):
        super(SGD2, self).__init__(**kwargs)
        
        with K.name_scope(self.__class__.__name__):
            self.iterations = K.variable(0, dtype='int64', name='iterations')
            self.lr = K.variable(lr, name='lr')
            self.momentum = K.variable(momentum, name='momentum')
            self.decay = K.variable(decay, name='decay')
        self.initial_decay = decay
    
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]
        lr = self.lr
        if self.initial_decay > 0:
            lr *= (1. / (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
        layer_count = 0
        lambda_value = 0.01
        # momentum
        shapes = [K.int_shape(p) for p in params ]
        moments = [K.zeros(shape) for shape in shapes]
        self.weights = [self.iterations] + moments

        for p, g, m in zip(params, grads, moments):
            # gradients correction by second order information
            if len(K.int_shape(g)) > 1:
                x = self.layer_inputs[layer_count]
                shape_g = K.int_shape(g)
                # First permute the x, then compute transpose of x
                layer_count = layer_count + 1
                # For 3 channel image
                if len(K.int_shape(x)) == 4:
                    x = tf.transpose(x, perm=[3, 0, 1, 2])
                    x = tf.reshape(x, [K.int_shape(x)[0], -1])
                    x = tf.matmul(x, tf.transpose(x))
                    g = tf.transpose(g, perm=[2, 0, 1, 3])
                    g = tf.reshape(g, [K.int_shape(g)[0], -1])

                elif len(K.int_shape(x)) == 2:
                    xt = tf.reshape(x, [K.int_shape(x)[1], -1])
                    x = tf.matmul(xt, x)
                    x = tf.reshape(x, [K.int_shape(x)[0], K.int_shape(x)[0]])

                lambda_eye =  tf.eye(K.int_shape(x)[0])
                corr_term = tf.matrix_inverse(tf.add(x, tf.multiply(lambda_value, lambda_eye)))
                g = tf.matmul(corr_term, g)
                g = tf.reshape(g, shape_g)

            v = self.momentum * m - lr * g  # velocity
            self.updates.append(K.update(m, v))
            new_p = p + v
            self.updates.append(K.update(p, new_p))
        return self.updates
    def set_layerinput(self, layer_inputs):
        self.layer_inputs = layer_inputs
        
    def get_config(self):
        config = {'lr': float(K.get_value(self.lr)),
                  'momentum': float(K.get_value(self.momentum)),
                  'decay': float(K.get_value(self.decay))}
        base_config = super(SGD2, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

import inspect
import keras
from keras import backend as K


def _get_shape(x):
    if hasattr(x, 'dense_shape'):
        return x.dense_shape

    return K.shape(x)


def add_gradient_noise(BaseOptimizer):
    """
    Given a Keras-compatible optimizer class, returns a modified class that
    supports adding gradient noise as introduced in this paper:
    https://arxiv.org/abs/1511.06807
    The relevant parameters from equation 1 in the paper can be set via
    noise_eta and noise_gamma, set by default to 0.3 and 0.55 respectively.
    """
    if not (
        inspect.isclass(BaseOptimizer) and
        issubclass(BaseOptimizer, keras.optimizers.Optimizer)
    ):
        raise ValueError(
            'add_gradient_noise() expects a valid Keras optimizer'
        )

    class NoisyOptimizer(BaseOptimizer):
        def __init__(self, noise_eta=0.3, noise_gamma=0.55, **kwargs):
            super(NoisyOptimizer, self).__init__(**kwargs)
            with K.name_scope(self.__class__.__name__):
                self.noise_eta = K.variable(noise_eta, name='noise_eta')
                self.noise_gamma = K.variable(noise_gamma, name='noise_gamma')

        def get_gradients(self, loss, params):
            grads = super(NoisyOptimizer, self).get_gradients(loss, params)

            # Add decayed gaussian noise
            t = K.cast(self.iterations, K.dtype(grads[0]))
            variance = self.noise_eta / ((1 + t) ** self.noise_gamma)

            grads = [
                grad + K.random_normal(
                    _get_shape(grad),
                    mean=0.0,
                    stddev=K.sqrt(variance),
                    dtype=K.dtype(grads[0])
                )
                for grad in grads
            ]

            return grads

        def get_config(self):
            config = {'noise_eta': float(K.get_value(self.noise_eta)),
                      'noise_gamma': float(K.get_value(self.noise_gamma))}
            base_config = super(NoisyOptimizer, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))

    NoisyOptimizer.__name__ = 'Noisy{}'.format(BaseOptimizer.__name__)

    return NoisyOptimizer

from keras.optimizers import RMSprop, Nadam
from keras.models import Sequential, load_model, Model
from keras.layers import *
from keras.losses import *

K.set_image_dim_ordering('th')  
  
board_size = 10
nb_frames = 4
  
game = Game(player = "ROBOT", board_size = board_size,
                        local_state = True, relative_pos = False)

optimizer = add_gradient_noise(RMSprop)(noise_eta = 0.01)
error = tf.losses.huber_loss

with tf.Session() as sess:
    model = create_model(optimizer = optimizer, loss = error,
                                 stack = nb_frames, input_size = board_size,
                                 output_size = game.nb_actions,
                                 dueling = False, cnn = 'CNN3',
                                 dense_type = 'dense')
    target = None
    sess.run(tf.global_variables_initializer())
    agent = Agent(model = model, sess = sess, target = target, memory_size = -1,
                      nb_frames = nb_frames, board_size = board_size,
                      per = False, update_target_freq = 500)
    agent.train(game, batch_size = 64, nb_epoch = 10000,
                    gamma = 0.95, n_steps = 1, policy = 'GreedyQPolicy')
    agent.test(game)



Epoch: 010/10000 | Mean size 10: 3.0 | Longest 10: 003 | Mean steps 10: 9.2 | Wins: 0 | Win percentage: 0.0%
Epoch: 020/10000 | Mean size 10: 3.4 | Longest 10: 005 | Mean steps 10: 10.2 | Wins: 3 | Win percentage: 15.0%
Epoch: 030/10000 | Mean size 10: 3.0 | Longest 10: 003 | Mean steps 10: 9.1 | Wins: 3 | Win percentage: 10.0%
Epoch: 040/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 13.1 | Wins: 4 | Win percentage: 10.0%
Epoch: 050/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 11.0 | Wins: 5 | Win percentage: 10.0%
Epoch: 060/10000 | Mean size 10: 3.2 | Longest 10: 004 | Mean steps 10: 12.1 | Wins: 7 | Win percentage: 11.7%
Epoch: 070/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 16.8 | Wins: 8 | Win percentage: 11.4%
Epoch: 080/10000 | Mean size 10: 3.2 | Longest 10: 004 | Mean steps 10: 15.6 | Wins: 10 | Win percentage: 12.5%
Epoch: 090/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 23.3 | Wins: 11 | Win percentage: 12.2%
Ep

In [0]:
#!/usr/bin/env python

"""SnakeGame: A simple and fun exploration, meant to be used by Human and AI. - THIS
"""

import sys  # To close the window when the game is over
from array import array  # Efficient numeric arrays
from os import environ, path  # To center the game window the best possible
import random  # Random numbers used for the food
import logging  # Logging function for movements and errors
import json # For file handling (leaderboards)
from itertools import tee  # For the color gradient on snake

import numpy as np # Used in calculations and math

__author__ = "Victor Neves"
__license__ = "MIT"
__version__ = "1.0"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"

# Actions, options and forbidden moves
OPTIONS = {'QUIT': 0,
           'PLAY': 1,
           'BENCHMARK': 2,
           'LEADERBOARDS': 3,
           'MENU': 4,
           'ADD_TO_LEADERBOARDS': 5}
RELATIVE_ACTIONS = {'LEFT': 0,
                    'FORWARD': 1,
                    'RIGHT': 2}
ABSOLUTE_ACTIONS = {'LEFT': 0,
                    'RIGHT': 1,
                    'UP': 2,
                    'DOWN': 3,
                    'IDLE': 4}
FORBIDDEN_MOVES = [(0, 1), (1, 0), (2, 3), (3, 2)]

# Possible rewards in the game
REWARDS = {'MOVE': -0.005,
           'GAME_OVER': -1,
           'SCORED': 1}

# Types of point in the board
POINT_TYPE = {'EMPTY': 0,
              'FOOD': 1,
              'BODY': 2,
              'HEAD': 3,
              'DANGEROUS': 4}

# Speed levels possible to human players. MEGA HARDCORE starts with MEDIUM and
# increases with snake size
LEVELS = [" EASY ", " MEDIUM ", " HARD ", " MEGA HARDCORE "]
SPEEDS = {'EASY': 80,
          'MEDIUM': 60,
          'HARD': 40,
          'MEGA_HARDCORE': 65}

# Set the constant FPS limit for the game. Smoothness depend on this.
GAME_FPS = 100


class GlobalVariables:
    """Global variables to be used while drawing and moving the snake game.

    Attributes
    ----------
    board_size: int, optional, default = 30
        The size of the board.
    block_size: int, optional, default = 20
        The size in pixels of a block.
    head_color: tuple of 3 * int, optional, default = (42, 42, 42)
        Color of the head. Start of the body color gradient.
    tail_color: tuple of 3 * int, optional, default = (152, 152, 152)
        Color of the tail. End of the body color gradient.
    food_color: tuple of 3 * int, optional, default = (200, 0, 0)
        Color of the food.
    game_speed: int, optional, default = 10
        Speed in ticks of the game. The higher the faster.
    benchmark: int, optional, default = 10
        Ammount of matches to benchmark and possibly go to leaderboards.
    """
    def __init__(self, board_size = 30, block_size = 20,
                 head_color = (42, 42, 42), tail_color = (152, 152, 152),
                 food_color = (200, 0, 0), game_speed = 80, benchmark = 10):
        """Initialize all global variables. Updated with argument_handler."""
        self.board_size = board_size
        self.block_size = block_size
        self.head_color = head_color
        self.tail_color = tail_color
        self.food_color = food_color
        self.game_speed = game_speed
        self.benchmark = benchmark

        if self.board_size > 50: # Warn the user about performance
            LOGGER.warning('WARNING: BOARD IS TOO BIG, IT MAY RUN SLOWER.')

    @property
    def canvas_size(self):
        """Canvas size is updated with board_size and block_size."""
        return self.board_size * self.block_size

class TextBlock:
    """Block of text class, used by pygame. Can be used to both text and menu.

    Attributes:
    ----------
    text: string
        The text to be displayed.
    pos: tuple of 2 * int
        Color of the tail. End of the body color gradient.
    screen: pygame window object
        The screen where the text is drawn.
    scale: int, optional, default = 1 / 12
        Adaptive scale to resize if the board size changes.
    type: string, optional, default = "text"
        Assert whether the BlockText is a text or menu option.
    """
    def __init__(self, text, pos, screen, scale = (1 / 12), block_type = "text"):
        """Initialize, set position of the rectangle and render the text block."""
        self.block_type = block_type
        self.hovered = False
        self.text = text
        self.pos = pos
        self.screen = screen
        self.scale = scale
        self.set_rect()
        self.draw()

    def draw(self):
        """Set what to render and blit on the pygame screen."""
        self.set_rend()
        self.screen.blit(self.rend, self.rect)

    def set_rend(self):
        """Set what to render (font, colors, sizes)"""
        font = pygame.font.Font(resource_path("resources/fonts/freesansbold.ttf"),
                                int((VAR.canvas_size) * self.scale))
        self.rend = font.render(self.text, True, self.get_color(),
                                self.get_background())

    def get_color(self):
        """Get color to render for text and menu (hovered or not).

        Return
        ----------
        color: tuple of 3 * int
            The color that will be rendered for the text block.
        """
        color = pygame.Color(42, 42, 42)

        if self.block_type == "menu" and not self.hovered:
                color = pygame.Color(152, 152, 152)

        return color

    def get_background(self):
        """Get background color to render for text (hovered or not) and menu.

        Return
        ----------
        color: tuple of 3 * int
            The color that will be rendered for the background of the text block.
        """
        color = None

        if self.block_type == "menu" and self.hovered:
            color = pygame.Color(152, 152, 152)

        return color

    def set_rect(self):
        """Set the rectangle and it's position to draw on the screen."""
        self.set_rend()
        self.rect = self.rend.get_rect()
        self.rect.center = self.pos


class Snake:
    """Player (snake) class which initializes head, body and board.

    The body attribute represents a list of positions of the body, which are in-
    cremented when moving/eating on the position [0]. The orientation represents
    where the snake is looking at (head) and collisions happen when any element
    is superposed with the head.

    Attributes
    ----------
    head: list of 2 * int, default = [board_size / 4, board_size / 4]
        The head of the snake, located according to the board size.
    body: list of lists of 2 * int
        Starts with 3 parts and grows when food is eaten.
    previous_action: int, default = 1
        Last action which the snake took.
    length: int, default = 3
        Variable length of the snake, can increase when food is eaten.
    """
    def __init__(self):
        """Inits Snake with 3 body parts (one is the head) and pointing right"""
        self.head = [int(VAR.board_size / 4), int(VAR.board_size / 4)]
        self.body = [[self.head[0], self.head[1]],
                     [self.head[0] - 1, self.head[1]],
                     [self.head[0] - 2, self.head[1]]]
        self.previous_action = 1
        self.length = 3

    def is_movement_invalid(self, action):
        """Check if the movement is invalid, according to FORBIDDEN_MOVES."""
        valid = False

        if (action, self.previous_action) in FORBIDDEN_MOVES:
            valid = True

        return valid

    def move(self, action, food_pos):
        """According to orientation, move 1 block. If the head is not positioned
        on food, pop a body part. Else, return without popping.

        Return
        ----------
        ate_food: boolean
            Flag which represents whether the snake ate or not food.
        """
        ate_food = False

        if (action == ABSOLUTE_ACTIONS['IDLE'] or
            self.is_movement_invalid(action)):
            action = self.previous_action
        else:
            self.previous_action = action

        if action == ABSOLUTE_ACTIONS['LEFT']:
            self.head[0] -= 1
        elif action == ABSOLUTE_ACTIONS['RIGHT']:
            self.head[0] += 1
        elif action == ABSOLUTE_ACTIONS['UP']:
            self.head[1] -= 1
        elif action == ABSOLUTE_ACTIONS['DOWN']:
            self.head[1] += 1

        self.body.insert(0, list(self.head))

        if self.head == food_pos:
            LOGGER.info('EVENT: FOOD EATEN')
            self.length = len(self.body)

            ate_food = True
        else:
            self.body.pop()

        return ate_food


class FoodGenerator:
    """Generate and keep track of food.

    Attributes
    ----------
    pos:
        Current position of food.
    is_food_on_screen:
        Flag for existence of food.
    """
    def __init__(self, body):
        """Initialize a food piece and set existence flag."""
        self.is_food_on_screen = False
        self.pos = self.generate_food(body)

    def generate_food(self, body):
        """Generate food and verify if it's on a valid place.

        Return
        ----------
        pos: tuple of 2 * int
            Position of the food that was generated. It can't be in the body.
        """
        if not self.is_food_on_screen:
            while True:
                food = [int((VAR.board_size - 1) * random.random()),
                        int((VAR.board_size - 1) * random.random())]

                if food in body:
                    continue
                else:
                    self.pos = food
                    break

            LOGGER.info('EVENT: FOOD APPEARED')
            self.is_food_on_screen = True

        return self.pos


class Game:
    """Hold the game window and functions.

    Attributes
    ----------
    window: pygame display
        Pygame window to show the game.
    fps: pygame time clock
        Define Clock and ticks in which the game will be displayed.
    snake: object
        The actual snake who is going to be played.
    food_generator: object
        Generator of food which responds to the snake.
    food_pos: tuple of 2 * int
        Position of the food on the board.
    game_over: boolean
        Flag for game_over.
    player: string
        Define if human or robots are playing the game.
    board_size: int, optional, default = 30
        The size of the board.
    local_state: boolean, optional, default = False
        Whether to use or not game expertise (used mostly by robots players).
    relative_pos: boolean, optional, default = False
        Whether to use or not relative position of the snake head. Instead of
        actions, use relative_actions.
    screen_rect: tuple of 2 * int
        The screen rectangle, used to draw relatively positioned blocks.
    """
    def __init__(self, player, board_size = 30, local_state = False, relative_pos = False):
        """Initialize window, fps and score. Change nb_actions if relative_pos"""
        VAR.board_size = board_size
        self.local_state = local_state
        self.relative_pos = relative_pos
        self.player = player

        if player == "ROBOT":
            if self.relative_pos:
                self.nb_actions = 3
            else:
                self.nb_actions = 5

            self.reset_game()

    def reset_game(self):
        """Reset the game environment."""
        self.step = 0
        self.snake = Snake()
        self.food_generator = FoodGenerator(self.snake.body)
        self.food_pos = self.food_generator.pos
        self.scored = False
        self.game_over = False

    def create_window(self):
        """Create a pygame display with board_size * block_size dimension."""
        pygame.init()
        flags = pygame.DOUBLEBUF | pygame.HWSURFACE
        self.window = pygame.display.set_mode((VAR.canvas_size, VAR.canvas_size),
                                              flags)
        self.window.set_alpha(None)

        self.screen_rect = self.window.get_rect()
        self.fps = pygame.time.Clock()

    def cycle_menu(self, menu_options, list_menu, dictionary, img = None,
                   img_rect = None):
        """Cycle through a given menu, waiting for an option to be clicked."""
        selected = False
        selected_option = None

        while not selected:
            pygame.event.pump()
            events = pygame.event.get()

            self.window.fill(pygame.Color(225, 225, 225))

            for i, option in enumerate(menu_options):
                if option is not None:
                    option.draw()
                    option.hovered = False

                    if option.rect.collidepoint(pygame.mouse.get_pos()):
                        option.hovered = True

                        for event in events:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = dictionary[list_menu[i]]

            if selected_option is not None:
                selected = True
            if img is not None:
                self.window.blit(img, img_rect.bottomleft)

            pygame.display.update()

        return selected_option

    def cycle_matches(self, n_matches, mega_hardcore = False):
        """Cycle through matches until the end."""
        score = array('i')

        for _ in range(n_matches):
            self.reset_game()
            self.start_match(wait = 3)
            score.append(self.single_player(mega_hardcore))

        return score

    def menu(self):
        """Main menu of the game.

        Return
        ----------
        selected_option: int
            The selected option in the main loop.
        """
        pygame.display.set_caption("SNAKE GAME  | PLAY NOW!")

        img = pygame.image.load(resource_path("resources/images" +
                                              "/snake_logo.png")).convert()
        img = pygame.transform.scale(img, (VAR.canvas_size,
                                           int(VAR.canvas_size / 3)))
        img_rect = img.get_rect()
        img_rect.center = self.screen_rect.center
        list_menu = ['PLAY', 'BENCHMARK', 'LEADERBOARDS', 'QUIT']
        menu_options = [TextBlock(' PLAY GAME ',
                                  (self.screen_rect.centerx,
                                   4 * self.screen_rect.centery / 10),
                                  self.window, (1 / 12), "menu"),
                        TextBlock(' BENCHMARK ',
                                  (self.screen_rect.centerx,
                                   6 * self.screen_rect.centery / 10),
                                  self.window, (1 / 12), "menu"),
                        TextBlock(' LEADERBOARDS ',
                                  (self.screen_rect.centerx,
                                   8 * self.screen_rect.centery / 10),
                                  self.window, (1 / 12), "menu"),
                        TextBlock(' QUIT ',
                                  (self.screen_rect.centerx,
                                   10 * self.screen_rect.centery / 10),
                                  self.window, (1 / 12), "menu")]
        selected_option = self.cycle_menu(menu_options, list_menu, OPTIONS,
                                          img, img_rect)

        return selected_option

    def start_match(self, wait):
        """Create some wait time before the actual drawing of the game."""
        for i in range(wait):
            time = str(wait - i)
            self.window.fill(pygame.Color(225, 225, 225))

            # Game starts in 3, 2, 1
            text = [TextBlock('Game starts in',
                              (self.screen_rect.centerx,
                               4 * self.screen_rect.centery / 10),
                              self.window, (1 / 10), "text"),
                    TextBlock(time, (self.screen_rect.centerx,
                                     12 * self.screen_rect.centery / 10),
                              self.window, (1 / 1.5), "text")]

            for text_block in text:
                text_block.draw()

            pygame.display.update()
            pygame.display.set_caption("SNAKE GAME  |  Game starts in "
                                       + time + " second(s) ...")
            pygame.time.wait(1000)

        LOGGER.info('EVENT: GAME START')

    def start(self):
        """Use menu to select the option/game mode."""
        opt = self.menu()

        while True:
            if opt == OPTIONS['QUIT']:
                pygame.quit()
                sys.exit()
            elif opt == OPTIONS['PLAY']:
                VAR.game_speed, mega_hardcore = self.select_speed()
                score = self.cycle_matches(n_matches = 1,
                                           mega_hardcore = mega_hardcore)
                opt = self.over(score)
            elif opt == OPTIONS['BENCHMARK']:
                VAR.game_speed, mega_hardcore = self.select_speed()
                score = self.cycle_matches(n_matches = VAR.benchmark,
                                           mega_hardcore = mega_hardcore)
                opt = self.over(score)
            elif opt == OPTIONS['LEADERBOARDS']:
                self.view_leaderboards()
            elif opt == OPTIONS['MENU']:
                opt = self.menu()
            if opt == OPTIONS['ADD_TO_LEADERBOARDS']:
                self.add_to_leaderboards(score, None) # Gotta improve this logic.
                self.view_leaderboards()

    def over(self, score):
        """If collision with wall or body, end the game and open options.

        Return
        ----------
        selected_option: int
            The selected option in the main loop.
        """
        score_option = None

        if len(score) == VAR.benchmark:
            score_option = TextBlock(' ADD TO LEADERBOARDS ',
                                     (self.screen_rect.centerx,
                                      8 * self.screen_rect.centery / 10),
                                     self.window, (1 / 15), "menu")

        text_score = 'SCORE: ' + str(int(np.mean(score)))
        list_menu = ['PLAY', 'MENU', 'ADD_TO_LEADERBOARDS', 'QUIT']
        menu_options = [TextBlock(' PLAY AGAIN ', (self.screen_rect.centerx,
                                                   4 * self.screen_rect.centery / 10),
                                  self.window, (1 / 15), "menu"),
                        TextBlock(' GO TO MENU ', (self.screen_rect.centerx,
                                                   6 * self.screen_rect.centery / 10),
                                  self.window, (1 / 15), "menu"),
                        score_option,
                        TextBlock(' QUIT ', (self.screen_rect.centerx,
                                             10 * self.screen_rect.centery / 10),
                                  self.window, (1 / 15), "menu"),
                        TextBlock(text_score, (self.screen_rect.centerx,
                                               15 * self.screen_rect.centery / 10),
                                  self.window, (1 / 10), "text")]
        pygame.display.set_caption("SNAKE GAME  | " + text_score
                                   + "  |  GAME OVER...")
        LOGGER.info('EVENT: GAME OVER | FINAL %s', text_score)
        selected_option = self.cycle_menu(menu_options, list_menu, OPTIONS)

        return selected_option

    def select_speed(self):
        """Speed menu, right before calling start_match.

        Return
        ----------
        speed: int
            The selected speed in the main loop.
        """
        list_menu = ['EASY', 'MEDIUM', 'HARD', 'MEGA_HARDCORE']
        menu_options = [TextBlock(LEVELS[0], (self.screen_rect.centerx,
                                              4 * self.screen_rect.centery / 10),
                                  self.window, (1 / 10), "menu"),
                        TextBlock(LEVELS[1], (self.screen_rect.centerx,
                                              8 * self.screen_rect.centery / 10),
                                  self.window, (1 / 10), "menu"),
                        TextBlock(LEVELS[2], (self.screen_rect.centerx,
                                              12 * self.screen_rect.centery / 10),
                                  self.window, (1 / 10), "menu"),
                        TextBlock(LEVELS[3], (self.screen_rect.centerx,
                                              16 * self.screen_rect.centery / 10),
                                  self.window, (1 / 10), "menu")]

        speed = self.cycle_menu(menu_options, list_menu, SPEEDS)
        mega_hardcore = False

        if speed == SPEEDS['MEGA_HARDCORE']:
            mega_hardcore = True

        return speed, mega_hardcore

    def single_player(self, mega_hardcore = False):
        """Game loop for single_player (HUMANS).

        Return
        ----------
        score: int
            The final score for the match (discounted of initial length).
        """
        # The main loop, it pump key_presses and update the board every tick.
        previous_size = self.snake.length # Initial size of the snake
        current_size = previous_size # Initial size
        color_list = self.gradient([(42, 42, 42), (152, 152, 152)],
                                   previous_size)

        # Main loop, where snakes moves after elapsed time is bigger than the
        # move_wait time. The last_key pressed is recorded to make the game more
        # smooth for human players.
        elapsed = 0
        last_key = self.snake.previous_action
        move_wait = VAR.game_speed

        while not self.game_over:
            elapsed += self.fps.get_time()  # Get elapsed time since last call.

            if mega_hardcore:  # Progressive speed increments, the hardest.
                move_wait = VAR.game_speed - (2 * (self.snake.length - 3))

            key_input = self.handle_input()  # Receive inputs with tick.
            invalid_key = self.snake.is_movement_invalid(key_input)

            if key_input is not None and not invalid_key:
                last_key = key_input

            if elapsed >= move_wait:  # Move and redraw
                elapsed = 0
                self.play(last_key)
                current_size = self.snake.length  # Update the body size

                if current_size > previous_size:
                    color_list = self.gradient([(42, 42, 42), (152, 152, 152)],
                                               current_size)

                    previous_size = current_size

                self.draw(color_list)

            pygame.display.update()
            self.fps.tick(GAME_FPS)  # Limit FPS to 100

        score = current_size - 3  # After the game is over, record score

        return score

    def check_collision(self):
        """Check wether any collisions happened with the wall or body.

        Return
        ----------
        collided: boolean
            Whether the snake collided or not.
        """
        collided = False

        if self.snake.head[0] > (VAR.board_size - 1) or self.snake.head[0] < 0:
            LOGGER.info('EVENT: WALL COLLISION')
            collided = True
        elif self.snake.head[1] > (VAR.board_size - 1) or self.snake.head[1] < 0:
            LOGGER.info('EVENT: WALL COLLISION')
            collided = True
        elif self.snake.head in self.snake.body[1:]:
            LOGGER.info('EVENT: BODY COLLISION')
            collided = True

        return collided

    def is_won(self):
        """Verify if the score is greater than 0.

        Return
        ----------
        won: boolean
            Whether the score is greater than 0.
        """
        return self.snake.length > 3

    def generate_food(self):
        """Generate new food if needed.

        Return
        ----------
        food_pos: tuple of 2 * int
            Current position of the food.
        """
        food_pos = self.food_generator.generate_food(self.snake.body)

        return food_pos

    def handle_input(self):
        """After getting current pressed keys, handle important cases.

        Return
        ----------
        action: int
            Handle human input to assess the next action.
        """
        pygame.event.set_allowed([pygame.QUIT, pygame.KEYDOWN])
        keys = pygame.key.get_pressed()
        pygame.event.pump()
        action = None

        if keys[pygame.K_ESCAPE] or keys[pygame.K_q]:
            LOGGER.info('ACTION: KEY PRESSED: ESCAPE or Q')
            self.over(self.snake.length - 3)
        elif keys[pygame.K_LEFT]:
            LOGGER.info('ACTION: KEY PRESSED: LEFT')
            action = ABSOLUTE_ACTIONS['LEFT']
        elif keys[pygame.K_RIGHT]:
            LOGGER.info('ACTION: KEY PRESSED: RIGHT')
            action = ABSOLUTE_ACTIONS['RIGHT']
        elif keys[pygame.K_UP]:
            LOGGER.info('ACTION: KEY PRESSED: UP')
            action = ABSOLUTE_ACTIONS['UP']
        elif keys[pygame.K_DOWN]:
            LOGGER.info('ACTION: KEY PRESSED: DOWN')
            action = ABSOLUTE_ACTIONS['DOWN']

        return action

    def state(self):
        """Create a matrix of the current state of the game.

        Return
        ----------
        canvas: np.array of size board_size**2
            Return the current state of the game in a matrix.
        """
        canvas = np.zeros((VAR.board_size, VAR.board_size))

        if self.game_over:
            pass
        else:
            body = self.snake.body

            for part in body:
                canvas[part[0], part[1]] = POINT_TYPE['BODY']

            canvas[body[0][0], body[0][1]] = POINT_TYPE['HEAD']

            if self.local_state:
                canvas = self.eval_local_safety(canvas, body)

            canvas[self.food_pos[0], self.food_pos[1]] = POINT_TYPE['FOOD']

        return canvas

    def relative_to_absolute(self, action):
        """Translate relative actions to absolute.

        Return
        ----------
        action: int
            Translated action from relative to absolute.
        """
        if action == RELATIVE_ACTIONS['FORWARD']:
            action = self.snake.previous_action
        elif action == RELATIVE_ACTIONS['LEFT']:
            if self.snake.previous_action == ABSOLUTE_ACTIONS['LEFT']:
                action = ABSOLUTE_ACTIONS['DOWN']
            elif self.snake.previous_action == ABSOLUTE_ACTIONS['RIGHT']:
                action = ABSOLUTE_ACTIONS['UP']
            elif self.snake.previous_action == ABSOLUTE_ACTIONS['UP']:
                action = ABSOLUTE_ACTIONS['LEFT']
            else:
                action = ABSOLUTE_ACTIONS['RIGHT']
        else:
            if self.snake.previous_action == ABSOLUTE_ACTIONS['LEFT']:
                action = ABSOLUTE_ACTIONS['UP']
            elif self.snake.previous_action == ABSOLUTE_ACTIONS['RIGHT']:
                action = ABSOLUTE_ACTIONS['DOWN']
            elif self.snake.previous_action == ABSOLUTE_ACTIONS['UP']:
                action = ABSOLUTE_ACTIONS['RIGHT']
            else:
                action = ABSOLUTE_ACTIONS['LEFT']

        return action

    def play(self, action):
        """Move the snake to the direction, eat and check collision."""
        self.scored = False
        self.step += 1
        self.food_pos = self.generate_food()

        if self.relative_pos:
            action = self.relative_to_absolute(action)

        if self.snake.move(action, self.food_pos):
            self.scored = True
            self.food_generator.is_food_on_screen = False

        if self.player == "HUMAN":
            if self.check_collision():
                self.game_over = True
        elif self.check_collision() or self.step > 50 * self.snake.length:
            self.game_over = True

    def get_reward(self):
        """Return the current reward. Can be used as the reward function.

        Return
        ----------
        reward: float
            Current reward of the game.
        """
        reward = REWARDS['MOVE']

        if self.game_over:
            reward = REWARDS['GAME_OVER']
        elif self.scored:
            reward = self.snake.length

        return reward

    def draw(self, color_list):
        """Draw the game, the snake and the food using pygame."""
        self.window.fill(pygame.Color(225, 225, 225))

        for part, color in zip(self.snake.body, color_list):
            pygame.draw.rect(self.window, color, pygame.Rect((part[0] *
                        VAR.block_size), part[1] * VAR.block_size,
                        VAR.block_size, VAR.block_size))

        pygame.draw.rect(self.window, VAR.food_color,
                         pygame.Rect(self.food_pos[0] * VAR.block_size,
                         self.food_pos[1] * VAR.block_size, VAR.block_size,
                         VAR.block_size))

        pygame.display.set_caption("SNAKE GAME  |  Score: "
                                   + str(self.snake.length - 3))

    def get_name(self):
        """See test.py in my desktop, for a textbox input in pygame"""
        return None

    def add_to_leaderboards(self, score, step):
        file_path = resource_path("resources/scores.json")

        name = self.get_name()
        new_score = {'name': 'test',
                     'ranking_data': {'score': score,
                                      'step': step}}

        with open(file_path, 'w') as leaderboards_file:
            json.dump(new_score, leaderboards_file)

    def view_leaderboards(self):
        list_menu = ['MENU']
        menu_options = [TextBlock('LEADERBOARDS',
                                  (self.screen_rect.centerx,
                                   2 * self.screen_rect.centery / 10),
                                  self.window, (1 / 12), "text")]

        file_path = resource_path("resources/scores.json")

        with open(file_path, 'r') as leaderboards_file:
            scores_data = json.loads(leaderboards_file.read())

        scores_data.sort(key = operator.itemgetter('score'))

#        for score in formatted_scores:
#            menu_options.append(TextBlock(person_ranked,
#                                (self.screen_rect.centerx,
#                                10 * self.screen_rect.centery / 10),
#                                self.window, (1 / 12), "text"))

        menu_options.append(TextBlock('MENU',
                            (self.screen_rect.centerx,
                            10 * self.screen_rect.centery / 10),
                            self.window, (1 / 12), "menu"))
        selected_option = self.cycle_menu(menu_options, list_menu, OPTIONS)

    @staticmethod
    def format_scores(scores, ammount):
        scores = scores[-ammount:]



    @staticmethod
    def eval_local_safety(canvas, body):
        """Evaluate the safety of the head's possible next movements.

        Return
        ----------
        canvas: np.array of size board_size**2
            After using game expertise, change canvas values to DANGEROUS if true.
        """
        if ((body[0][0] + 1) > (VAR.board_size - 1)
            or ([body[0][0] + 1, body[0][1]]) in body[1:]):
            canvas[VAR.board_size - 1, 0] = POINT_TYPE['DANGEROUS']
        if (body[0][0] - 1) < 0 or ([body[0][0] - 1, body[0][1]]) in body[1:]:
            canvas[VAR.board_size - 1, 1] = POINT_TYPE['DANGEROUS']
        if (body[0][1] - 1) < 0 or ([body[0][0], body[0][1] - 1]) in body[1:]:
            canvas[VAR.board_size - 1, 2] = POINT_TYPE['DANGEROUS']
        if ((body[0][1] + 1) > (VAR.board_size - 1)
            or ([body[0][0], body[0][1] + 1]) in body[1:]):
            canvas[VAR.board_size - 1, 3] = POINT_TYPE['DANGEROUS']

        return canvas

    @staticmethod
    def gradient(colors, steps, components = 3):
        """Function to create RGB gradients given 2 colors and steps. If
        component is changed to 4, it does the same to RGBA colors.

        Return
        ----------
        result: list of steps length of tuple of 3 * int (if RGBA, 4 * int)
            List of colors of calculated gradient from start to end.
        """
        def linear_gradient(start, finish, substeps):
            yield start

            for substep in range(1, substeps):
                yield tuple([(start[component]
                              + (float(substep) / (substeps - 1))
                              * (finish[component] - start[component]))
                             for component in range(components)])

        def pairs(seq):
            first_color, second_color = tee(seq)
            next(second_color, None)

            return zip(first_color, second_color)

        result = []
        substeps = int(float(steps) / (len(colors) - 1))

        for first_color, second_color in pairs(colors):
            for gradient_color in linear_gradient(first_color, second_color,
                                                  substeps):
                result.append(gradient_color)

        return result


def resource_path(relative_path):
    """Function to return absolute paths. Used while creating .exe file."""
    if hasattr(sys, '_MEIPASS'):
        return path.join(sys._MEIPASS, relative_path)

    return path.join(path.dirname(path.realpath(__file__)), relative_path)

VAR = GlobalVariables() # Initializing GlobalVariables
LOGGER = logging.getLogger(__name__) # Setting logger
environ['SDL_VIDEO_CENTERED'] = '1' # Centering the window


"""THIS"""

import numpy as np
from random import sample, uniform, random
from array import array  # Efficient numeric arrays


class ExperienceReplay:
    """The class that handles memory and experiences replay.

    Attributes
    ----------
    memory: list of experiences
        Memory list to insert experiences.
    memory_size: int, optional, default = 150000
        The ammount of experiences to be stored in the memory.
    input_shape: tuple of 3 * int
        The shape of the input which will be stored.
    """
    def __init__(self, memory_size = 150000):
        """Initialize parameters and the memory array."""
        self.memory_size = memory_size
        self.reset_memory() # Initiate the memory

    def exp_size(self):
        """Returns how much memory is stored."""
        return len(self.memory)

    def remember(self, s, a, r, s_prime, game_over):
        """Remember SARS' experiences, with the game_over parameter (done)."""
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])

        self.memory.append(experience)

        if self.memory_size > 0 and self.exp_size() > self.memory_size:
            self.memory.pop(0)

    def get_samples(self, batch_size):
        """Sample the memory according to PER flag.

        Return
        ----------
        batch: np.array of batch_size experiences
            The batched experiences from memory.
        IS_weights: np.array of batch_size of the weights
            As it's used only in PER, is an array of ones in this case.
        Indexes: list of batch_size * int
            As it's used only in PER, return None.
        """
        IS_weights = np.ones((batch_size, ))
        batch = np.array(sample(self.memory, batch_size))

        return batch, IS_weights, None

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9,
                    n_steps = 1):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])

            for idx, target in enumerate(Y_target):
                Qsa[idx] = target[actions[idx]]

            Qsa = np.array(Qsa)
        else:
            Qsa = np.max(Y[batch_size:], axis = 1)

        # Where the action happened, replace with the Q values of S_prime
        targets = np.array(Y[:batch_size])
        value = r + (gamma ** n_steps) * (1 - game_over) * Qsa
        targets[range(batch_size), a.astype(int)] = value

        return S, targets, IS_weights

    def reset_memory(self):
        """Set the memory as a blank list."""
        if self.memory_size <= 100:
            memory_size = 150000

        self.memory = []


class PrioritizedExperienceReplayNaive:
    """The class that handles memory and experiences replay.

    Attributes:
        memory: memory array to insert experiences.
        memory_size: the ammount of experiences to be stored in the memory.
        input_shape: the shape of the input which will be stored.
        batch_function: returns targets according to S.
        per: flag for PER usage.
        per_epsilon: used to replace "0" probabilities cases.
        per_alpha: how much prioritization to use.
        per_beta: importance sampling weights (IS_weights).
    """
    def __init__(self, memory_size = 150000, alpha = 0.6, epsilon = 0.001,
                 beta = 0.4, nb_epoch = 10000, decay = 0.5):
        """Initialize parameters and the memory array."""
        self.memory_size = memory_size
        self.epsilon = epsilon
        self.alpha = alpha
        self.beta = beta
        self.schedule = LinearSchedule(nb_epoch * decay, 1.0, beta)
        self.reset_memory() # Initiate the memory

    def exp_size(self):
        """Returns how much memory is stored."""
        return self.exp

    def get_priority(self, errors):
        """Returns priority based on how much prioritization to use."""
        return (errors + self.epsilon) ** self.alpha

    def update(self, tree_indices, errors):
        """Update a list of nodes, based on their errors."""
        priorities = self.get_priority(errors)

        for index, priority in zip(tree_indices, priorities):
            self.memory.update(index, priority)

    def remember(self, s, a, r, s_prime, game_over):
        """Remember SARS' experiences, with the game_over parameter (done)."""
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])

        max_priority = self.memory.max_leaf()

        if max_priority == 0:
            max_priority = self.get_priority(0)

        self.memory.insert(experience, max_priority)
        self.exp += 1

    def get_samples(self, batch_size):
        """Sample the memory according to PER flag."""
        batch = [None] * batch_size
        IS_weights = np.zeros((batch_size, ))
        tree_indices = [0] * batch_size

        memory_sum = self.memory.sum()
        len_seg = memory_sum / batch_size
        min_prob = self.memory.min_leaf() / memory_sum

        for i in range(batch_size):
            val = uniform(len_seg * i, len_seg * (i + 1))
            tree_indices[i], priority, batch[i] = self.memory.retrieve(val)
            prob = priority / self.memory.sum()
            IS_weights[i] = np.power(prob / min_prob, -self.beta)

        return np.array(batch), IS_weights, tree_indices

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9,
                    n_steps = 1):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])

            for idx, target in enumerate(Y_target):
                Qsa[idx] = target[actions[idx]]

            Qsa = np.array(Qsa)
        else:
            Qsa = np.max(Y[batch_size:], axis = 1)

        # Where the action happened, replace with the Q values of S_prime
        targets = np.array(Y[:batch_size])
        value = r + (gamma ** n_steps) * (1 - game_over) * Qsa
        targets[range(batch_size), a.astype(int)] = value

        errors = np.abs(value - Y[:batch_size].max(axis = 1)).clip(max = 1.)
        self.update_priorities(tree_indices, errors)

        return S, targets, IS_weights

    def reset_memory(self):
        """Set the memory as a blank list."""
        if self.memory_size <= 100:
            self.memory_size = 150000

        self.memory = SumTree(self.memory_size)
        self.exp = 0


class PrioritizedExperienceReplay:
    def __init__(self, memory_size, nb_epoch = 10000, epsilon = 0.001,
                 alpha = 0.6, beta = 0.4, decay = 0.5):
        self.memory_size = memory_size
        self.alpha = alpha
        self.epsilon = epsilon
        self.beta = beta
        self.schedule = LinearSchedule(nb_epoch * decay, 1.0, beta)
        self.max_priority = 1.0
        self.reset_memory()

    def exp_size(self):
        """Returns how much memory is stored."""
        return len(self.memory)

    def remember(self, s, a, r, s_prime, game_over):
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])
        if self.exp_size() < self.memory_size:
            self.memory.append(experience)
            self.pos += 1
        else:
            self.memory[self.pos] = experience
            self.pos = (self.pos + 1) % self.memory_size

        self._it_sum[self.pos] = self.max_priority ** self.alpha
        self._it_min[self.pos] = self.max_priority ** self.alpha

    def _sample_proportional(self, batch_size):
        res = array('i')

        for _ in range(batch_size):
            mass = random() * self._it_sum.sum(0, self.exp_size() - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)

        return res

    def get_priority(self, errors):
        """Returns priority based on how much prioritization to use."""
        return (errors + self.epsilon) ** self.alpha

    def get_samples(self, batch_size):
        idxes = self._sample_proportional(batch_size)

        weights = array('f')
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * self.exp_size()) ** (-self.beta)

        for idx in idxes:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * self.exp_size()) ** (-self.beta)
            weights.append(weight / max_weight)

        weights = np.array(weights, dtype=np.float32)
        samples = [self.memory[idx] for idx in idxes]

        return np.array(samples), weights, idxes

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9,
                    n_steps = 1):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])

            for idx, target in enumerate(Y_target):
                Qsa[idx] = target[actions[idx]]

            Qsa = np.array(Qsa)
        else:
            Qsa = np.max(Y[batch_size:], axis = 1)

        # Where the action happened, replace with the Q values of S_prime
        targets = np.array(Y[:batch_size])
        value = r + (gamma ** n_steps) * (1 - game_over) * Qsa
        targets[range(batch_size), a.astype(int)] = value

        errors = np.abs(value - Y[:batch_size].max(axis = 1)).clip(max = 1.)
        self.update_priorities(tree_indices, errors)

        return S, targets, IS_weights

    def update_priorities(self, idxes, errors):
        priorities = self.get_priority(errors)

        for idx, priority in zip(idxes, priorities):
            self._it_sum[idx] = priority ** self.alpha
            self._it_min[idx] = priority ** self.alpha

            self.max_priority = max(self.max_priority, priority)

    def reset_memory(self):
        """Set the memory as a blank list."""
        if self.memory_size <= 100:
            self.memory_size = 150000

        self.memory = []
        self.pos = 0

        it_capacity = 1

        while it_capacity < self.memory_size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)


#!/usr/bin/env python

"""dqn: First try to create an AI for SnakeGame. Is it good enough?

This algorithm is a implementation of DQN, Double DQN logic (using a target
network to have fixed Q-targets), Dueling DQN logic (Q(s,a) = Advantage + Value),
PER (Prioritized Experience Replay, using Sum Trees) and Multi-step returns. You
can read more about these on https://goo.gl/MctLzp

Implemented algorithms
----------
    * Simple Deep Q-network (DQN with ExperienceReplay);
        Paper: https://arxiv.org/abs/1312.5602
    * Double Deep Q-network (Double DQN);
        Paper: https://arxiv.org/abs/1509.06461
    * Dueling Deep Q-network (Dueling DQN);
        Paper: https://arxiv.org/abs/1511.06581
    * Prioritized Experience Replay (PER);
        Paper: https://arxiv.org/abs/1511.05952
    * Multi-step returns (n-steps);
        Paper: https://arxiv.org/pdf/1703.01327
    * Noisy nets.
        Paper: https://arxiv.org/abs/1706.10295

Arguments
----------
--load: 'file.h5'
    Load a previously trained model in '.h5' format.
--board_size: int, optional, default = 10
    Assign the size of the board.
--nb_frames: int, optional, default = 4
    Assign the number of frames per stack, default = 4.
--nb_actions: int, optional, default = 5
    Assign the number of actions possible.
--update_freq: int, optional, default = 0.001
    Whether to soft or hard update the target. Epochs or ammount of the update.
--visual: boolean, optional, default = False
    Select wheter or not to draw the game in pygame.
--double: boolean, optional, default = False
    Use a target network with double DQN logic.
--dueling: boolean, optional, default = False
    Whether to use dueling network logic, Q(s,a) = A + V.
--per: boolean, optional, default = False
    Use Prioritized Experience Replay (based on Sum Trees).
--local_state: boolean, optional, default = True
    Verify is possible next moves are dangerous (field expertise)
    THIS
"""

import numpy as np
from array import array
import random

__author__ = "Victor Neves"
__license__ = "MIT"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"


class Agent:
    """Agent based in a simple DQN that can read states, remember and play.

    Attributes
    ----------
    memory: object
        Memory used in training. ExperienceReplay or PrioritizedExperienceReplay
    memory_size: int, optional, default = -1
        Capacity of the memory used.
    model: keras model
        The input model in Keras.
    target: keras model, optional, default = None
        The target model, used to calculade the fixed Q-targets.
    nb_frames: int, optional, default = 4
        Ammount of frames for each experience (sars).
    board_size: int, optional, default = 10
        Size of the board used.
    frames: list of experiences
        The buffer of frames, store sars experiences.
    per: boolean, optional, default = False
        Flag for PER usage.
    update_target_freq: int or float, default = 0.001
        Whether soft or hard updates occur. If < 1, soft updated target model.
    n_steps: int, optional, default = 1
        Size of the rewards buffer, to use Multi-step returns.
    """
    def __init__(self, model, sess, target = None, memory_size = -1, nb_frames = 4,
                 board_size = 10, per = False, update_target_freq = 0.001):
        """Initialize the agent with given attributes."""
        if per:
            self.memory = PrioritizedExperienceReplay(memory_size = memory_size)
        else:
            self.memory = ExperienceReplay(memory_size = memory_size)

        self.per = per
        self.model = model
        self.target = target
        self.nb_frames = nb_frames
        self.board_size = board_size
        self.update_target_freq = update_target_freq
        self.sess = sess
        self.set_noise_list()
        self.clear_frames()

    def reset_memory(self):
        """Reset memory if necessary."""
        self.memory.reset_memory()

    def set_noise_list(self):
        """Set a list of noise variables if NoisyNet is involved."""
        self.noise_list = []
        for layer in self.model.layers:
            if type(layer) in {NoisyDenseFG}:
                self.noise_list.extend(layer.noise_list)

    def sample_noise(self):
        """Resample noise variables in NoisyNet."""
        for noise in self.noise_list:
            self.sess.run(noise.initializer)

    def get_game_data(self, game):
        """Create a list with 4 frames and append/pop them each frame.

        Return
        ----------
        expanded_frames: list of experiences
            The buffer of frames, shape = (nb_frames, board_size, board_size)
        """
        frame = game.state()

        if self.frames is None:
            self.frames = [frame] * self.nb_frames
        else:
            self.frames.append(frame)
            self.frames.pop(0)

        expanded_frames = np.expand_dims(self.frames, 0)
        # expanded_frames = np.transpose(expanded_frames, [0, 3, 2, 1])

        return expanded_frames

    def clear_frames(self):
        """Reset frames to restart appending."""
        self.frames = None

    def update_target_model_hard(self):
        """Update the target model with the main model's weights."""
        self.target.set_weights(self.model.get_weights())

    def transfer_weights(self):
        """Transfer Weights from Model to Target at rate update_target_freq."""
        model_weights = self.model.get_weights()
        target_weights = self.target.get_weights()

        for i in range(len(W)):
            target_weights[i] = (self.update_target_freq * model_weights[i]
                                 + ((1 - self.update_target_frequency)
                                    * target_weights[i]))

        self.target.set_weights(target_weights)

    def print_metrics(self, epoch, nb_epoch, history_size, policy, value,
                      win_count, history_step, history_reward,
                      history_loss = None, verbose = 1):
        """Function to print metrics of training steps."""
        if verbose == 0:
            pass
        elif verbose == 1:
            text_epoch = ('Epoch: {:03d}/{:03d} | Mean size 10: {:.1f} | '
                           + 'Longest 10: {:03d} | Mean steps 10: {:.1f} | '
                           + 'Wins: {:d} | Win percentage: {:.1f}%')
            print(text_epoch.format(epoch + 1, nb_epoch,
                                    np.mean(history_size[-10:]),
                                    max(history_size[-10:]),
                                    np.mean(history_step[-10:]),
                                    win_count, 100 * win_count/(epoch + 1)))
        else:
            text_epoch = 'Epoch: {:03d}/{:03d}'  # Print epoch info
            print(text_epoch.format(epoch + 1, nb_epoch))

            if loss is not None:  # Print training performance
                text_train = ('\t\x1b[0;30;47m' + ' Training metrics ' + '\x1b[0m'
                              + '\tTotal loss: {:.4f} | Loss per step: {:.4f} | '
                              + 'Mean loss - 100 episodes: {:.4f}')
                print(text_perf.format(history_loss[-1],
                                       history_loss[-1] / history_step[-1],
                                       np.mean(history_loss[-100:])))

            text_game = ('\t\x1b[0;30;47m' + ' Game metrics ' + '\x1b[0m'
                         + '\t\tSize: {:d} | Ammount of steps: {:d} | '
                         + 'Steps per food eaten: {:.1f} | '
                         + 'Mean size - 100 episodes: {:.1f}')
            print(text_game.format(history_size[-1], history_step[-1],
                                   history_size[-1] / history_step[-1],
                                   np.mean(history_step[-100:])))

            # Print policy metrics
            if policy == "BoltzmannQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tBoltzmann Temperature: {:.2f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            elif policy == "BoltzmannGumbelQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tNumber of actions: {:.0f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            else:
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tEpsilon: {:.2f} | Episode reward: {:.1f} | '
                               + 'Wins: {:d} | Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))

    def train_model(self, model, target, batch_size, gamma, nb_actions, epoch = 0):
        """Function to train the model on a batch of the data. The optimization
        flag is used when we are not playing, just batching and optimizing.

        Return
        ----------
        loss: float
            Training loss of given batch.
        """
        loss = 0.
        batch = self.memory.get_targets(model = self.model,
                                        target = self.target,
                                        batch_size = batch_size,
                                        gamma = gamma,
                                        nb_actions = nb_actions,
                                        n_steps = self.n_steps)

        if batch:
            inputs, targets, IS_weights = batch

            if inputs is not None and targets is not None:
                loss = float(self.model.train_on_batch(inputs,
                                                       targets,
                                                       IS_weights))

        return loss

    def train(self, game, nb_epoch = 10000, batch_size = 64, gamma = 0.95,
              eps = [1., .01], temp = [1., 0.01], learning_rate = 0.5,
              observe = 0, optim_rounds = 1, policy = "EpsGreedyQPolicy",
              verbose = 1, n_steps = 1):
        """The main training function, loops the game, remember and choose best
        action given game state (frames)."""
        if not hasattr(self, 'n_steps'):
            self.n_steps = n_steps  # Set attribute only once

        history_size = array('i')  # Holds all the sizes
        history_step = array('f')  # Holds all the steps
        history_loss = array('f')  # Holds all the losses
        history_reward = array('f')  # Holds all the rewards

        # Select exploration policy. EpsGreedyQPolicy runs faster, but takes
        # longer to converge. BoltzmannGumbelQPolicy is the slowest, but
        # converge really fast (0.1 * nb_epoch used in EpsGreedyQPolicy).
        # BoltzmannQPolicy is in the middle.
        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp[0], temp[1], nb_epoch * learning_rate)
        elif policy == "BoltzmannGumbelQPolicy":
            q_policy = BoltzmannGumbelQPolicy()
        else:
            q_policy = EpsGreedyQPolicy(eps[0], eps[1], nb_epoch * learning_rate)

        nb_actions = game.nb_actions
        win_count = 0

        # If optim_rounds is bigger than one, the model will keep optimizing
        # after the exploration, in turns of nb_epoch size.
        for turn in range(optim_rounds):
            if turn > 0:
                for epoch in range(nb_epoch):
                    loss = self.train_model(model = self.model,
                                            epoch = epoch,
                                            target = self.target,
                                            batch_size = batch_size,
                                            gamma = gamma,
                                            nb_actions = nb_actions)
                    text_optim = ('Optimizer turn: {:2d} | Epoch: {:03d}/{:03d}'
                                  + '| Loss: {:.4f}')
                    print(text_optim.format(turn, epoch + 1, nb_epoch, loss))
            else:  # Exploration and training
                for epoch in range(nb_epoch):
                    loss = 0.
                    total_reward = 0.
                    game.reset_game()
                    self.clear_frames()
                    S = self.get_game_data(game)

                    if n_steps > 1:  # Create multi-step returns buffer.
                        n_step_buffer = array('f')

                    while not game.game_over:  # Main loop, until game_over
                        game.food_pos = game.generate_food()
                        self.sample_noise()
                        action, value = q_policy.select_action(self.model,
                                                               S, epoch,
                                                               nb_actions)
                        game.play(action)
                        r = game.get_reward()
                        total_reward += r

                        if n_steps > 1:
                            n_step_buffer.append(r)

                            if len(n_step_buffer) < n_steps:
                                R = r
                            else:
                                R = sum([n_step_buffer[i] * (gamma ** i)\
                                        for i in range(n_steps)])

                                n_step_buffer.pop(0)
                        else:
                            R = r

                        S_prime = self.get_game_data(game)
                        experience = [S, action, R, S_prime, game.game_over]
                        self.memory.remember(*experience)  # Add to the memory
                        S = S_prime  # Advance to the next state (stack of S)

                        if epoch >= observe:  # Get the batchs and train
                            loss += self.train_model(model = self.model,
                                                     target = self.target,
                                                     batch_size = batch_size,
                                                     gamma = gamma,
                                                     nb_actions = nb_actions)

                    if game.is_won():
                        win_count += 1  # Counter of wins for metrics

                    if self.per:  # Advance beta, used in PER
                        self.memory.beta = self.memory.schedule.value(epoch)

                    if self.target is not None:  # Update the target model
                        if update_target_freq >= 1: # Hard updates
                            if epoch % self.update_target_freq == 0:
                                self.update_target_model_hard()
                        elif update_target_freq < 1.:  # Soft updates
                            self.transfer_weights()

                    history_size.append(game.snake.length)
                    history_step.append(game.step)
                    history_loss.append(loss)
                    history_reward.append(total_reward)

                    if (epoch + 1) % 10 == 0:
                        self.print_metrics(epoch = epoch, nb_epoch = nb_epoch,
                                           history_size = history_size,
                                           history_loss = history_loss,
                                           history_step = history_step,
                                           history_reward = history_reward,
                                           policy = policy, value = value,
                                           win_count = win_count,
                                           verbose = verbose)

    def test(self, game, nb_epoch = 1000, eps = 0.01, temp = 0.01,
             visual = False, policy = "GreedyQPolicy"):
        """Play the game with the trained agent. Can use the visual tag to draw
            in pygame."""
        win_count = 0

        history_size = array('i')  # Holds all the sizes
        history_step = array('f')  # Holds all the steps
        history_reward = array('f')  # Holds all the rewards

        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp, temp, nb_epoch)
        elif policy == "EpsGreedyQPolicy":
            q_policy = EpsGreedyQPolicy(eps, eps, nb_epoch)
        else:
            q_policy = GreedyQPolicy()

        for epoch in range(nb_epoch):
            game.reset_game()
            self.clear_frames()

            if visual:
                game.create_window()
                previous_size = game.snake.length  # Initial size of the snake
                color_list = game.gradient([(42, 42, 42), (152, 152, 152)],\
                                               previous_size)
                elapsed = 0

            while not game.game_over:
                if visual:
                    elapsed += game.fps.get_time()  # Get elapsed time since last call.

                    if elapsed >= 60:
                        elapsed = 0
                        S = self.get_game_data(game)
                        action, value = q_policy.select_action(self.model, S, epoch, game.nb_actions)
                        game.play(action)
                        current_size = game.snake.length  # Update the body size

                        if current_size > previous_size:
                            color_list = game.gradient([(42, 42, 42), (152, 152, 152)],
                                                       game.snake.length)

                            previous_size = current_size

                        game.draw(color_list)

                    pygame.display.update()
                    game.fps.tick(120)  # Limit FPS to 100
                else:
                    S = self.get_game_data(game)
                    action, value = q_policy.select_action(self.model, S, epoch, game.nb_actions)
                    game.play(action)
                    current_size = game.snake.length  # Update the body size

                if game.game_over:
                    history_size.append(current_size)
                    history_step.append(game.step)
                    history_reward.append(game.get_reward())

            if game.is_won():
                win_count += 1

        print("Accuracy: {} %".format(100. * win_count / nb_epoch))
        print("Mean size: {} | Biggest size: {} | Smallest size: {}"\
              .format(np.mean(history_size), np.max(history_size),
                      np.min(history_size)))
        print("Mean steps: {} | Biggest step: {} | Smallest step: {}"\
              .format(np.mean(history_step), np.max(history_step),
                      np.min(history_step)))
        print("Mean rewards: {} | Biggest reward: {} | Smallest reward: {}"\
              .format(np.mean(history_reward), np.max(history_reward),
                      np.min(history_reward)))

"""THIS"""
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base
from tensorflow.python.ops.init_ops import Constant

class NoisyDense(tf.keras.layers.Dense):

    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        if input_shape[-1].value is None:
            raise ValueError('The last dimension of the inputs to `Dense` '
                             'should be defined. Found `None`.')
        self.input_spec = base.InputSpec(min_ndim=2,
                                         axes={-1: input_shape[-1].value})
        kernel_shape = [input_shape[-1].value, self.units]
        kernel_quiet = self.add_variable('kernel_quiet',
                                         shape=kernel_shape,
                                         initializer=self.kernel_initializer,
                                         regularizer=self.kernel_regularizer,
                                         constraint=self.kernel_constraint,
                                         dtype=self.dtype,
                                         trainable=True)
        scale_init = Constant(value=(0.5 / np.sqrt(kernel_shape[0])))
        kernel_noise_scale = self.add_variable('kernel_noise_scale',
                                               shape=kernel_shape,
                                               initializer=scale_init,
                                               dtype=self.dtype,
                                               trainable=True)
        kernel_noise = self.make_kernel_noise(shape=kernel_shape)
        self.kernel = kernel_quiet + kernel_noise_scale * kernel_noise
        if self.use_bias:
            bias_shape = [self.units,]
            bias_quiet = self.add_variable('bias_quiet',
                                           shape=bias_shape,
                                           initializer=self.bias_initializer,
                                           regularizer=self.bias_regularizer,
                                           constraint=self.bias_constraint,
                                           dtype=self.dtype,
                                           trainable=True)
            bias_noise_scale = self.add_variable(name='bias_noise_scale',
                                                 shape=bias_shape,
                                                 initializer=scale_init,
                                                 dtype=self.dtype,
                                                 trainable=True)
            bias_noise = self.make_bias_noise(shape=bias_shape)
            self.bias = bias_quiet + bias_noise_scale * bias_noise
        else:
            self.bias = None
        self.built = True

    def make_kernel_noise(self, shape):
        raise NotImplementedError

    def make_bias_noise(self, shape):
        raise NotImplementedError


class NoisyDenseIG(NoisyDense):
    '''
    Noisy dense layer with independent Gaussian noise
    '''
    def make_kernel_noise(self, shape):
        noise = tf.random_normal(shape, dtype=self.dtype)
        kernel_noise = tf.Variable(noise, trainable=False, dtype=self.dtype)
        self.noise_list = [kernel_noise]
        return kernel_noise

    def make_bias_noise(self, shape):
        noise = tf.random_normal(shape, dtype=self.dtype)
        bias_noise = tf.Variable(noise, trainable=False, dtype=self.dtype)
        self.noise_list.append(bias_noise)
        return bias_noise


class NoisyDenseFG(NoisyDense):
    '''
    Noisy dense layer with factorized Gaussian noise
    '''
    def make_kernel_noise(self, shape):
        kernel_noise_input = self.make_fg_noise(shape=[shape[0]])
        kernel_noise_output = self.make_fg_noise(shape=[shape[1]])
        self.noise_list = [kernel_noise_input, kernel_noise_output]
        kernel_noise = kernel_noise_input[:, tf.newaxis] * kernel_noise_output
        return kernel_noise

    def make_bias_noise(self, shape):
        return self.noise_list[1] # kernel_noise_output

    def make_fg_noise(self, shape):
        noise = tf.random_normal(shape, dtype=self.dtype)
        trans_noise = tf.sign(noise) * tf.sqrt(tf.abs(noise))
        return tf.Variable(trans_noise, trainable=False, dtype=self.dtype)
        
        
import random
import numpy as np

class LinearSchedule(object):
    def __init__(self, schedule_timesteps, final_p, initial_p):
        """Linear interpolation between initial_p and final_p over
        schedule_timesteps. After this many timesteps pass final_p is
        returned.
        Parameters
        ----------
        schedule_timesteps: int
            Number of timesteps for which to linearly anneal initial_p
            to final_p
        initial_p: float
            initial output value
        final_p: float
            final output value
        """
        self.schedule_timesteps = schedule_timesteps
        self.final_p = final_p
        self.initial_p = initial_p

    def value(self, t):
        """See Schedule.value"""
        fraction = min(float(t) / self.schedule_timesteps, 1.0)
        return self.initial_p + fraction * (self.final_p - self.initial_p)


import random
import numpy as np

class GreedyQPolicy:
    """Implement the greedy policy

    Greedy policy always takes current best action.
    """
    def __init__(self):
        super(GreedyQPolicy, self).__init__()

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)
        action = int(np.argmax(q[0]))

        return action, 0

    def get_config(self):
        """Return configurations of GreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(GreedyQPolicy, self).get_config()
        return config


class EpsGreedyQPolicy:
    """Implement the epsilon greedy policy

    Eps Greedy policy either:

    - takes a random action with probability epsilon
    - takes current best action with prob (1 - epsilon)
    """
    def __init__(self, max_eps=1., min_eps = .01, nb_epoch = 10000):
        super(EpsGreedyQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_eps, max_eps)

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        rand = random.random()
        self.eps = self.schedule.value(epoch)

        if rand < self.eps:
            action = int(nb_actions * rand)
        else:
            q = model.predict(state)
            action = int(np.argmax(q[0]))

        return action, self.eps

    def get_config(self):
        """Return configurations of EpsGreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(EpsGreedyQPolicy, self).get_config()
        config['eps'] = self.eps
        return config


class BoltzmannQPolicy:
    """Implement the Boltzmann Q Policy
    Boltzmann Q Policy builds a probability law on q values and returns
    an action selected randomly according to this law.
    """
    def __init__(self, max_temp = 1., min_temp = .01, nb_epoch = 10000, clip = (-500., 500.)):
        super(BoltzmannQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_temp, max_temp)
        self.clip = clip

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        self.temp = self.schedule.value(epoch)
        arg = q / self.temp

        exp_values = np.exp(arg - arg.max())
        probs = exp_values / exp_values.sum()
        action = np.random.choice(range(nb_actions), p = probs)

        return action, self.temp

    def get_config(self):
        """Return configurations of BoltzmannQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannQPolicy, self).get_config()
        config['temp'] = self.temp
        config['clip'] = self.clip
        return config


class BoltzmannGumbelQPolicy:
    """Implements Boltzmann-Gumbel exploration (BGE) adapted for Q learning
    based on the paper Boltzmann Exploration Done Right
    (https://arxiv.org/pdf/1705.10257.pdf).
    BGE is invariant with respect to the mean of the rewards but not their
    variance. The parameter C, which defaults to 1, can be used to correct for
    this, and should be set to the least upper bound on the standard deviation
    of the rewards.
    BGE is only available for training, not testing. For testing purposes, you
    can achieve approximately the same result as BGE after training for N steps
    on K actions with parameter C by using the BoltzmannQPolicy and setting
    tau = C/sqrt(N/K)."""

    def __init__(self, C = 1.0):
        super(BoltzmannGumbelQPolicy, self).__init__()
        self.C = C
        self.action_counts = None

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        q = q.astype('float64')

        # If we are starting training, we should reset the action_counts.
        # Otherwise, action_counts should already be initialized, since we
        # always do so when we begin training.
        if epoch == 0:
            self.action_counts = np.ones(q.shape)

        beta = self.C/np.sqrt(self.action_counts)
        Z = np.random.gumbel(size = q.shape)

        perturbation = beta * Z
        perturbed_q_values = q + perturbation
        action = np.argmax(perturbed_q_values)

        self.action_counts[action] += 1
        return action, np.sum(self.action_counts)

    def get_config(self):
        """Return configurations of BoltzmannGumbelQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannGumbelQPolicy, self).get_config()
        config['C'] = self.C
        return config

#!/usr/bin/env python

"""clipped_error: L1 for errors < clip_value else L2 error.

Functions:
    huber_loss: Return L1 error if absolute error is less than clip_value, else
                return L2 error.
    clipped_error: Call huber_loss with default clip_value to 1.0.
"""

import numpy as np
from keras import backend as K
import tensorflow as tf

__author__ = "Victor Neves"
__license__ = "MIT"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"

def huber_loss(y_true, y_pred, clip_value):
	# Huber loss, see https://en.wikipedia.org/wiki/Huber_loss and
	# https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b
	# for details.
	assert clip_value > 0.

	x = y_true - y_pred
	if np.isinf(clip_value):
		# Spacial case for infinity since Tensorflow does have problems
		# if we compare `K.abs(x) < np.inf`.
		return .5 * tf.square(x)

	condition = tf.abs(x) < clip_value
	squared_loss = .5 * tf.square(x)
	linear_loss = clip_value * (tf.abs(x) - .5 * clip_value)

	if hasattr(tf, 'select'):
		return tf.select(condition, squared_loss, linear_loss)  # condition, true, false
	else:
		return tf.where(condition, squared_loss, linear_loss)  # condition, true, false

def clipped_error(y_true, y_pred):
	return tf.keras.backend.mean(huber_loss(y_true, y_pred, clip_value = 1.), axis = -1)

#def CNN1(optimizer, loss, stack, input_size, output_size):
 #   model = Sequential()
  #  model.add(Conv2D(32, (3, 3), activation = 'relu', input_shape = (stack,
   #                                                                  input_size,
    #                                                                 input_size)))
#    model.add(Conv2D(64, (3, 3), activation = 'relu'))
 #   model.add(Conv2D(128, (3, 3), activation = 'relu'))
  #  model.add(Conv2D(256, (3, 3), activation = 'relu'))
   # model.add(Flatten())
    #model.add(Dense(1024, activation = 'relu'))
    #model.add(Dense(output_size))
    #model.compile(optimizer = optimizer, loss = loss)

    #return model
    
def CNN4(optimizer, loss, stack, input_size, output_size):
    """From @Kaixhin implementation's of the Rainbow paper."""
    model = Sequential()
    model.add(Conv2D(32, (4, 4), activation = 'relu', input_shape = (stack,
                                                                    input_size,
                                                                    input_size)))
    model.add(Conv2D(64, (2, 2), activation = 'relu'))
    model.add(Conv2D(64, (2, 2), activation = 'relu'))
    model.add(Flatten())
    model.add(Dense(3136, activation = 'relu'))
    model.add(Dense(output_size))
    model.compile(optimizer = optimizer, loss = loss)

    return model

from keras.optimizers import RMSprop, Nadam
from keras.models import Sequential, load_model, Model
from keras.layers import Conv2D, Dense, Flatten, MaxPooling2D, Flatten,\
                             Input, Lambda, Add
K.set_image_dim_ordering('th')  
  
board_size = 10
nb_frames = 4
  
game = Game(player = "ROBOT", board_size = board_size,
                        local_state = True, relative_pos = False)


with tf.Session() as sess:
    model = CNN4(optimizer = RMSprop(), loss = clipped_error,
                                stack = nb_frames, input_size = board_size,
                                output_size = game.nb_actions)
    target = None

    sess.run(tf.global_variables_initializer())
    agent = Agent(model = model, sess = sess, target = target, memory_size = -1,
                      nb_frames = nb_frames, board_size = board_size,
                      per = False, update_target_freq = 500)
    agent.train(game, batch_size = 64, nb_epoch = 10000,
                    gamma = 0.95, n_steps = 1)


Using TensorFlow backend.


Epoch: 010/10000 | Mean size 10: 3.6 | Longest 10: 006 | Mean steps 10: 10.7 | Wins: 3 | Win percentage: 30.0%
Epoch: 020/10000 | Mean size 10: 3.2 | Longest 10: 004 | Mean steps 10: 14.0 | Wins: 5 | Win percentage: 25.0%
Epoch: 030/10000 | Mean size 10: 3.0 | Longest 10: 003 | Mean steps 10: 17.9 | Wins: 5 | Win percentage: 16.7%
Epoch: 040/10000 | Mean size 10: 3.3 | Longest 10: 004 | Mean steps 10: 9.7 | Wins: 8 | Win percentage: 20.0%
Epoch: 050/10000 | Mean size 10: 3.2 | Longest 10: 004 | Mean steps 10: 11.1 | Wins: 10 | Win percentage: 20.0%
Epoch: 060/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 13.5 | Wins: 11 | Win percentage: 18.3%
Epoch: 070/10000 | Mean size 10: 3.0 | Longest 10: 003 | Mean steps 10: 14.5 | Wins: 11 | Win percentage: 15.7%
Epoch: 080/10000 | Mean size 10: 3.0 | Longest 10: 003 | Mean steps 10: 12.0 | Wins: 11 | Win percentage: 13.8%
Epoch: 090/10000 | Mean size 10: 3.0 | Longest 10: 003 | Mean steps 10: 10.1 | Wins: 11 | Win percentage: 12.

In [0]:
#model.save('keras.h5')

#!zip -r model-epsgreedy-bench.zip keras.h5 
#from google.colab import files
#files.download('model-epsgreedy-bench.zip')
#model = load_model('keras.h5', custom_objects={'clipped_error': clipped_error})

board_size = 10
nb_frames = 4
nb_actions = 5

target = None

agent = Agent(model = model, target = target, memory_size = -1,
                          nb_frames = nb_frames, board_size = board_size,
                          per = False)
#%lprun -f agent.train agent.train(game, batch_size = 64, nb_epoch = 10, gamma = 0.95, update_target_freq = 500, policy = "EpsGreedyQPolicy")

agent.test(game, visual = False, nb_epoch = 1000)

In [0]:
#!/usr/bin/env python
import numpy as np
from argparse import ArgumentParser

from keras import backend as K
import keras.optimizers as optimizers
import tensorflow as tf

__author__ = "Victor Neves"
__license__ = "MIT"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"

class HandleArguments:
        """Handle arguments provided in the command line when executing the model.

        Attributes:
            args: arguments parsed in the command line.
            status_load: a flag for usage of --load argument.
            status_visual: a flag for usage of --visual argument.

            NEED UPDATE!
        """
        def __init__(self):
            self.parser = ArgumentParser() # Receive arguments
            self.parser.add_argument("-l", "--load", help = "load a previously trained model. the argument is the filename", required = False, default = "")
            self.parser.add_argument("-v", "--visual", help = "define board size", required = False, action = 'store_true')
            self.parser.add_argument("-du", "--dueling", help = "use dueling DQN", required = False, action = 'store_true')
            self.parser.add_argument("-do", "--double", help = "use double DQN", required = False, action = 'store_true')
            self.parser.add_argument("-p", "--per", help = "use Prioritized Experience Replay", required = False, action = 'store_true')
            self.parser.add_argument("-ls", "--local_state", help = "define board size", required = False, action = 'store_true')
            self.parser.add_argument("-g", "--board_size", help = "define board size", required = False, default = 10, type = int)
            self.parser.add_argument("-nf", "--nb_frames", help = "define board size", required = False, default = 4, type = int)
            self.parser.add_argument("-na", "--nb_actions", help = "define board size", required = False, default = 5, type = int)
            self.parser.add_argument("-uf", "--update_freq", help = "frequency to update target", required = False, default = 500, type = int)

            self.args = self.parser.parse_args()
            self.status_load = False
            self.status_visual = False
            self.local_state = False
            self.dueling = False
            self.double = False
            self.per = False

            if self.args.load:
                script_dir = path.dirname(__file__) # Absolute dir the script is in
                abs_file_path = path.join(script_dir, self.args.load)
                model = load_model(abs_file_path)

                self.status_load = True

            if self.args.visual:
                self.status_visual = True

            if self.args.local_state:
                self.local_state = True

            if self.args.dueling:
                self.dueling = True

            if self.args.double:
                self.double = True

            if self.args.per:
                self.per = True

def huber_loss(y_true, y_pred, clip_value):
	# Huber loss, see https://en.wikipedia.org/wiki/Huber_loss and
	# https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b
	# for details.
	assert clip_value > 0.

	x = y_true - y_pred
	if np.isinf(clip_value):
		# Spacial case for infinity since Tensorflow does have problems
		# if we compare `K.abs(x) < np.inf`.
		return .5 * K.square(x)

	condition = K.abs(x) < clip_value
	squared_loss = .5 * K.square(x)
	linear_loss = clip_value * (K.abs(x) - .5 * clip_value)
	if K.backend() == 'tensorflow':
		if hasattr(tf, 'select'):
			return tf.select(condition, squared_loss, linear_loss)  # condition, true, false
		else:
			return tf.where(condition, squared_loss, linear_loss)  # condition, true, false
	elif K.backend() == 'theano':
		from theano import tensor as T
		return T.switch(condition, squared_loss, linear_loss)
	else:
		raise RuntimeError('Unknown backend "{}".'.format(K.backend()))

def clipped_error(y_true, y_pred):
	return K.mean(huber_loss(y_true, y_pred, clip_value = 1.), axis = -1)

#!/usr/bin/env python

""" Needs update!
"""

import numpy as np
from keras.models import Sequential, load_model, Model
from keras.layers import *
import tensorflow as tf

__author__ = "Victor Neves"
__license__ = "MIT"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"

def weird_CNN(optimizer, loss, stack, input_size, output_size, min_neurons = 16,
         max_neurons = 128, kernel_size = (3,3), layers = 4):
    # INPUTS
    # size     - size of the input images
    # n_layers - number of layers
    # OUTPUTS
    # model    - compiled CNN

    # Define hyperparamters
    MIN_NEURONS = min_neurons
    MAX_NEURONS = max_neurons
    KERNEL = kernel_size
    n_layers = layers

    # Determine the # of neurons in each convolutional layer
    steps = np.floor(MAX_NEURONS / (n_layers + 1))
    neurons = np.arange(MIN_NEURONS, MAX_NEURONS, steps)
    neurons = neurons.astype(np.int32)

    # Define a model
    model = Sequential()

    # Add convolutional layers
    for i in range(0, n_layers):
        if i == 0:
            model.add(Conv2D(neurons[i], KERNEL, input_shape = (stack,
                                                                input_size,
                                                                input_size)))
        else:
            model.add(Conv2D(neurons[i], KERNEL))

        model.add(Activation('relu'))

    # Add max pooling layer
    model.add(MaxPooling2D(pool_size = (2, 2)))
    model.add(Flatten())
    model.add(Dense(MAX_NEURONS * 4))
    model.add(Activation('relu'))

    # Add output layer
    model.add(Dense(output_size))
    model.add(Activation('sigmoid'))

    # Compile the model
    model.compile(loss = loss, optimizer = optimizer)

    return model

def CNN1(inputs, stack, input_size):
    net = Conv2D(16, (3, 3), activation = 'relu')(inputs)
    net = Conv2D(32, (3, 3), activation = 'relu')(net)

    return model

def CNN2(inputs, stack, input_size):
    net = Conv2D(16, (3, 3), activation = 'relu')(inputs)
    net = Conv2D(32, (3, 3), activation = 'relu')(net)
    net = Conv2D(32, (3, 3), activation = 'relu')(net)
    net = MaxPooling2D(pool_size = (2, 2))(net)
    net = Flatten()(net)

    return model

def CNN3(inputs, stack, input_size):
    """From @Kaixhin implementation's of the Rainbow paper."""
    net = Conv2D(32, (4, 4), activation = 'relu')(inputs)
    net = Conv2D(64, (2, 2), activation = 'relu')(net)
    net = Conv2D(64, (2, 2), activation = 'relu')(net)
    net = Flatten()(net)

    return net

def create_cnn(cnn, inputs, stack, input_size):
    if cnn == "CNN1":
        net = CNN1(inputs, stack, input_size)
    elif cnn == "CNN2":
        net = CNN2(inputs, stack, input_size)
    else:
        net = CNN3(inputs, stack, input_size)

    return net

def create_model(optimizer, loss, stack, input_size, output_size,
                  dueling = False, cnn = "CNN3"):
    inputs = Input(shape = (stack, input_size, input_size))
    net = create_cnn(cnn, inputs, stack, input_size)

    if dueling:
        advt = Dense(3136, activation = 'relu')(net)
        advt = Dense(output_size)(advt)
        value = Dense(3136, activation = 'relu')(net)
        value = Dense(1)(value)

        # now to combine the two streams
        advt = Lambda(lambda advt: advt - tf.reduce_mean(advt, axis = -1,
                                                         keepdims = True))(advt)
        value = Lambda(lambda value: tf.tile(value, [1, output_size]))(value)
        final = Add()([value, advt])
    else:
        final = Dense(3136, activation = 'relu')(net)
        final = Dense(output_size)(final)

    model = Model(inputs = inputs, outputs = final)
    model.compile(optimizer = optimizer, loss = loss)

    return model
import random
import numpy as np

class LinearSchedule(object):
    def __init__(self, schedule_timesteps, final_p, initial_p):
        """Linear interpolation between initial_p and final_p over
        schedule_timesteps. After this many timesteps pass final_p is
        returned.
        Parameters
        ----------
        schedule_timesteps: int
            Number of timesteps for which to linearly anneal initial_p
            to final_p
        initial_p: float
            initial output value
        final_p: float
            final output value
        """
        self.schedule_timesteps = schedule_timesteps
        self.final_p = final_p
        self.initial_p = initial_p

    def value(self, t):
        """See Schedule.value"""
        fraction = min(float(t) / self.schedule_timesteps, 1.0)
        return self.initial_p + fraction * (self.final_p - self.initial_p)


class GreedyQPolicy:
    """Implement the greedy policy

    Greedy policy always takes current best action.
    """
    def __init__(self):
        super(GreedyQPolicy, self).__init__()

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)
        action = int(np.argmax(q[0]))

        return action, 0

    def get_config(self):
        """Return configurations of GreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(GreedyQPolicy, self).get_config()
        return config


class EpsGreedyQPolicy:
    """Implement the epsilon greedy policy

    Eps Greedy policy either:

    - takes a random action with probability epsilon
    - takes current best action with prob (1 - epsilon)
    """
    def __init__(self, max_eps=1., min_eps = .01, nb_epoch = 10000):
        super(EpsGreedyQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_eps, max_eps)

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        rand = random.random()
        self.eps = self.schedule.value(epoch)

        if rand < self.eps:
            action = int(nb_actions * rand)
        else:
            q = model.predict(state)
            action = int(np.argmax(q[0]))

        return action, self.eps

    def get_config(self):
        """Return configurations of EpsGreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(EpsGreedyQPolicy, self).get_config()
        config['eps'] = self.eps
        return config


class BoltzmannQPolicy:
    """Implement the Boltzmann Q Policy
    Boltzmann Q Policy builds a probability law on q values and returns
    an action selected randomly according to this law.
    """
    def __init__(self, max_temp = 1., min_temp = .01, nb_epoch = 10000, clip = (-500., 500.)):
        super(BoltzmannQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_temp, max_temp)
        self.clip = clip

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        self.temp = self.schedule.value(epoch)
        arg = q / self.temp

        exp_values = np.exp(arg - arg.max())
        probs = exp_values / exp_values.sum()
        action = np.random.choice(range(nb_actions), p = probs)

        return action, self.temp

    def get_config(self):
        """Return configurations of BoltzmannQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannQPolicy, self).get_config()
        config['temp'] = self.temp
        config['clip'] = self.clip
        return config


class BoltzmannGumbelQPolicy:
    """Implements Boltzmann-Gumbel exploration (BGE) adapted for Q learning
    based on the paper Boltzmann Exploration Done Right
    (https://arxiv.org/pdf/1705.10257.pdf).
    BGE is invariant with respect to the mean of the rewards but not their
    variance. The parameter C, which defaults to 1, can be used to correct for
    this, and should be set to the least upper bound on the standard deviation
    of the rewards.
    BGE is only available for training, not testing. For testing purposes, you
    can achieve approximately the same result as BGE after training for N steps
    on K actions with parameter C by using the BoltzmannQPolicy and setting
    tau = C/sqrt(N/K)."""

    def __init__(self, C = 1.0):
        super(BoltzmannGumbelQPolicy, self).__init__()
        self.C = C
        self.action_counts = None

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        q = q.astype('float64')

        # If we are starting training, we should reset the action_counts.
        # Otherwise, action_counts should already be initialized, since we
        # always do so when we begin training.
        if epoch == 0:
            self.action_counts = np.ones(q.shape)

        beta = self.C/np.sqrt(self.action_counts)
        Z = np.random.gumbel(size = q.shape)

        perturbation = beta * Z
        perturbed_q_values = q + perturbation
        action = np.argmax(perturbed_q_values)

        self.action_counts[action] += 1
        return action, np.sum(self.action_counts)

    def get_config(self):
        """Return configurations of BoltzmannGumbelQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannGumbelQPolicy, self).get_config()
        config['C'] = self.C
        return config

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# pylint: disable=C0111

import numpy as np
import sys
import time
import operator
from datetime import timedelta
import collections

class SegmentTree(object):
    def __init__(self, capacity, operation, neutral_element):
        """Build a Segment Tree data structure.
        https://en.wikipedia.org/wiki/Segment_tree
        Can be used as regular array, but with two
        important differences:
            a) setting item's value is slightly slower.
               It is O(lg capacity) instead of O(1).
            b) user has access to an efficient `reduce`
               operation which reduces `operation` over
               a contiguous subsequence of items in the
               array.
        Paramters
        ---------
        capacity: int
            Total size of the array - must be a power of two.
        operation: lambda obj, obj -> obj
            and operation for combining elements (eg. sum, max)
            must for a mathematical group together with the set of
            possible values for array elements.
        neutral_element: obj
            neutral element for the operation above. eg. float('-inf')
            for max and 0 for sum.
        """
        assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2."
        self._capacity = capacity
        self._value = [neutral_element for _ in range(2 * capacity)]
        self._operation = operation

    def _reduce_helper(self, start, end, node, node_start, node_end):
        if start == node_start and end == node_end:
            return self._value[node]
        mid = (node_start + node_end) // 2
        if end <= mid:
            return self._reduce_helper(start, end, 2 * node, node_start, mid)
        else:
            if mid + 1 <= start:
                return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
            else:
                return self._operation(
                    self._reduce_helper(start, mid, 2 * node, node_start, mid),
                    self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
                )

    def reduce(self, start=0, end=None):
        """Returns result of applying `self.operation`
        to a contiguous subsequence of the array.
            self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))
        Parameters
        ----------
        start: int
            beginning of the subsequence
        end: int
            end of the subsequences
        Returns
        -------
        reduced: obj
            result of reducing self.operation over the specified range of array elements.
        """
        if end is None:
            end = self._capacity
        if end < 0:
            end += self._capacity
        end -= 1
        return self._reduce_helper(start, end, 1, 0, self._capacity - 1)

    def __setitem__(self, idx, val):
        # index of the leaf
        idx += self._capacity
        self._value[idx] = val
        idx //= 2
        while idx >= 1:
            self._value[idx] = self._operation(
                self._value[2 * idx],
                self._value[2 * idx + 1]
            )
            idx //= 2

    def __getitem__(self, idx):
        assert 0 <= idx < self._capacity
        return self._value[self._capacity + idx]


class SumSegmentTree(SegmentTree):
    def __init__(self, capacity):
        super(SumSegmentTree, self).__init__(
            capacity=capacity,
            operation=operator.add,
            neutral_element=0.0
        )

    def sum(self, start=0, end=None):
        """Returns arr[start] + ... + arr[end]"""
        return super(SumSegmentTree, self).reduce(start, end)

    def find_prefixsum_idx(self, prefixsum):
        """Find the highest index `i` in the array such that
            sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
        if array values are probabilities, this function
        allows to sample indexes according to the discrete
        probability efficiently.
        Parameters
        ----------
        perfixsum: float
            upperbound on the sum of array prefix
        Returns
        -------
        idx: int
            highest index satisfying the prefixsum constraint
        """
        assert 0 <= prefixsum <= self.sum() + 1e-5
        idx = 1
        while idx < self._capacity:  # while non-leaf
            if self._value[2 * idx] > prefixsum:
                idx = 2 * idx
            else:
                prefixsum -= self._value[2 * idx]
                idx = 2 * idx + 1
        return idx - self._capacity


class MinSegmentTree(SegmentTree):
    def __init__(self, capacity):
        super(MinSegmentTree, self).__init__(
            capacity=capacity,
            operation=min,
            neutral_element=float('inf')
        )

    def min(self, start=0, end=None):
        """Returns min(arr[start], ...,  arr[end])"""

        return super(MinSegmentTree, self).reduce(start, end)


class SumTree:
    def __init__(self, capacity):
        self._capacity = capacity
        self._tree = np.zeros(2 * self._capacity - 1)
        self._data = np.zeros(self._capacity, dtype = object)
        self._data_idx = 0

    @property
    def capacity(self):
        return self._capacity

    @property
    def tree(self):
        return self._tree

    @property
    def data(self):
        return self._data

    def sum(self):
        return self._tree[0]

    def insert(self, data, priority):
#        print("Data shape: {}".format(data.shape))
#        print("Stored data shape: {}".format(self._data.shape))
        self._data[self._data_idx] = data
        tree_idx = self._data_idx + self._capacity - 1
        self.update(tree_idx, priority)
        self._data_idx += 1
        if self._data_idx >= self._capacity:
            self._data_idx = 0

    def update(self, tree_idx, priority):
        delta = priority - self._tree[tree_idx]
        self._tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2  # Get parent
            self._tree[tree_idx] += delta

    def retrieve(self, val):
        tree_idx, parent = None, 0
        while True:
            left = 2 * parent + 1
            right = left + 1
            if left >= len(self._tree):  # Leaf
                tree_idx = parent
                break
            else:
                if val <= self._tree[left]:
                    parent = left
                else:
                    val -= self._tree[left]
                    parent = right

        priority = self._tree[tree_idx]
        data = self._data[tree_idx - self._capacity + 1]

        return tree_idx, priority, data

    def max_leaf(self):
        return np.max(self.leaves())

    def min_leaf(self):
        return np.min(self.leaves())

    def leaves(self):
        return self._tree[-self._capacity:]

#!/usr/bin/env python

"""SnakeGame: A simple and fun exploration, meant to be used by Human and AI.
"""

import sys  # To close the window when the game is over
from array import array  # Efficient numeric arrays
from os import environ, path  # To center the game window the best possible
import random  # Random numbers used for the food
import logging  # Logging function for movements and errors
from itertools import tee  # For the color gradient on snake
import numpy as np

__author__ = "Victor Neves"
__license__ = "MIT"
__version__ = "1.0"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"

# Actions, options and forbidden moves
options = {'QUIT': 0, 'PLAY': 1, 'BENCHMARK': 2, 'LEADERBOARDS': 3, 'MENU': 4,
           'ADD_LEADERBOARDS': 5}
relative_actions = {'LEFT': 0, 'FORWARD': 1, 'RIGHT': 2}
actions = {'LEFT': 0, 'RIGHT': 1, 'UP': 2, 'DOWN': 3, 'IDLE': 4}
forbidden_moves = [(0, 1), (1, 0), (2, 3), (3, 2)]

# Possible rewards in the game
rewards = {'MOVE': -0.005, 'GAME_OVER': -1, 'SCORED': 1}

# Types of point in the board
point_type = {'EMPTY': 0, 'FOOD': 1, 'BODY': 2, 'HEAD': 3, 'DANGEROUS': 4}

# Speed levels possible to human players, MEGA HARDCORE starts with MEDIUM and
# increases with snake size
levels = [" EASY ", " MEDIUM ", " HARD ", " MEGA HARDCORE "]
speeds = {'EASY': 80, 'MEDIUM': 60, 'HARD': 40}

class GlobalVariables:
    """Global variables to be used while drawing and moving the snake game.

    Attributes
    ----------
    BOARD_SIZE: int, optional, default = 30
        The size of the board.
    BLOCK_SIZE: int, optional, default = 20
        The size in pixels of a block.
    HEAD_COLOR: tuple of 3 * int, optional, default = (42, 42, 42)
        Color of the head. Start of the body color gradient.
    TAIL_COLOR: tuple of 3 * int, optional, default = (152, 152, 152)
        Color of the tail. End of the body color gradient.
    FOOD_COLOR: tuple of 3 * int, optional, default = (200, 0, 0)
        Color of the food.
    GAME_SPEED: int, optional, default = 10
        Speed in ticks of the game. The higher the faster.
    BENCHMARK: int, optional, default = 10
        Ammount of matches to BENCHMARK and possibly go to leaderboards.
    """
    def __init__(self, BOARD_SIZE = 30, BLOCK_SIZE = 20,
                 HEAD_COLOR = (42, 42, 42), TAIL_COLOR = (152, 152, 152),
                 FOOD_COLOR = (200, 0, 0), GAME_SPEED = 80, GAME_FPS = 100,
                 BENCHMARK = 10):
        """Initialize all global variables. Can be updated with argument_handler.
        """
        self.BOARD_SIZE = BOARD_SIZE
        self.BLOCK_SIZE = BLOCK_SIZE
        self.HEAD_COLOR = HEAD_COLOR
        self.TAIL_COLOR = TAIL_COLOR
        self.FOOD_COLOR = FOOD_COLOR
        self.GAME_SPEED = GAME_SPEED
        self.GAME_FPS = GAME_FPS
        self.BENCHMARK = BENCHMARK

        if self.BOARD_SIZE > 50: # Warn the user about performance
            logger.warning('WARNING: BOARD IS TOO BIG, IT MAY RUN SLOWER.')

class TextBlock:
    """Block of text class, used by pygame. Can be used to both text and menu.

    Attributes:
    ----------
    text: string
        The text to be displayed.
    pos: tuple of 2 * int
        Color of the tail. End of the body color gradient.
    screen: pygame window object
        The screen where the text is drawn.
    scale: int, optional, default = 1 / 12
        Adaptive scale to resize if the board size changes.
    type: string, optional, default = "text"
        Assert whether the BlockText is a text or menu option.
    """
    def __init__(self, text, pos, screen, scale = (1 / 12), type = "text"):
        """Initialize, set position of the rectangle and render the text block."""
        self.type = type
        self.hovered = False
        self.text = text
        self.pos = pos
        self.screen = screen
        self.scale = scale
        self.set_rect()
        self.draw()

    def draw(self):
        """Set what to render and blit on the pygame screen."""
        self.set_rend()
        self.screen.blit(self.rend, self.rect)

    def set_rend(self):
        """Set what to render (font, colors, sizes)"""
        font = pygame.font.Font(resource_path("resources/fonts/freesansbold.ttf"),
                                int((var.BOARD_SIZE * var.BLOCK_SIZE) * self.scale))
        self.rend = font.render(self.text, True, self.get_color(),
                                self.get_background())

    def get_color(self):
        """Get color to render for text and menu (hovered or not).

        Return
        ----------
        color: tuple of 3 * int
            The color that will be rendered for the text block.
        """
        color = pygame.Color(42, 42, 42)

        if self.type == "menu":
            if self.hovered:
                pass
            else:
                color = pygame.Color(152, 152, 152)

        return color

    def get_background(self):
        """Get background color to render for text (hovered or not) and menu.

        Return
        ----------
        color: tuple of 3 * int
            The color that will be rendered for the background of the text block.
        """
        color = None

        if self.type == "menu":
            if self.hovered:
                color = pygame.Color(152, 152, 152)

        return color

    def set_rect(self):
        """Set the rectangle and it's position to draw on the screen."""
        self.set_rend()
        self.rect = self.rend.get_rect()
        self.rect.center = self.pos


class Snake:
    """Player (snake) class which initializes head, body and board.

    The body attribute represents a list of positions of the body, which are in-
    cremented when moving/eating on the position [0]. The orientation represents
    where the snake is looking at (head) and collisions happen when any element
    is superposed with the head.

    Attributes
    ----------
    head: list of 2 * int, default = [BOARD_SIZE / 4, BOARD_SIZE / 4]
        The head of the snake, located according to the board size.
    body: list of lists of 2 * int
        Starts with 3 parts and grows when food is eaten.
    previous_action: int, default = 1
        Last action which the snake took.
    length: int, default = 3
        Variable length of the snake, can increase when food is eaten.
    """
    def __init__(self):
        """Inits Snake with 3 body parts (one is the head) and pointing right"""
        self.head = [int(var.BOARD_SIZE / 4), int(var.BOARD_SIZE / 4)]
        self.body = [[self.head[0], self.head[1]],
                     [self.head[0] - 1, self.head[1]],
                     [self.head[0] - 2, self.head[1]]]
        self.previous_action = 1
        self.length = 3

    def is_movement_invalid(self, action):
        valid = False

        if (action, self.previous_action) in forbidden_moves:
            valid = True

        return valid

    def move(self, action, food_pos):
        """According to orientation, move 1 block. If the head is not positioned
        on food, pop a body part. Else, return without popping.

        Return
        ----------
        ate_food: boolean
            Flag which represents whether the snake ate or not food.
        """
        ate_food = False

        if action == actions['IDLE'] or self.is_movement_invalid(action):
            action = self.previous_action
        else:
            self.previous_action = action

        if action == actions['LEFT']:
            self.head[0] -= 1
        elif action == actions['RIGHT']:
            self.head[0] += 1
        elif action == actions['UP']:
            self.head[1] -= 1
        elif action == actions['DOWN']:
            self.head[1] += 1

        self.body.insert(0, list(self.head))

        if self.head == food_pos:
            logger.info('EVENT: FOOD EATEN')
            self.length = len(self.body)

            ate_food = True
        else:
            self.body.pop()

        return ate_food


class FoodGenerator:
    """Generate and keep track of food.

    Attributes
    ----------
    pos:
        Current position of food.
    is_food_on_screen:
        Flag for existence of food.
    """
    def __init__(self, body):
        """Initialize a food piece and set existence flag."""
        self.is_food_on_screen = False
        self.pos = self.generate_food(body)

    def generate_food(self, body):
        """Generate food and verify if it's on a valid place.

        Return
        ----------
        pos: tuple of 2 * int
            Position of the food that was generated. It can't be in the body.
        """
        if not self.is_food_on_screen:
            while True:
                food = [int((var.BOARD_SIZE - 1) * random.random()),
                        int((var.BOARD_SIZE - 1) * random.random())]

                if food in body:
                    continue
                else:
                    self.pos = food
                    break

            logger.info('EVENT: FOOD APPEARED')
            self.is_food_on_screen = True

        return self.pos


class Game:
    """Hold the game window and functions.

    Attributes
    ----------
    window: pygame display
        Pygame window to show the game.
    fps: pygame time clock
        Define Clock and ticks in which the game will be displayed.
    snake: object
        The actual snake who is going to be played.
    food_generator: object
        Generator of food which responds to the snake.
    food_pos: tuple of 2 * int
        Position of the food on the board.
    game_over: boolean
        Flag for game_over.
    player: string
        Define if human or robots are playing the game.
    board_size: int, optional, default = 30
        The size of the board.
    local_state: boolean, optional, default = False
        Whether to use or not game expertise (used mostly by robots players).
    relative_pos: boolean, optional, default = False
        Whether to use or not relative position of the snake head. Instead of
        actions, use relative_actions.
    screen_rect: tuple of 2 * int
        The screen rectangle, used to draw relatively positioned blocks.
    """
    def __init__(self, player, board_size = 30, local_state = False, relative_pos = False):
        """Initialize window, fps and score. Change nb_actions if relative_pos"""
        var.BOARD_SIZE = board_size
        self.local_state = local_state
        self.relative_pos = relative_pos
        self.player = player

        if player == "ROBOT":
            if self.relative_pos:
                self.nb_actions = 3
            else:
                self.nb_actions = 5

            self.reset_game()

    def reset_game(self):
        """Reset the game environment."""
        self.step = 0
        self.snake = Snake()
        self.food_generator = FoodGenerator(self.snake.body)
        self.food_pos = self.food_generator.pos
        self.scored = False
        self.game_over = False

    def create_window(self):
        """Create a pygame display with BOARD_SIZE * BLOCK_SIZE dimension."""
        pygame.init()

        flags = pygame.DOUBLEBUF
        self.window = pygame.display.set_mode((var.BOARD_SIZE * var.BLOCK_SIZE,\
                                               var.BOARD_SIZE * var.BLOCK_SIZE),
                                               flags)
        self.window.set_alpha(None)
        self.screen_rect = self.window.get_rect()
        self.fps = pygame.time.Clock()

    def menu(self):
        """Main menu of the game.

        Return
        ----------
        selected_option: int
            The selected option in the main loop.
        """
        pygame.display.set_caption("SNAKE GAME  | PLAY NOW!")

        img = pygame.image.load(resource_path("resources/images/snake_logo.png"))
        img = pygame.transform.scale(img, (var.BOARD_SIZE * var.BLOCK_SIZE, int(var.BOARD_SIZE * var.BLOCK_SIZE / 3)))

        img_rect = img.get_rect()
        img_rect.center = self.screen_rect.center

        menu_options = [TextBlock(' PLAY GAME ', (self.screen_rect.centerx,
                                                  4 * self.screen_rect.centery / 10),
                                                  self.window, (1 / 12), "menu"),
                        TextBlock(' BENCHMARK ', (self.screen_rect.centerx,
                                                  6 * self.screen_rect.centery / 10),
                                                  self.window, (1 / 12), "menu"),
                        TextBlock(' LEADERBOARDS ', (self.screen_rect.centerx,
                                                     8 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 12), "menu"),
                        TextBlock(' QUIT ', (self.screen_rect.centerx,
                                             10 * self.screen_rect.centery / 10),
                                             self.window, (1 / 12), "menu")]
        selected = False
        selected_option = None

        while not selected:
            pygame.event.pump()
            ev = pygame.event.get()

            self.window.fill(pygame.Color(225, 225, 225))

            for option in menu_options:
                option.draw()

                if option.rect.collidepoint(pygame.mouse.get_pos()):
                    option.hovered = True

                    if option == menu_options[0]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['PLAY']
                    elif option == menu_options[1]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['BENCHMARK']
                    elif option == menu_options[2]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['LEADERBOARDS']
                    elif option == menu_options[3]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['QUIT']
                else:
                    option.hovered = False

            if selected_option is not None:
                selected = True

            self.window.blit(img, img_rect.bottomleft)
            pygame.display.update()

        return selected_option

    def start_match(self):
        """Create some wait time before the actual drawing of the game."""
        for i in range(3):
            time = str(3 - i)
            self.window.fill(pygame.Color(225, 225, 225))

            # Game starts in 3, 2, 1
            text = [TextBlock('Game starts in', (self.screen_rect.centerx,
                                                 4 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 10), "text"),
                    TextBlock(time, (self.screen_rect.centerx,
                                                 12 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 1.5), "text")]

            for text_block in text:
                text_block.draw()

            pygame.display.update()
            pygame.display.set_caption("SNAKE GAME  |  Game starts in "
                                       + time + " second(s) ...")

            pygame.time.wait(1000)

        logger.info('EVENT: GAME START')

    def start(self):
        """Use menu to select the option/game mode."""
        opt = self.menu()
        running = True

        while running:
            if opt == options['QUIT']:
                pygame.quit()
                sys.exit()
            elif opt == options['PLAY']:
                var.GAME_SPEED, mega_hardcore = self.select_speed()
                self.reset_game()
                self.start_match()
                score = self.single_player(mega_hardcore)
                opt = self.over(score)
            elif opt == options['BENCHMARK']:
                var.GAME_SPEED, mega_hardcore = self.select_speed()
                score = array('i')

                for i in range(var.BENCHMARK):
                    self.reset_game()
                    self.start_match()
                    score.append(self.single_player(mega_hardcore))

                opt = self.over(score)
            elif opt == options['LEADERBOARDS']:
                pass
            elif opt == options['ADD_LEADERBOARDS']:
                pass
            elif opt == options['MENU']:
                opt = self.menu()

    def over(self, score):
        """If collision with wall or body, end the game and open options.

        Return
        ----------
        selected_option: int
            The selected option in the main loop.
        """
        menu_options = [None] * 5
        menu_options[0] = TextBlock(' PLAY AGAIN ', (self.screen_rect.centerx,
                                                     4 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 15), "menu")
        menu_options[1] = TextBlock(' GO TO MENU ', (self.screen_rect.centerx,
                                                     6 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 15), "menu")
        menu_options[3] = TextBlock(' QUIT ', (self.screen_rect.centerx,
                                               10 * self.screen_rect.centery / 10),
                                               self.window, (1 / 15), "menu")

        if isinstance(score, int):
            text_score = 'SCORE: ' + str(score)
        else:
            text_score = 'MEAN SCORE: ' + str(sum(score) / var.BENCHMARK)
            menu_options[2] = TextBlock(' ADD TO LEADERBOARDS ', (self.screen_rect.centerx,
                                                                  8 * self.screen_rect.centery / 10),
                                                                  self.window, (1 / 15), "menu")

        pygame.display.set_caption("SNAKE GAME  | " + text_score
                                   + "  |  GAME OVER...")
        logger.info('EVENT: GAME OVER | FINAL ' + text_score)
        menu_options[4] = TextBlock(text_score, (self.screen_rect.centerx,
                                                 15 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 10), "text")
        selected = False
        selected_option = None

        while not selected:
            pygame.event.pump()
            ev = pygame.event.get()

            # Game over screen
            self.window.fill(pygame.Color(225, 225, 225))

            for option in menu_options:
                if option is not None:
                    option.draw()

                    if option.rect.collidepoint(pygame.mouse.get_pos()):
                        option.hovered = True

                        if option == menu_options[0]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    selected_option = options['PLAY']
                        elif option == menu_options[1]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    selected_option = options['MENU']
                        elif option == menu_options[2]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    selected_option = options['ADD_LEADERBOARDS']
                        elif option == menu_options[3]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    pygame.quit()
                                    sys.exit()
                    else:
                        option.hovered = False

            if selected_option is not None:
                selected = True

            pygame.display.update()

        return selected_option

    def select_speed(self):
        """Speed menu, right before calling start_match.

        Return
        ----------
        speed: int
            The selected speed in the main loop.
        """
        menu_options = [TextBlock(levels[0], (self.screen_rect.centerx,
                                              4 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu"),
                        TextBlock(levels[1], (self.screen_rect.centerx,
                                              8 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu"),
                        TextBlock(levels[2], (self.screen_rect.centerx,
                                              12 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu"),
                        TextBlock(levels[3], (self.screen_rect.centerx,
                                              16 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu")]
        mega_hardcore = False
        selected = False
        speed = None

        while not selected:
            pygame.event.pump()
            ev = pygame.event.get()

            # Game over screen
            self.window.fill(pygame.Color(225, 225, 225))

            for option in menu_options:
                if option is not None:
                    option.draw()

                    if option.rect.collidepoint(pygame.mouse.get_pos()):
                        option.hovered = True

                        if option == menu_options[0]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = speeds['EASY']
                        elif option == menu_options[1]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = speeds['MEDIUM']
                        elif option == menu_options[2]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = speeds['HARD']
                        elif option == menu_options[3]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = speeds['MEDIUM']
                                    mega_hardcore = True

                    else:
                        option.hovered = False

            if speed is not None:
                selected = True

            pygame.display.update()

        return speed, mega_hardcore

    def single_player(self, mega_hardcore = False):
        """Game loop for single_player (HUMANS).

        Return
        ----------
        score: int
            The final score for the match (discounted of initial length).
        """
        # The main loop, it pump key_presses and update the board every tick.
        previous_size = self.snake.length # Initial size of the snake
        current_size = previous_size # Initial size
        color_list = self.gradient([(42, 42, 42), (152, 152, 152)],\
                                   previous_size)

        # Main loop, where snakes moves after elapsed time is bigger than the
        # move_wait time. The last_key pressed is recorded to make the game more
        # smooth for human players.
        elapsed = 0
        last_key = self.snake.previous_action
        move_wait = var.GAME_SPEED

        while not self.game_over:
            elapsed += self.fps.get_time()  # Get elapsed time since last call.

            if mega_hardcore:  # Progressive speed increments, the hardest.
                move_wait = var.GAME_SPEED - (2 * (self.snake.length - 3))

            key_input = self.handle_input()  # Receive inputs with tick.
            invalid_key = self.snake.is_movement_invalid(key_input)

            if key_input is not None and not invalid_key:
                last_key = key_input

            if elapsed >= move_wait:  # Move and redraw
                elapsed = 0
                self.game_over = self.play(last_key)
                current_size = self.snake.length  # Update the body size

                if current_size > previous_size:
                    color_list = self.gradient([(42, 42, 42), (152, 152, 152)],
                                                   current_size)

                    previous_size = current_size

                self.draw(color_list)

            pygame.display.update()
            self.fps.tick(100)  # Limit FPS to 100

        score = current_size - 3  # After the game is over, record score

        return score

    def check_collision(self):
        """Check wether any collisions happened with the wall or body.

        Return
        ----------
        collided: boolean
            Whether the snake collided or not.
        """
        collided = False

        if self.snake.head[0] > (var.BOARD_SIZE - 1) or self.snake.head[0] < 0:
            logger.info('EVENT: WALL COLLISION')
            collided = True
        elif self.snake.head[1] > (var.BOARD_SIZE - 1) or self.snake.head[1] < 0:
            logger.info('EVENT: WALL COLLISION')
            collided = True
        elif self.snake.head in self.snake.body[1:]:
            logger.info('EVENT: BODY COLLISION')
            collided = True

        return collided

    def is_won(self):
        """Verify if the score is greater than 0.

        Return
        ----------
        won: boolean
            Whether the score is greater than 0.
        """
        return self.snake.length > 3

    def generate_food(self):
        """Generate new food if needed.

        Return
        ----------
        food_pos: tuple of 2 * int
            Current position of the food.
        """
        food_pos = self.food_generator.generate_food(self.snake.body)

        return food_pos

    def handle_input(self):
        """After getting current pressed keys, handle important cases.

        Return
        ----------
        action: int
            Handle human input to assess the next action.
        """
        pygame.event.set_allowed([pygame.QUIT, pygame.KEYDOWN])
        keys = pygame.key.get_pressed()
        pygame.event.pump()
        action = None

        if keys[pygame.K_ESCAPE] or keys[pygame.K_q]:
            logger.info('ACTION: KEY PRESSED: ESCAPE or Q')
            self.over(self.snake.length - 3)
        elif keys[pygame.K_LEFT]:
            logger.info('ACTION: KEY PRESSED: LEFT')
            action = actions['LEFT']
        elif keys[pygame.K_RIGHT]:
            logger.info('ACTION: KEY PRESSED: RIGHT')
            action = actions['RIGHT']
        elif keys[pygame.K_UP]:
            logger.info('ACTION: KEY PRESSED: UP')
            action = actions['UP']
        elif keys[pygame.K_DOWN]:
            logger.info('ACTION: KEY PRESSED: DOWN')
            action = actions['DOWN']

        return action

    def eval_local_safety(self, canvas, body):
        """Evaluate the safety of the head's possible next movements.

        Return
        ----------
        canvas: np.array of size BOARD_SIZE**2
            After using game expertise, change canvas values to DANGEROUS if true.
        """
        if (body[0][0] + 1) > (var.BOARD_SIZE - 1)\
            or ([body[0][0] + 1, body[0][1]]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 0] = point_type['DANGEROUS']
        if (body[0][0] - 1) < 0 or ([body[0][0] - 1, body[0][1]]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 1] = point_type['DANGEROUS']
        if (body[0][1] - 1) < 0 or ([body[0][0], body[0][1] - 1]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 2] = point_type['DANGEROUS']
        if (body[0][1] + 1) > (var.BOARD_SIZE - 1)\
            or ([body[0][0], body[0][1] + 1]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 3] = point_type['DANGEROUS']

        return canvas

    def state(self):
        """Create a matrix of the current state of the game.

        Return
        ----------
        canvas: np.array of size BOARD_SIZE**2
            Return the current state of the game in a matrix.
        """
        canvas = np.zeros((var.BOARD_SIZE, var.BOARD_SIZE))

        if self.game_over:
            pass
        else:
            body = self.snake.body

            for part in body:
                canvas[part[0], part[1]] = point_type['BODY']

            canvas[body[0][0], body[0][1]] = point_type['HEAD']

            if self.local_state:
                canvas = self.eval_local_safety(canvas, body)

            canvas[self.food_pos[0], self.food_pos[1]] = point_type['FOOD']

        return canvas

    def relative_to_absolute(self, action):
        """Translate relative actions to absolute.

        Return
        ----------
        action: int
            Translated action from relative to absolute.
        """
        if action == relative_actions['FORWARD']:
            action = self.snake.previous_action
        elif action == relative_actions['LEFT']:
            if self.snake.previous_action == actions['LEFT']:
                action = actions['DOWN']
            elif self.snake.previous_action == actions['RIGHT']:
                action = actions['UP']
            elif self.snake.previous_action == actions['UP']:
                action = actions['LEFT']
            else:
                action = actions['RIGHT']
        else:
            if self.snake.previous_action == actions['LEFT']:
                action = actions['UP']
            elif self.snake.previous_action == actions['RIGHT']:
                action = actions['DOWN']
            elif self.snake.previous_action == actions['UP']:
                action = actions['RIGHT']
            else:
                action = actions['LEFT']

        return action

    def play(self, action):
        """Move the snake to the direction, eat and check collision."""
        self.scored = False
        self.step += 1
        self.food_pos = self.generate_food()

        if self.relative_pos:
            action = self.relative_to_absolute(action)

        if self.snake.move(action, self.food_pos):
            self.scored = True
            self.food_generator.is_food_on_screen = False

        if self.player == "HUMAN":
            if self.check_collision():
                return True
        elif self.check_collision() or self.step > 50 * self.snake.length:
            self.game_over = True

    def get_reward(self):
        """Return the current score. Can be used as the reward function.

        Return
        ----------
        reward: float
            Current reward of the game.
        """
        reward = rewards['MOVE']

        if self.game_over:
            reward = rewards['GAME_OVER']
        elif self.scored:
            reward = self.snake.length

        return reward

    def gradient(self, colors, steps, components = 3):
        """Function to create RGB gradients given 2 colors and steps. If
        component is changed to 4, it does the same to RGBA colors.

        Return
        ----------
        result: list of steps length of tuple of 3 * int (if RGBA, 4 * int)
            List of colors of calculated gradient from start to end.
        """
        def linear_gradient(start, finish, substeps):
            yield start

            for i in range(1, substeps):
                yield tuple([(start[j] + (float(i) / (substeps-1)) * (finish[j]\
                            - start[j])) for j in range(components)])

        def pairs(seq):
            a, b = tee(seq)
            next(b, None)

            return zip(a, b)

        result = []
        substeps = int(float(steps) / (len(colors) - 1))

        for a, b in pairs(colors):
            for c in linear_gradient(a, b, substeps):
                result.append(c)

        return result

    def draw(self, color_list):
        """Draw the game, the snake and the food using pygame."""
        self.window.fill(pygame.Color(225, 225, 225))

        for part, color in zip(self.snake.body, color_list):
            pygame.draw.rect(self.window, color, pygame.Rect(part[0] *\
                        var.BLOCK_SIZE, part[1] * var.BLOCK_SIZE, \
                        var.BLOCK_SIZE, var.BLOCK_SIZE))

        pygame.draw.rect(self.window, var.FOOD_COLOR,\
                         pygame.Rect(self.food_pos[0] * var.BLOCK_SIZE,\
                         self.food_pos[1] * var.BLOCK_SIZE, var.BLOCK_SIZE,\
                         var.BLOCK_SIZE))

        pygame.display.set_caption("SNAKE GAME  |  Score: "
                                    + str(self.snake.length - 3))

def resource_path(relative_path):
    """Function to return absolute paths. Used while creating .exe file."""
    if hasattr(sys, '_MEIPASS'):
        return path.join(sys._MEIPASS, relative_path)

    return path.join(path.dirname(path.realpath(__file__)), relative_path)

var = GlobalVariables() # Initializing GlobalVariables
logger = logging.getLogger(__name__) # Setting logger
environ['SDL_VIDEO_CENTERED'] = '1' # Centering the window    

import numpy as np
from random import sample, uniform
from array import array  # Efficient numeric arrays

class ExperienceReplay:
    """The class that handles memory and experiences replay.

    Attributes
    ----------
    memory: list of experiences
        Memory list to insert experiences.
    memory_size: int, optional, default = 150000
        The ammount of experiences to be stored in the memory.
    input_shape: tuple of 3 * int
        The shape of the input which will be stored.
    """
    def __init__(self, memory_size = 150000):
        """Initialize parameters and the memory array."""
        self.memory_size = memory_size
        self.reset_memory() # Initiate the memory

    def exp_size(self):
        """Returns how much memory is stored."""
        return len(self.memory)

    def remember(self, s, a, r, s_prime, game_over):
        """Remember SARS' experiences, with the game_over parameter (done)."""
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])

        self.memory.append(experience)

        if self.memory_size > 0 and self.exp_size() > self.memory_size:
            self.memory.pop(0)

    def get_samples(self, batch_size):
        """Sample the memory according to PER flag.

        Return
        ----------
        batch: np.array of batch_size experiences
            The batched experiences from memory.
        IS_weights: np.array of batch_size of the weights
            As it's used only in PER, is an array of ones in this case.
        Indexes: list of batch_size * int
            As it's used only in PER, return None.
        """
        IS_weights = np.ones((batch_size, ))
        batch = np.array(sample(self.memory, batch_size))

        return batch, IS_weights, None

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9,
                    n_steps = 1):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])

            for idx, target in enumerate(Y_target):
                Qsa[idx] = target[actions[idx]]

            Qsa = np.array(Qsa)
        else:
            Qsa = np.max(Y[batch_size:], axis = 1)

        # Where the action happened, replace with the Q values of S_prime
        targets = np.array(Y[:batch_size])
        value = r + (gamma ** n_steps) * (1 - game_over) * Qsa
        targets[range(batch_size), a.astype(int)] = value

        return S, targets, IS_weights

    def reset_memory(self):
        """Set the memory as a blank list."""
        if self.memory_size <= 100:
            memory_size = 150000

        self.memory = []


class PrioritizedExperienceReplayNaive:
    """The class that handles memory and experiences replay.

    Attributes:
        memory: memory array to insert experiences.
        memory_size: the ammount of experiences to be stored in the memory.
        input_shape: the shape of the input which will be stored.
        batch_function: returns targets according to S.
        per: flag for PER usage.
        per_epsilon: used to replace "0" probabilities cases.
        per_alpha: how much prioritization to use.
        per_beta: importance sampling weights (IS_weights).
    """
    def __init__(self, memory_size = 150000, alpha = 0.6, epsilon = 0.001,
                 beta = 0.4, nb_epoch = 10000, decay = 0.5):
        """Initialize parameters and the memory array."""
        self.memory_size = memory_size
        self.epsilon = epsilon
        self.alpha = alpha
        self.beta = beta
        self.schedule = LinearSchedule(nb_epoch * decay, 1.0, beta)
        self.reset_memory() # Initiate the memory

    def exp_size(self):
        """Returns how much memory is stored."""
        return self.exp

    def get_priority(self, errors):
        """Returns priority based on how much prioritization to use."""
        return (errors + self.epsilon) ** self.alpha

    def update(self, tree_indices, errors):
        """Update a list of nodes, based on their errors."""
        priorities = self.get_priority(errors)

        for index, priority in zip(tree_indices, priorities):
            self.memory.update(index, priority)

    def remember(self, s, a, r, s_prime, game_over):
        """Remember SARS' experiences, with the game_over parameter (done)."""
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])

        max_priority = self.memory.max_leaf()

        if max_priority == 0:
            max_priority = self.get_priority(0)

        self.memory.insert(experience, max_priority)
        self.exp += 1

    def get_samples(self, batch_size):
        """Sample the memory according to PER flag."""
        batch = [None] * batch_size
        IS_weights = np.zeros((batch_size, ))
        tree_indices = [0] * batch_size

        memory_sum = self.memory.sum()
        len_seg = memory_sum / batch_size
        min_prob = self.memory.min_leaf() / memory_sum

        for i in range(batch_size):
            val = uniform(len_seg * i, len_seg * (i + 1))
            tree_indices[i], priority, batch[i] = self.memory.retrieve(val)
            prob = priority / self.memory.sum()
            IS_weights[i] = np.power(prob / min_prob, -self.beta)

        return np.array(batch), IS_weights, tree_indices

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9,
                    n_steps = 1):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])

            for idx, target in enumerate(Y_target):
                Qsa[idx] = target[actions[idx]]

            Qsa = np.array(Qsa)
        else:
            Qsa = np.max(Y[batch_size:], axis = 1)

        # Where the action happened, replace with the Q values of S_prime
        targets = np.array(Y[:batch_size])
        value = r + (gamma ** n_steps) * (1 - game_over) * Qsa
        targets[range(batch_size), a.astype(int)] = value

        errors = np.abs(value - Y[:batch_size].max(axis = 1)).clip(max = 1.)
        self.update_priorities(tree_indices, errors)

        return S, targets, IS_weights

    def reset_memory(self):
        """Set the memory as a blank list."""
        if self.memory_size <= 100:
            self.memory_size = 150000

        self.memory = SumTree(self.memory_size)
        self.exp = 0


class PrioritizedExperienceReplay:
    def __init__(self, memory_size, nb_epoch = 10000, epsilon = 0.001,
                 alpha = 0.6, beta = 0.4, decay = 0.5):
        self.memory_size = memory_size
        self.alpha = alpha
        self.epsilon = epsilon
        self.beta = beta
        self.schedule = LinearSchedule(nb_epoch * decay, 1.0, beta)
        self.max_priority = 1.0
        self.reset_memory()

    def exp_size(self):
        """Returns how much memory is stored."""
        return len(self.memory)

    def remember(self, s, a, r, s_prime, game_over):
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])
        if self.exp_size() < self.memory_size:
            self.memory.append(experience)
            self.pos += 1
        else:
            self.memory[self.pos] = experience
            self.pos = (self.pos + 1) % self.memory_size

        self._it_sum[self.pos] = self.max_priority ** self.alpha
        self._it_min[self.pos] = self.max_priority ** self.alpha

    def _sample_proportional(self, batch_size):
        res = array('i')

        for _ in range(batch_size):
            mass = random.random() * self._it_sum.sum(0, self.exp_size() - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)

        return res

    def get_priority(self, errors):
        """Returns priority based on how much prioritization to use."""
        return (errors + self.epsilon) ** self.alpha

    def get_samples(self, batch_size):
        idxes = self._sample_proportional(batch_size)

        weights = array('f')
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * self.exp_size()) ** (-self.beta)

        for idx in idxes:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * self.exp_size()) ** (-self.beta)
            weights.append(weight / max_weight)

        weights = np.array(weights, dtype=np.float32)
        samples = [self.memory[idx] for idx in idxes]

        return np.array(samples), weights, idxes

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9,
                    n_steps = 1):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])

            for idx, target in enumerate(Y_target):
                Qsa[idx] = target[actions[idx]]

            Qsa = np.array(Qsa)
        else:
            Qsa = np.max(Y[batch_size:], axis = 1)

        # Where the action happened, replace with the Q values of S_prime
        targets = np.array(Y[:batch_size])
        value = r + (gamma ** n_steps) * (1 - game_over) * Qsa
        targets[range(batch_size), a.astype(int)] = value

        errors = np.abs(value - Y[:batch_size].max(axis = 1)).clip(max = 1.)
        self.update_priorities(tree_indices, errors)

        return S, targets, IS_weights

    def update_priorities(self, idxes, errors):
        priorities = self.get_priority(errors)

        for idx, priority in zip(idxes, priorities):
            self._it_sum[idx] = priority ** self.alpha
            self._it_min[idx] = priority ** self.alpha

            self.max_priority = max(self.max_priority, priority)

    def reset_memory(self):
        """Set the memory as a blank list."""
        if self.memory_size <= 100:
            self.memory_size = 150000

        self.memory = []
        self.pos = 0

        it_capacity = 1

        while it_capacity < self.memory_size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)


#!/usr/bin/env python

"""dqn: First try to create an AI for SnakeGame. Is it good enough?

This algorithm is a implementation of DQN, Double DQN logic (using a target
network to have fixed Q-targets), Dueling DQN logic (Q(s,a) = Advantage + Value),
PER (Prioritized Experience Replay, using Sum Trees) and Multi-step returns. You
can read more about these on https://goo.gl/MctLzp

Implemented algorithms
----------
    * Simple Deep Q-network (DQN with ExperienceReplay);
        Paper: https://arxiv.org/abs/1312.5602
    * Double Deep Q-network (Double DQN);
        Paper: https://arxiv.org/abs/1509.06461
    * Dueling Deep Q-network (Dueling DQN);
        Paper: https://arxiv.org/abs/1511.06581
    * Prioritized Experience Replay (PER);
        Paper: https://arxiv.org/abs/1511.05952
    * Multi-step returns.
        Paper: https://arxiv.org/pdf/1703.01327

Arguments
----------
--load: 'file.h5'
    Load a previously trained model in '.h5' format.
--board_size: int, optional, default = 10
    Assign the size of the board.
--nb_frames: int, optional, default = 4
    Assign the number of frames per stack, default = 4.
--nb_actions: int, optional, default = 5
    Assign the number of actions possible.
--update_freq: int, optional, default = 0.001
    Whether to soft or hard update the target. Epochs or ammount of the update.
--visual: boolean, optional, default = False
    Select wheter or not to draw the game in pygame.
--double: boolean, optional, default = False
    Use a target network with double DQN logic.
--dueling: boolean, optional, default = False
    Whether to use dueling network logic, Q(s,a) = A + V.
--per: boolean, optional, default = False
    Use Prioritized Experience Replay (based on Sum Trees).
--local_state: boolean, optional, default = True
    Verify is possible next moves are dangerous (field expertise)
"""

import numpy as np
from array import array
from os import path, environ, sys
import random
import inspect

# Making relative imports from parallel folders possible
currentdir = path.dirname(path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = path.dirname(currentdir)
sys.path.insert(0, parentdir)

from keras.optimizers import RMSprop, Nadam
from keras.models import load_model
from keras import backend as K

__author__ = "Victor Neves"
__license__ = "MIT"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"

K.set_image_dim_ordering('th')  # Setting keras ordering

class Agent:
    """Agent based in a simple DQN that can read states, remember and play.

    Attributes
    ----------
    memory: object
        Memory used in training. ExperienceReplay or PrioritizedExperienceReplay
    memory_size: int, optional, default = -1
        Capacity of the memory used.
    model: keras model
        The input model in Keras.
    target: keras model, optional, default = None
        The target model, used to calculade the fixed Q-targets.
    nb_frames: int, optional, default = 4
        Ammount of frames for each experience (sars).
    board_size: int, optional, default = 10
        Size of the board used.
    frames: list of experiences
        The buffer of frames, store sars experiences.
    per: boolean, optional, default = False
        Flag for PER usage.
    update_target_freq: int or float, default = 0.001
        Whether soft or hard updates occur. If < 1, soft updated target model.
    n_steps: int, optional, default = 1
        Size of the rewards buffer, to use Multi-step returns.
    """
    def __init__(self, model, target = None, memory_size = -1, nb_frames = 4,
                 board_size = 10, per = False, update_target_freq = 0.001):
        """Initialize the agent with given attributes."""
        if per:
            self.memory = PrioritizedExperienceReplay(memory_size = memory_size)
        else:
            self.memory = ExperienceReplay(memory_size = memory_size)

        self.per = per
        self.model = model
        self.target = target
        self.nb_frames = nb_frames
        self.board_size = board_size
        self.update_target_freq = update_target_freq
        self.clear_frames()

    def reset_memory(self):
        """Reset memory if necessary."""
        self.memory.reset_memory()

    def get_game_data(self, game):
        """Create a list with 4 frames and append/pop them each frame.

        Return
        ----------
        expanded_frames: list of experiences
            The buffer of frames, shape = (nb_frames, board_size, board_size)
        """
        frame = game.state()

        if self.frames is None:
            self.frames = [frame] * self.nb_frames
        else:
            self.frames.append(frame)
            self.frames.pop(0)

        expanded_frames = np.expand_dims(self.frames, 0)

        return expanded_frames

    def clear_frames(self):
        """Reset frames to restart appending."""
        self.frames = None

    def update_target_model_hard(self):
        """Update the target model with the main model's weights."""
        self.target.set_weights(self.model.get_weights())

    def transfer_weights(self):
        """Transfer Weights from Model to Target at rate update_target_freq."""
        model_weights = self.model.get_weights()
        target_weights = self.target.get_weights()

        for i in range(len(W)):
            target_weights[i] = (self.update_target_freq * model_weights[i]
                                 + ((1 - self.update_target_frequency)
                                    * target_weights[i]))

        self.target.set_weights(target_weights)

    def print_metrics(self, epoch, nb_epoch, history_size, policy, value,
                      win_count, history_step, history_reward,
                      history_loss = None, verbose = 1):
        """Function to print metrics of training steps."""
        if verbose == 0:
            pass
        elif verbose == 1:
            text_epoch = ('Epoch: {:03d}/{:03d} | Mean size 10: {:.1f} | '
                           + 'Longest 10: {:03d} | Mean steps 10: {:.1f} | '
                           + 'Wins: {:d} | Win percentage: {:.1f}%')
            print(text_epoch.format(epoch + 1, nb_epoch,
                                    np.mean(history_size[-10:]),
                                    max(history_size[-10:]),
                                    np.mean(history_step[-10:]),
                                    win_count, 100 * win_count/(epoch + 1)))
        else:
            text_epoch = 'Epoch: {:03d}/{:03d}'  # Print epoch info
            print(text_epoch.format(epoch + 1, nb_epoch))

            if loss is not None:  # Print training performance
                text_train = ('\t\x1b[0;30;47m' + ' Training metrics ' + '\x1b[0m'
                              + '\tTotal loss: {:.4f} | Loss per step: {:.4f} | '
                              + 'Mean loss - 100 episodes: {:.4f}')
                print(text_perf.format(history_loss[-1],
                                       history_loss[-1] / history_step[-1],
                                       np.mean(history_loss[-100:])))

            text_game = ('\t\x1b[0;30;47m' + ' Game metrics ' + '\x1b[0m'
                         + '\t\tSize: {:d} | Ammount of steps: {:d} | '
                         + 'Steps per food eaten: {:.1f} | '
                         + 'Mean size - 100 episodes: {:.1f}')
            print(text_game.format(history_size[-1], history_step[-1],
                                   history_size[-1] / history_step[-1],
                                   np.mean(history_step[-100:])))

            # Print policy metrics
            if policy == "BoltzmannQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tBoltzmann Temperature: {:.2f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            elif policy == "BoltzmannGumbelQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tNumber of actions: {:.0f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            else:
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tEpsilon: {:.2f} | Episode reward: {:.1f} | '
                               + 'Wins: {:d} | Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))

    def train_model(self, model, target, batch_size, gamma, nb_actions, epoch = 0):
        """Function to train the model on a batch of the data. The optimization
        flag is used when we are not playing, just batching and optimizing.

        Return
        ----------
        loss: float
            Training loss of given batch.
        """
        loss = 0.
        batch = self.memory.get_targets(model = self.model,
                                        target = self.target,
                                        batch_size = batch_size,
                                        gamma = gamma,
                                        nb_actions = nb_actions,
                                        n_steps = self.n_steps)

        if batch:
            inputs, targets, IS_weights = batch

            if inputs is not None and targets is not None:
                loss = float(self.model.train_on_batch(inputs,
                                                       targets,
                                                       IS_weights))

        return loss

    def train(self, game, nb_epoch = 10000, batch_size = 64, gamma = 0.95,
              eps = [1., .01], temp = [1., 0.01], learning_rate = 0.5,
              observe = 0, optim_rounds = 1, policy = "EpsGreedyQPolicy",
              verbose = 1, n_steps = 1):
        """The main training function, loops the game, remember and choose best
        action given game state (frames)."""
        if not hasattr(self, 'n_steps'):
            self.n_steps = n_steps  # Set attribute only once

        history_size = array('i')  # Holds all the sizes
        history_step = array('f')  # Holds all the steps
        history_loss = array('f')  # Holds all the losses
        history_reward = array('f')  # Holds all the rewards

        # Select exploration policy. EpsGreedyQPolicy runs faster, but takes
        # longer to converge. BoltzmannGumbelQPolicy is the slowest, but
        # converge really fast (0.1 * nb_epoch used in EpsGreedyQPolicy).
        # BoltzmannQPolicy is in the middle.
        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp[0], temp[1], nb_epoch * learning_rate)
        elif policy == "BoltzmannGumbelQPolicy":
            q_policy = BoltzmannGumbelQPolicy()
        else:
            q_policy = EpsGreedyQPolicy(eps[0], eps[1], nb_epoch * learning_rate)

        nb_actions = game.nb_actions
        win_count = 0

        # If optim_rounds is bigger than one, the model will keep optimizing
        # after the exploration, in turns of nb_epoch size.
        for turn in range(optim_rounds):
            if turn > 0:
                for epoch in range(nb_epoch):
                    loss = self.train_model(model = self.model,
                                            epoch = epoch,
                                            target = self.target,
                                            batch_size = batch_size,
                                            gamma = gamma,
                                            nb_actions = nb_actions)
                text_optim = ('Optimizer turn: {:2d} | Epoch: {:03d}/{:03d}'
                              + '| Loss: {:.4f}')
                print(text_optim.format(turn, epoch + 1, nb_epoch, loss))
            else:  # Exploration and training
                for epoch in range(nb_epoch):
                    loss = 0.
                    total_reward = 0.
                    game.reset_game()
                    self.clear_frames()
                    S = self.get_game_data(game)

                    if n_steps > 1:  # Create multi-step returns buffer.
                        n_step_buffer = array('f')

                    while not game.game_over:  # Main loop, until game_over
                        game.food_pos = game.generate_food()
                        action, value = q_policy.select_action(self.model,
                                                               S, epoch,
                                                               nb_actions)
                        game.play(action)
                        r = game.get_reward()
                        total_reward += r

                        if n_steps > 1:
                            n_step_buffer.append(r)

                            if len(n_step_buffer) < n_steps:
                                R = r
                            else:
                                R = sum([n_step_buffer[i] * (gamma ** i)\
                                        for i in range(n_steps)])
                        else:
                            R = r

                        S_prime = self.get_game_data(game)
                        experience = [S, action, R, S_prime, game.game_over]
                        self.memory.remember(*experience)  # Add to the memory
                        S = S_prime  # Advance to the next state (stack of S)

                        if epoch >= observe:  # Get the batchs and train
                            loss += self.train_model(model = self.model,
                                                     target = self.target,
                                                     batch_size = batch_size,
                                                     gamma = gamma,
                                                     nb_actions = nb_actions)

                    if game.is_won():
                        win_count += 1  # Counter of wins for metrics

                    if self.per:  # Advance beta, used in PER
                        self.memory.beta = self.memory.schedule.value(epoch)

                    if self.target is not None:  # Update the target model
                        if update_target_freq >= 1: # Hard updates
                            if epoch % self.update_target_freq == 0:
                                self.update_target_model_hard()
                        elif update_target_freq < 1.:  # Soft updates
                            self.transfer_weights()

                    history_size.append(game.snake.length)
                    history_step.append(game.step)
                    history_loss.append(loss)
                    history_reward.append(total_reward)

                    if (epoch + 1) % 10 == 0:
                        self.print_metrics(epoch = epoch, nb_epoch = nb_epoch,
                                           history_size = history_size,
                                           history_loss = history_loss,
                                           history_step = history_step,
                                           history_reward = history_reward,
                                           policy = policy, value = value,
                                           win_count = win_count,
                                           verbose = verbose)

    def play(self, game, nb_epoch = 1000, eps = 0.01, temp = 0.01,
             visual = False, policy = "GreedyQPolicy"):
        """Play the game with the trained agent. Can use the visual tag to draw
            in pygame."""
        win_count = 0

        history_size = array('i')  # Holds all the sizes
        history_step = array('f')  # Holds all the steps
        history_reward = array('f')  # Holds all the rewards

        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp, temp, nb_epoch)
        elif policy == "EpsGreedyQPolicy":
            q_policy = EpsGreedyQPolicy(eps, eps, nb_epoch)
        else:
            q_policy = GreedyQPolicy()

        for epoch in range(nb_epoch):
            game.reset_game()
            self.clear_frames()
            S = self.get_game_data(game)

            if visual:
                game.create_window()
                # The main loop, it pump key_presses and update every tick.
                environ['SDL_VIDEO_CENTERED'] = '1'  # Centering the window
                previous_size = game.snake.length  # Initial size of the snake
                color_list = game.gradient([(42, 42, 42), (152, 152, 152)],\
                                               previous_size)

            while not game.game_over:
                action, value = q_policy.select_action(self.model, S, epoch, nb_actions)
                game.play(action)
                current_size = game.snake.length  # Update the body size

                if visual:
                    game.draw(color_list)

                    if current_size > previous_size:
                        color_list = game.gradient([(42, 42, 42), (152, 152, 152)],
                                                   game.snake.length)

                        previous_size = current_size

                S = self.get_game_data(game)

                if game.game_over:
                    history_size.append(current_size)
                    history_step.append(game.step)
                    history_reward.append(game.get_reward())

            if game.is_won():
                win_count += 1

        print("Accuracy: {} %".format(100. * win_count / nb_epoch))
        print("Mean size: {} | Biggest size: {} | Smallest size: {}"\
              .format(np.mean(history_size), np.max(history_size),
                      np.min(history_size)))
        print("Mean steps: {} | Biggest step: {} | Smallest step: {}"\
              .format(np.mean(history_step), np.max(history_step),
                      np.min(history_step)))
        print("Mean rewards: {} | Biggest reward: {} | Smallest reward: {}"\
              .format(np.mean(history_reward), np.max(history_reward),
                      np.min(history_reward)))
        
board_size = 10
nb_frames = 4
  
game = Game(player = "ROBOT", board_size = board_size,
                        local_state = True, relative_pos = False)

model = create_model(optimizer = RMSprop(), loss = clipped_error,
                            stack = nb_frames, input_size = board_size,
                            output_size = game.nb_actions, dueling = False, cnn = "CNN3")

target = None

agent = Agent(model = model, target = target, memory_size = 3000000,
                          nb_frames = nb_frames, board_size = board_size,
                          per = True, update_target_freq = 0.001)

#%lprun -f agent.train agent.train(game, batch_size = 64, nb_epoch = 10, gamma = 0.95, update_target_freq = 500, policy = "EpsGreedyQPolicy")
agent.train(game, batch_size = 64, nb_epoch = 10000, gamma = 0.95, policy = "EpsGreedyQPolicy")        

In [0]:
model.save('keras.h5')

!zip -r model-epsgreedy-per.zip keras.h5 
from google.colab import files
files.download('model-epsgreedy-per.zip')
model = load_model('keras.h5', custom_objects={'clipped_error': clipped_error})

board_size = 10
nb_frames = 4
nb_actions = 5

target = None

agent = Agent(model = model, target = target, memory_size = 1500000,
                          nb_frames = nb_frames, board_size = board_size,
                          per = True)
#%lprun -f agent.train agent.train(game, batch_size = 64, nb_epoch = 10, gamma = 0.95, update_target_freq = 500, policy = "EpsGreedyQPolicy")

agent.play(game, visual = False, nb_epoch = 10000)

In [0]:
#!/usr/bin/env python

"""SnakeGame: A simple and fun exploration, meant to be used by Human and AI.
"""

import sys # To close the window when the game is over
from os import environ, path # To center the game window the best possible
import random # Random numbers used for the food
import logging # Logging function for movements and errors
from itertools import tee # For the color gradient on snake
import numpy as np

__author__ = "Victor Neves"
__license__ = "MIT"
__version__ = "1.0"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"

# Actions, options and forbidden moves
options = {'QUIT': 0, 'PLAY': 1, 'BENCHMARK': 2, 'LEADERBOARDS': 3, 'MENU': 4,
           'ADD_LEADERBOARDS': 5}
relative_actions = {'LEFT': 0, 'FORWARD': 1, 'RIGHT': 2}
actions = {'LEFT': 0, 'RIGHT': 1, 'UP': 2, 'DOWN': 3, 'IDLE': 4}
forbidden_moves = [(0, 1), (1, 0), (2, 3), (3, 2)]

# Possible rewards in the game
rewards = {'MOVE': -0.005, 'GAME_OVER': -1, 'SCORED': 1}

# Types of point in the board
point_type = {'EMPTY': 0, 'FOOD': 1, 'BODY': 2, 'HEAD': 3, 'DANGEROUS': 4}

# Speed levels possible to human players
levels = [" EASY ", " MEDIUM ", " HARD ", " MEGA HARDCORE "]

class GlobalVariables:
    """Global variables to be used while drawing and moving the snake game.

    Attributes
    ----------
    BOARD_SIZE: int, optional, default = 30
        The size of the board.
    BLOCK_SIZE: int, optional, default = 20
        The size in pixels of a block.
    HEAD_COLOR: tuple of 3 * int, optional, default = (42, 42, 42)
        Color of the head. Start of the body color gradient.
    TAIL_COLOR: tuple of 3 * int, optional, default = (152, 152, 152)
        Color of the tail. End of the body color gradient.
    FOOD_COLOR: tuple of 3 * int, optional, default = (200, 0, 0)
        Color of the food.
    GAME_SPEED: int, optional, default = 10
        Speed in ticks of the game. The higher the faster.
    BENCHMARK: int, optional, default = 10
        Ammount of matches to BENCHMARK and possibly go to leaderboards.
    """
    def __init__(self, BOARD_SIZE = 30, BLOCK_SIZE = 20,
                 HEAD_COLOR = (42, 42, 42), TAIL_COLOR = (152, 152, 152),
                 FOOD_COLOR = (200, 0, 0), GAME_SPEED = 10, BENCHMARK = 10):
        """Initialize all global variables. Can be updated with argument_handler.
        """
        self.BOARD_SIZE = BOARD_SIZE
        self.BLOCK_SIZE = BLOCK_SIZE
        self.HEAD_COLOR = HEAD_COLOR
        self.TAIL_COLOR = TAIL_COLOR
        self.FOOD_COLOR = FOOD_COLOR
        self.GAME_SPEED = GAME_SPEED
        self.BENCHMARK = BENCHMARK

        if self.BOARD_SIZE > 50: # Warn the user about performance
            logger.warning('WARNING: BOARD IS TOO BIG, IT MAY RUN SLOWER.')

class TextBlock:
    """Block of text class, used by pygame. Can be used to both text and menu.

    Attributes:
    ----------
    text: string
        The text to be displayed.
    pos: tuple of 2 * int
        Color of the tail. End of the body color gradient.
    screen: pygame window object
        The screen where the text is drawn.
    scale: int, optional, default = 1 / 12
        Adaptive scale to resize if the board size changes.
    type: string, optional, default = "text"
        Assert whether the BlockText is a text or menu option.
    """
    def __init__(self, text, pos, screen, scale = (1 / 12), type = "text"):
        """Initialize, set position of the rectangle and render the text block."""
        self.type = type
        self.hovered = False
        self.text = text
        self.pos = pos
        self.screen = screen
        self.scale = scale
        self.set_rect()
        self.draw()

    def draw(self):
        """Set what to render and blit on the pygame screen."""
        self.set_rend()
        self.screen.blit(self.rend, self.rect)

    def set_rend(self):
        """Set what to render (font, colors, sizes)"""
        font = pygame.font.Font(resource_path("resources/fonts/freesansbold.ttf"),
                                int((var.BOARD_SIZE * var.BLOCK_SIZE) * self.scale))
        self.rend = font.render(self.text, True, self.get_color(),
                                self.get_background())

    def get_color(self):
        """Get color to render for text and menu (hovered or not).

        Return
        ----------
        color: tuple of 3 * int
            The color that will be rendered for the text block.
        """
        color = pygame.Color(42, 42, 42)

        if self.type == "menu":
            if self.hovered:
                pass
            else:
                color = pygame.Color(152, 152, 152)

        return color

    def get_background(self):
        """Get background color to render for text (hovered or not) and menu.

        Return
        ----------
        color: tuple of 3 * int
            The color that will be rendered for the background of the text block.
        """
        color = None

        if self.type == "menu":
            if self.hovered:
                color = pygame.Color(152, 152, 152)

        return color

    def set_rect(self):
        """Set the rectangle and it's position to draw on the screen."""
        self.set_rend()
        self.rect = self.rend.get_rect()
        self.rect.center = self.pos


class Snake:
    """Player (snake) class which initializes head, body and board.

    The body attribute represents a list of positions of the body, which are in-
    cremented when moving/eating on the position [0]. The orientation represents
    where the snake is looking at (head) and collisions happen when any element
    is superposed with the head.

    Attributes
    ----------
    head: list of 2 * int, default = [BOARD_SIZE / 4, BOARD_SIZE / 4]
        The head of the snake, located according to the board size.
    body: list of lists of 2 * int
        Starts with 3 parts and grows when food is eaten.
    previous_action: int, default = 1
        Last action which the snake took.
    length: int, default = 3
        Variable length of the snake, can increase when food is eaten.
    """
    def __init__(self):
        """Inits Snake with 3 body parts (one is the head) and pointing right"""
        self.head = [int(var.BOARD_SIZE / 4), int(var.BOARD_SIZE / 4)]
        self.body = [[self.head[0], self.head[1]],
                     [self.head[0] - 1, self.head[1]],
                     [self.head[0] - 2, self.head[1]]]
        self.previous_action = 1
        self.length = 3

    def move(self, action, food_pos):
        """According to orientation, move 1 block. If the head is not positioned
        on food, pop a body part. Else, return without popping.

        Return
        ----------
        ate_food: boolean
            Flag which represents whether the snake ate or not food.
        """
        ate_food = False

        if action == actions['IDLE']\
            or (action, self.previous_action) in forbidden_moves:
            action = self.previous_action
        else:
            self.previous_action = action

        if action == actions['LEFT']:
            self.head[0] -= 1
        elif action == actions['RIGHT']:
            self.head[0] += 1
        elif action == actions['UP']:
            self.head[1] -= 1
        elif action == actions['DOWN']:
            self.head[1] += 1

        self.body.insert(0, list(self.head))

        if self.head == food_pos:
            logger.info('EVENT: FOOD EATEN')
            self.length = len(self.body)

            ate_food = True
        else:
            self.body.pop()

        return ate_food


class FoodGenerator:
    """Generate and keep track of food.

    Attributes
    ----------
    pos:
        Current position of food.
    is_food_on_screen:
        Flag for existence of food.
    """
    def __init__(self, body):
        """Initialize a food piece and set existence flag."""
        self.is_food_on_screen = False
        self.pos = self.generate_food(body)

    def generate_food(self, body):
        """Generate food and verify if it's on a valid place.

        Return
        ----------
        pos: tuple of 2 * int
            Position of the food that was generated. It can't be in the body.
        """
        if not self.is_food_on_screen:
            while True:
                food = [int((var.BOARD_SIZE - 1) * random.random()),
                        int((var.BOARD_SIZE - 1) * random.random())]

                if food in body:
                    continue
                else:
                    self.pos = food
                    break

            logger.info('EVENT: FOOD APPEARED')
            self.is_food_on_screen = True

        return self.pos

      
import numpy as np

class SumTree:
    def __init__(self, capacity):
        self._capacity = capacity
        self._tree = np.zeros(2 * self._capacity - 1)
        self._data = np.zeros(self._capacity, dtype = object)
        self._data_idx = 0

    @property
    def capacity(self):
        return self._capacity

    @property
    def tree(self):
        return self._tree

    @property
    def data(self):
        return self._data

    def sum(self):
        return self._tree[0]

    def insert(self, data, priority):
#        print("Data shape: {}".format(data.shape))
#        print("Stored data shape: {}".format(self._data.shape))
        self._data[self._data_idx] = data
        tree_idx = self._data_idx + self._capacity - 1
        self.update(tree_idx, priority)
        self._data_idx += 1
        if self._data_idx >= self._capacity:
            self._data_idx = 0

    def update(self, tree_idx, priority):
        delta = priority - self._tree[tree_idx]
        self._tree[tree_idx] = priority
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2  # Get parent
            self._tree[tree_idx] += delta

    def retrieve(self, val):
        tree_idx, parent = None, 0
        while True:
            left = 2 * parent + 1
            right = left + 1
            if left >= len(self._tree):  # Leaf
                tree_idx = parent
                break
            else:
                if val <= self._tree[left]:
                    parent = left
                else:
                    val -= self._tree[left]
                    parent = right

        priority = self._tree[tree_idx]
        data = self._data[tree_idx - self._capacity + 1]

        return tree_idx, priority, data

    def max_leaf(self):
        return np.max(self.leaves())

    def min_leaf(self):
        return np.min(self.leaves())

    def leaves(self):
        return self._tree[-self._capacity:]      

class Game:
    """Hold the game window and functions.

    Attributes
    ----------
    window: pygame display
        Pygame window to show the game.
    fps: pygame time clock
        Define Clock and ticks in which the game will be displayed.
    snake: object
        The actual snake who is going to be played.
    food_generator: object
        Generator of food which responds to the snake.
    food_pos: tuple of 2 * int
        Position of the food on the board.
    game_over: boolean
        Flag for game_over.
    player: string
        Define if human or robots are playing the game.
    board_size: int, optional, default = 30
        The size of the board.
    local_state: boolean, optional, default = False
        Whether to use or not game expertise (used mostly by robots players).
    relative_pos: boolean, optional, default = False
        Whether to use or not relative position of the snake head. Instead of
        actions, use relative_actions.
    screen_rect: tuple of 2 * int
        The screen rectangle, used to draw relatively positioned blocks.
    """
    def __init__(self, player, board_size = 30, local_state = False, relative_pos = False):
        """Initialize window, fps and score. Change nb_actions if relative_pos"""
        var.BOARD_SIZE = board_size
        self.local_state = local_state
        self.relative_pos = relative_pos
        self.player = player

        if player == "ROBOT":
            if self.relative_pos:
                self.nb_actions = 3
            else:
                self.nb_actions = 5

            self.reset_game()

    def reset_game(self):
        """Reset the game environment."""
        self.step = 0
        self.snake = Snake()
        self.food_generator = FoodGenerator(self.snake.body)
        self.food_pos = self.food_generator.pos
        self.scored = False
        self.game_over = False

    def create_window(self):
        """Create a pygame display with BOARD_SIZE * BLOCK_SIZE dimension."""
        pygame.init()

        flags = pygame.DOUBLEBUF
        self.window = pygame.display.set_mode((var.BOARD_SIZE * var.BLOCK_SIZE,\
                                               var.BOARD_SIZE * var.BLOCK_SIZE),
                                               flags)
        self.window.set_alpha(None)
        self.screen_rect = self.window.get_rect()
        self.fps = pygame.time.Clock()

    def menu(self):
        """Main menu of the game.

        Return
        ----------
        selected_option: int
            The selected option in the main loop.
        """
        pygame.display.set_caption("SNAKE GAME  | PLAY NOW!")

        img = pygame.image.load(resource_path("resources/images/snake_logo.png"))
        img = pygame.transform.scale(img, (var.BOARD_SIZE * var.BLOCK_SIZE, int(var.BOARD_SIZE * var.BLOCK_SIZE / 3)))

        img_rect = img.get_rect()
        img_rect.center = self.screen_rect.center

        menu_options = [TextBlock(' PLAY GAME ', (self.screen_rect.centerx,
                                                  4 * self.screen_rect.centery / 10),
                                                  self.window, (1 / 12), "menu"),
                        TextBlock(' BENCHMARK ', (self.screen_rect.centerx,
                                                  6 * self.screen_rect.centery / 10),
                                                  self.window, (1 / 12), "menu"),
                        TextBlock(' LEADERBOARDS ', (self.screen_rect.centerx,
                                                     8 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 12), "menu"),
                        TextBlock(' QUIT ', (self.screen_rect.centerx,
                                             10 * self.screen_rect.centery / 10),
                                             self.window, (1 / 12), "menu")]
        selected = False
        selected_option = None

        while not selected:
            pygame.event.pump()
            ev = pygame.event.get()

            self.window.fill(pygame.Color(225, 225, 225))

            for option in menu_options:
                option.draw()

                if option.rect.collidepoint(pygame.mouse.get_pos()):
                    option.hovered = True

                    if option == menu_options[0]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['PLAY']
                    elif option == menu_options[1]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['BENCHMARK']
                    elif option == menu_options[2]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['LEADERBOARDS']
                    elif option == menu_options[3]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['QUIT']
                else:
                    option.hovered = False

            if selected_option is not None:
                selected = True

            self.window.blit(img, img_rect.bottomleft)
            pygame.display.update()

        return selected_option

    def start_match(self):
        """Create some wait time before the actual drawing of the game."""
        for i in range(3):
            time = str(3 - i)
            self.window.fill(pygame.Color(225, 225, 225))

            # Game starts in 3, 2, 1
            text = [TextBlock('Game starts in', (self.screen_rect.centerx,
                                                 4 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 10), "text"),
                    TextBlock(time, (self.screen_rect.centerx,
                                                 12 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 1.5), "text")]

            for text_block in text:
                text_block.draw()

            pygame.display.update()
            pygame.display.set_caption("SNAKE GAME  |  Game starts in "
                                       + time + " second(s) ...")

            pygame.time.wait(1000)

        logger.info('EVENT: GAME START')

    def start(self):
        """Use menu to select the option/game mode."""
        opt = self.menu()
        running = True

        while running:
            if opt == options['QUIT']:
                pygame.quit()
                sys.exit()
            elif opt == options['PLAY']:
                var.GAME_SPEED = self.select_speed()
                self.reset_game()
                self.start_match()
                score = self.single_player()
                opt = self.over(score)
            elif opt == options['BENCHMARK']:
                var.GAME_SPEED = self.select_speed()
                score = []

                for i in range(var.BENCHMARK):
                    self.reset_game()
                    self.start_match()
                    score.append(self.single_player())

                opt = self.over(score)
            elif opt == options['LEADERBOARDS']:
                pass
            elif opt == options['ADD_LEADERBOARDS']:
                pass
            elif opt == options['MENU']:
                opt = self.menu()

    def over(self, score):
        """If collision with wall or body, end the game and open options.

        Return
        ----------
        selected_option: int
            The selected option in the main loop.
        """
        menu_options = [None] * 5
        menu_options[0] = TextBlock(' PLAY AGAIN ', (self.screen_rect.centerx,
                                                     4 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 15), "menu")
        menu_options[1] = TextBlock(' GO TO MENU ', (self.screen_rect.centerx,
                                                     6 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 15), "menu")
        menu_options[3] = TextBlock(' QUIT ', (self.screen_rect.centerx,
                                               10 * self.screen_rect.centery / 10),
                                               self.window, (1 / 15), "menu")

        if isinstance(score, int):
            text_score = 'SCORE: ' + str(score)
        else:
            text_score = 'MEAN SCORE: ' + str(sum(score) / var.BENCHMARK)
            menu_options[2] = TextBlock(' ADD TO LEADERBOARDS ', (self.screen_rect.centerx,
                                                                  8 * self.screen_rect.centery / 10),
                                                                  self.window, (1 / 15), "menu")

        pygame.display.set_caption("SNAKE GAME  | " + text_score
                                   + "  |  GAME OVER...")
        logger.info('EVENT: GAME OVER | FINAL ' + text_score)
        menu_options[4] = TextBlock(text_score, (self.screen_rect.centerx,
                                                 15 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 10), "text")
        selected = False
        selected_option = None

        while not selected:
            pygame.event.pump()
            ev = pygame.event.get()

            # Game over screen
            self.window.fill(pygame.Color(225, 225, 225))

            for option in menu_options:
                if option is not None:
                    option.draw()

                    if option.rect.collidepoint(pygame.mouse.get_pos()):
                        option.hovered = True

                        if option == menu_options[0]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    selected_option = options['PLAY']
                        elif option == menu_options[1]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    selected_option = options['MENU']
                        elif option == menu_options[2]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    selected_option = options['ADD_LEADERBOARDS']
                        elif option == menu_options[3]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    pygame.quit()
                                    sys.exit()
                    else:
                        option.hovered = False

            if selected_option is not None:
                selected = True

            pygame.display.update()

        return selected_option

    def select_speed(self):
        """Speed menu, right before calling start_match.

        Return
        ----------
        speed: int
            The selected speed in the main loop.
        """
        menu_options = [TextBlock(levels[0], (self.screen_rect.centerx,
                                              4 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu"),
                        TextBlock(levels[1], (self.screen_rect.centerx,
                                              8 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu"),
                        TextBlock(levels[2], (self.screen_rect.centerx,
                                              12 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu"),
                        TextBlock(levels[3], (self.screen_rect.centerx,
                                              16 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu")]
        selected = False
        speed = None

        while not selected:
            pygame.event.pump()
            ev = pygame.event.get()

            # Game over screen
            self.window.fill(pygame.Color(225, 225, 225))

            for option in menu_options:
                if option is not None:
                    option.draw()

                    if option.rect.collidepoint(pygame.mouse.get_pos()):
                        option.hovered = True

                        if option == menu_options[0]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = 10
                        elif option == menu_options[1]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = 20
                        elif option == menu_options[2]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = 30
                        elif option == menu_options[3]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = 45
                    else:
                        option.hovered = False

            if speed is not None:
                selected = True

            pygame.display.update()

        return speed

    def single_player(self):
        """Game loop for single_player (HUMANS).

        Return
        ----------
        score: int
            The final score for the match (discounted of initial length).
        """
        # The main loop, it pump key_presses and update the board every tick.
        previous_size = self.snake.length # Initial size of the snake
        current_size = previous_size # Initial size
        color_list = self.gradient([(42, 42, 42), (152, 152, 152)],\
                                   previous_size)

        # Main loop, where the snake keeps going each tick. It generate food,
        # check collisions and draw.
        while not self.game_over:
            action = self.handle_input()
            self.game_over = self.play(action)
            self.draw(color_list)
            current_size = self.snake.length # Update the body size

            if current_size > previous_size:
                color_list = self.gradient([(42, 42, 42), (152, 152, 152)],\
                                           current_size)

                previous_size = current_size

        score = current_size - 3

        return score

    def check_collision(self):
        """Check wether any collisions happened with the wall or body.

        Return
        ----------
        collided: boolean
            Whether the snake collided or not.
        """
        collided = False

        if self.snake.head[0] > (var.BOARD_SIZE - 1) or self.snake.head[0] < 0:
            logger.info('EVENT: WALL COLLISION')
            collided = True
        elif self.snake.head[1] > (var.BOARD_SIZE - 1) or self.snake.head[1] < 0:
            logger.info('EVENT: WALL COLLISION')
            collided = True
        elif self.snake.head in self.snake.body[1:]:
            logger.info('EVENT: BODY COLLISION')
            collided = True

        return collided

    def is_won(self):
        """Verify if the score is greater than 0.

        Return
        ----------
        won: boolean
            Whether the score is greater than 0.
        """
        return self.snake.length > 3

    def generate_food(self):
        """Generate new food if needed.

        Return
        ----------
        food_pos: tuple of 2 * int
            Current position of the food.
        """
        food_pos = self.food_generator.generate_food(self.snake.body)

        return food_pos

    def handle_input(self):
        """After getting current pressed keys, handle important cases.

        Return
        ----------
        action: int
            Handle human input to assess the next action.
        """
        pygame.event.set_allowed([pygame.QUIT, pygame.KEYDOWN])
        keys = pygame.key.get_pressed()
        pygame.event.pump()
        action = self.snake.previous_action

        if keys[pygame.K_ESCAPE] or keys[pygame.K_q]:
            logger.info('ACTION: KEY PRESSED: ESCAPE or Q')
            self.over(self.snake.length - 3)
        elif keys[pygame.K_LEFT]:
            logger.info('ACTION: KEY PRESSED: LEFT')
            action = actions['LEFT']
        elif keys[pygame.K_RIGHT]:
            logger.info('ACTION: KEY PRESSED: RIGHT')
            action = actions['RIGHT']
        elif keys[pygame.K_UP]:
            logger.info('ACTION: KEY PRESSED: UP')
            action = actions['UP']
        elif keys[pygame.K_DOWN]:
            logger.info('ACTION: KEY PRESSED: DOWN')
            action = actions['DOWN']

        return action

    def eval_local_safety(self, canvas, body):
        """Evaluate the safety of the head's possible next movements.

        Return
        ----------
        canvas: np.array of size BOARD_SIZE**2
            After using game expertise, change canvas values to DANGEROUS if true.
        """
        if (body[0][0] + 1) > (var.BOARD_SIZE - 1)\
            or ([body[0][0] + 1, body[0][1]]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 0] = point_type['DANGEROUS']
        if (body[0][0] - 1) < 0 or ([body[0][0] - 1, body[0][1]]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 1] = point_type['DANGEROUS']
        if (body[0][1] - 1) < 0 or ([body[0][0], body[0][1] - 1]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 2] = point_type['DANGEROUS']
        if (body[0][1] + 1) > (var.BOARD_SIZE - 1)\
            or ([body[0][0], body[0][1] + 1]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 3] = point_type['DANGEROUS']

        return canvas

    def state(self):
        """Create a matrix of the current state of the game.

        Return
        ----------
        canvas: np.array of size BOARD_SIZE**2
            Return the current state of the game in a matrix.
        """
        canvas = np.zeros((var.BOARD_SIZE, var.BOARD_SIZE))

        if self.game_over:
            pass
        else:
            body = self.snake.body

            for part in body:
                canvas[part[0], part[1]] = point_type['BODY']

            canvas[body[0][0], body[0][1]] = point_type['HEAD']

            if self.local_state:
                canvas = self.eval_local_safety(canvas, body)

            canvas[self.food_pos[0], self.food_pos[1]] = point_type['FOOD']

        return canvas

    def relative_to_absolute(self, action):
        """Translate relative actions to absolute.

        Return
        ----------
        action: int
            Translated action from relative to absolute.
        """
        if action == relative_actions['FORWARD']:
            action = self.snake.previous_action
        elif action == relative_actions['LEFT']:
            if self.snake.previous_action == actions['LEFT']:
                action = actions['DOWN']
            elif self.snake.previous_action == actions['RIGHT']:
                action = actions['UP']
            elif self.snake.previous_action == actions['UP']:
                action = actions['LEFT']
            else:
                action = actions['RIGHT']
        else:
            if self.snake.previous_action == actions['LEFT']:
                action = actions['UP']
            elif self.snake.previous_action == actions['RIGHT']:
                action = actions['DOWN']
            elif self.snake.previous_action == actions['UP']:
                action = actions['RIGHT']
            else:
                action = actions['LEFT']

        return action

    def play(self, action):
        """Move the snake to the direction, eat and check collision."""
        self.scored = False
        self.step += 1
        self.food_pos = self.generate_food()

        if self.relative_pos:
            action = self.relative_to_absolute(action)

        if self.snake.move(action, self.food_pos):
            self.scored = True
            self.food_generator.is_food_on_screen = False

        if self.player == "HUMAN":
            if self.check_collision():
                return True
        elif self.check_collision() or self.step > 50 * self.snake.length:
            self.game_over = True

    def get_reward(self):
        """Return the current score. Can be used as the reward function.

        Return
        ----------
        reward: float
            Current reward of the game.
        """
        reward = rewards['MOVE']

        if self.game_over:
            reward = rewards['GAME_OVER']
        elif self.scored:
            reward = self.snake.length

        return reward

    def gradient(self, colors, steps, components = 3):
        """Function to create RGB gradients given 2 colors and steps. If
        component is changed to 4, it does the same to RGBA colors.

        Return
        ----------
        result: list of steps length of tuple of 3 * int (if RGBA, 4 * int)
            List of colors of calculated gradient from start to end.
        """
        def linear_gradient(start, finish, substeps):
            yield start

            for i in range(1, substeps):
                yield tuple([(start[j] + (float(i) / (substeps-1)) * (finish[j]\
                            - start[j])) for j in range(components)])

        def pairs(seq):
            a, b = tee(seq)
            next(b, None)

            return zip(a, b)

        result = []
        substeps = int(float(steps) / (len(colors) - 1))

        for a, b in pairs(colors):
            for c in linear_gradient(a, b, substeps):
                result.append(c)

        return result

    def draw(self, color_list):
        """Draw the game, the snake and the food using pygame."""
        self.window.fill(pygame.Color(225, 225, 225))

        for part, color in zip(self.snake.body, color_list):
            pygame.draw.rect(self.window, color, pygame.Rect(part[0] *\
                        var.BLOCK_SIZE, part[1] * var.BLOCK_SIZE, \
                        var.BLOCK_SIZE, var.BLOCK_SIZE))

        pygame.draw.rect(self.window, var.FOOD_COLOR,\
                         pygame.Rect(self.food_pos[0] * var.BLOCK_SIZE,\
                         self.food_pos[1] * var.BLOCK_SIZE, var.BLOCK_SIZE,\
                         var.BLOCK_SIZE))

        pygame.display.set_caption("SNAKE GAME  |  Score: "
                                    + str(self.snake.length - 3))
        pygame.display.update()
        self.fps.tick(var.GAME_SPEED)

def resource_path(relative_path):
    """Function to return absolute paths. Used while creating .exe file."""
    if hasattr(sys, '_MEIPASS'):
        return path.join(sys._MEIPASS, relative_path)

    return path.join(path.dirname(path.realpath(__file__)), relative_path)

var = GlobalVariables() # Initializing GlobalVariables
logger = logging.getLogger(__name__) # Setting logger
environ['SDL_VIDEO_CENTERED'] = '1' # Centering the window

from keras.layers import Dense
from keras.models import Sequential, Model
import tensorflow as tf
from keras import backend as K
from keras.layers import Input
import numpy as np

class Net(object):
    def __init__(self, nb_frames, board_size):
        model = Sequential()
        model.add(Dense(50, activation='relu', input_dim=(board_size**2)))
        model.add(Dense(50, activation='relu'))
        self.pi_model = Sequential([model])
        self.pi_model.add(Dense(512, activation='relu'))
        self.pi_model.add(Dense(5, activation='softmax'))
        self.v_model = Sequential([model])
        self.v_model.add(Dense(512, activation='relu'))
        self.v_model.add(Dense(1))


class PCL(object):
    def __init__(self, epoch, env, replay_buffer, sess=None, net=None,
            pi_optimizer=None, v_optimizer=None, off_policy_rate=20,
            pi_lr=7e-4, v_rate=0.5, entropy_tau=0.5, rollout_d=20, gamma=1):
        self.epoch = epoch
        self.env = env
        self.replay_buffer = replay_buffer
        self.sess = sess
        self.net = net
        if pi_optimizer is None:
            self.pi_optimizer = tf.train.AdamOptimizer(pi_lr)
        else:
            self.pi_optimizer = pi_optimizer
        if v_optimizer is None:
            self.v_optimizer = tf.train.AdamOptimizer(pi_lr*v_rate)
        else:
            self.v_optimizer = v_optimizer
        self.off_policy_rate = off_policy_rate
        self.entropy_tau = entropy_tau
        self.rollout_d = rollout_d
        self.gamma = gamma
        self.state_shape = 100
        self.action_shape = [5]
        self.built = False

    def build(self):
        pi_model = self.net.pi_model
        v_model = self.net.v_model
        self.state = tf.placeholder(tf.float32, shape=[None, None, 100], name='state')
        self.R = tf.placeholder(tf.float32, shape=[None, None], name='R')
        self.action = tf.placeholder(tf.float32, shape=[None, None, 5], name='action')
        self.discount = tf.placeholder(tf.float32, shape=[None], name='discount')

        v_s_t = v_model(self.state[:, 0, :])
        v_s_t_d = v_model(self.state[:, -1, :])
        self.pi = pi_model(self.state)
        C = K.sum(-v_s_t + self.gamma ** self.rollout_d * v_s_t_d + \
                K.sum(self.R, axis=1) - self.entropy_tau * K.sum(self.discount * \
                K.sum(K.log(self.pi+K.epsilon()) * self.action, axis=2), axis=1), axis=0)
        self.loss = C ** 2

        self.updater = [self.pi_optimizer.minimize(self.loss, var_list=pi_model.trainable_weights),
                self.v_optimizer.minimize(self.loss, var_list=v_model.trainable_weights)]
        self.sess.run(tf.global_variables_initializer())
        self.built = True

    def optimize(self, episode):
        if not self.built:
            self.build()
        if len(episode['states']) < self.rollout_d:
            rollout_d = len(episode['states'])
        else:
            rollout_d = self.rollout_d
        discount = np.array([self.gamma**i for i in range(rollout_d)], dtype=np.float32)
        state = []
        action = []
        R = []
        for i in range(len(episode['states'])-rollout_d+1):
            state.append(episode['states'][i:i+rollout_d])
            a = episode['actions'][i:i+rollout_d]
            action.append(np.eye(*self.action_shape, dtype=np.int32)[a])
            R.append(episode['rewards'][i:i+rollout_d])
        feed_in = {self.state: state, self.action: action, self.R: R, self.discount: discount}
        self.sess.run(self.updater, feed_in)

    def rollout(self):
        states = []
        actions = []
        rewards = []
        agent_infos = []
        self.env.reset_game()
        s = self.env.state().flatten()

        while not self.env.game_over:
            game.food_pos = game.generate_food()
            a, agent_info = self.get_action(s)

            self.env.play(a)

            r = self.env.get_reward()
            next_s = self.env.state().flatten()

            states.append(s)
            rewards.append(r)
            actions.append(a)
            agent_infos.append(agent_info)

            s = next_s
        return dict(
            states=np.array(states),
            actions=np.array(actions),
            rewards=np.array(rewards),
            agent_infos=np.array(agent_infos)
        )

    def get_action(self, state):
        if not self.built:
            self.build()
        pi = self.sess.run(self.pi, {self.state: [[state]]})[0][0]
        a = np.random.choice(np.arange(self.action_shape[0]), p=pi)
        return a, dict(prob=pi)

    def train(self):
        rewards = []
        entropy = []
        for i in range(self.epoch):
            episode = self.rollout()
            self.optimize(episode)
            rewards.append(episode['rewards'].sum())
            p = np.array([agent_info['prob'] for agent_info in episode['agent_infos']])
            ent = -np.sum(p * np.log(p+K.epsilon()), axis=1)
            entropy.append(ent.mean())
            if (i + 1) % 100 == 0:
                print("Epoch: {:03d}/{:03d} | ".format(i + 1, self.epoch), end = "")
                print("Reward: {:.2f} | ".format(np.mean(rewards[-100:])), end = "")
                print("Mean ent: {:.2f}".format(np.mean(entropy[-100:])))
            self.replay_buffer.add(episode)
            if self.replay_buffer.trainable:
                for _ in range(self.off_policy_rate):
                    episode = self.replay_buffer.sample()
                    self.optimize(episode)


class ReplayBuffer(object):
    def __init__(self, max_len=3000000, alpha=1):
        self.max_len = max_len
        self.alpha = alpha
        self.buffer = []
        # weight is not normalized
        self.weight = np.array([])

    def add(self, episode):
        self.buffer.append(episode)
        self.weight = np.append(self.weight, np.exp(self.alpha*episode['rewards'].sum()))
        if len(self.buffer) > self.max_len:
            delete_ind = np.random.randint(len(self.buffer))
            del self.buffer[delete_ind]
            self.weight = np.delete(self.weight, delete_ind)

    def sample(self):
        return np.random.choice(self.buffer, p=self.weight/self.weight.sum())

    @property
    def trainable(self):
        if len(self.buffer) > 32:
            return True
        else:
            return False

board_size = 10
nb_frames = 4

net = Net(1, board_size)

game = Game(player = "ROBOT", board_size = board_size,
                        local_state = True, relative_pos = False)

replay_buffer = ReplayBuffer()

sess = tf.Session()

agent = PCL(20000, game, replay_buffer, sess, net)
agent.train()


Epoch: 100/20000 | Reward: -0.37 | Mean ent: 1.12
Epoch: 200/20000 | Reward: -0.42 | Mean ent: 0.83
Epoch: 300/20000 | Reward: -0.29 | Mean ent: 1.01
Epoch: 400/20000 | Reward: -0.46 | Mean ent: 0.89
Epoch: 500/20000 | Reward: -0.64 | Mean ent: 1.19
Epoch: 600/20000 | Reward: -0.32 | Mean ent: 1.11
Epoch: 700/20000 | Reward: -0.54 | Mean ent: 0.94
Epoch: 800/20000 | Reward: -0.46 | Mean ent: 0.82
Epoch: 900/20000 | Reward: -0.51 | Mean ent: 0.80
Epoch: 1000/20000 | Reward: -0.30 | Mean ent: 0.89
Epoch: 1100/20000 | Reward: -0.29 | Mean ent: 0.68
Epoch: 1200/20000 | Reward: -0.78 | Mean ent: 0.59
Epoch: 1300/20000 | Reward: -0.16 | Mean ent: 0.67
Epoch: 1400/20000 | Reward: -0.43 | Mean ent: 0.68
Epoch: 1500/20000 | Reward: 0.17 | Mean ent: 0.67
Epoch: 1600/20000 | Reward: 0.17 | Mean ent: 0.68
Epoch: 1700/20000 | Reward: -0.47 | Mean ent: 0.78
Epoch: 1800/20000 | Reward: -0.04 | Mean ent: 0.77
Epoch: 1900/20000 | Reward: 0.52 | Mean ent: 0.71
Epoch: 2000/20000 | Reward: -0.31 | Mean en

In [0]:
#!/usr/bin/env python

"""SnakeGame: A simple and fun exploration, meant to be used by AI algorithms.
"""

from sys import exit # To close the window when the game is over
from os import environ # To center the game window the best possible
import random # Random numbers used for the food
import logging # Logging function for movements and errors
from itertools import tee # For the color gradient on snake
#!pip install pygame # Jupyter Notebook
#import pygame # This is the engine used in the game
import numpy as np

__author__ = "Victor Neves"
__license__ = "MIT"
__version__ = "1.0"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"

# Actions and forbidden moves
relative_actions = {'LEFT': 0, 'FORWARD': 1, 'RIGHT': 2}
actions = {'LEFT': 0, 'RIGHT': 1, 'UP': 2, 'DOWN': 3, 'IDLE': 4}
forbidden_moves = [(0, 1), (1, 0), (2, 3), (3, 2)]

# Types of point in the board
point_type = {'EMPTY': 0, 'FOOD': 1, 'BODY': 2, 'HEAD': 3, 'DANGEROUS': 4}

class GlobalVariables:
    """Global variables to be used while drawing and moving the snake game.

    Attributes:
        BLOCK_SIZE: The size in pixels of a block.
        HEAD_COLOR: Color of the head.
        BODY_COLOR: Color of the body.
        FOOD_COLOR: Color of the food.
        GAME_SPEED: Speed in ticks of the game. The higher the faster.
    """
    def __init__(self):
        """Initialize all global variables."""
        self.BOARD_SIZE = 30
        self.BLOCK_SIZE = 20
        self.HEAD_COLOR = (0, 0, 0)
        self.BODY_COLOR = (0, 200, 0)
        self.FOOD_COLOR = (200, 0, 0)
        self.GAME_SPEED = 24

        if self.BOARD_SIZE > 50:
            logger.warning('WARNING: BOARD IS TOO BIG, IT MAY RUN SLOWER.')


class Snake:
    """Player (snake) class which initializes head, body and board.

    The body attribute represents a list of positions of the body, which are in-
    cremented when moving/eating on the position [0]. The orientation represents
    where the snake is looking at (head) and collisions happen when any element
    is superposed with the head.

    Attributes:
        head: The head of the snake, located according to the board size.
        body: Starts with 3 parts and grows when food is eaten.
        orientation: Current orientation where head is pointing.
    """
    def __init__(self):
        """Inits Snake with 3 body parts (one is the head) and pointing right"""
        self.head = [int(var.BOARD_SIZE / 4), int(var.BOARD_SIZE / 4)]
        self.body = [[self.head[0], self.head[1]],
                     [self.head[0] - 1, self.head[1]],
                     [self.head[0] - 2, self.head[1]]]
        self.previous_action = 1
        self.length = 3

    def move(self, action, food_pos):
        """According to orientation, move 1 block. If the head is not positioned
        on food, pop a body part. Else (food), return without popping."""
        if action == actions['IDLE']\
            or (action, self.previous_action) in forbidden_moves:
            action = self.previous_action
        else:
            self.previous_action = action

        if action == actions['LEFT']:
            self.head[0] -= 1
        elif action == actions['RIGHT']:
            self.head[0] += 1
        elif action == actions['UP']:
            self.head[1] -= 1
        elif action == actions['DOWN']:
            self.head[1] += 1

        self.body.insert(0, list(self.head))

        if self.head == food_pos:
            logger.info('EVENT: FOOD EATEN')
            self.length = len(self.body)

            return True
        else:
            self.body.pop()

            return False

    def check_collision(self):
        """Check wether any collisions happened with the wall or body and re-
        turn."""
        if self.head[0] > (var.BOARD_SIZE - 1) or self.head[0] < 0:
            logger.info('EVENT: WALL COLLISION')

            return True
        elif self.head[1] > (var.BOARD_SIZE - 1) or self.head[1] < 0:
            logger.info('EVENT: WALL COLLISION')

            return True
        elif self.head in self.body[1:]:
            logger.info('EVENT: BODY COLLISION')

            return True

        return False

    def return_body(self):
        """Return the whole body."""
        return self.body


class FoodGenerator:
    """Generate and keep track of food.

    Attributes:
        pos: Current position of food.
        is_food_on_screen: Flag for existence of food.
    """
    def __init__(self, body):
        """Initialize a food piece and set existence flag."""
        self.is_food_on_screen = False
        self.pos = self.generate_food(body)

    def generate_food(self, body):
        """Generate food and verify if it's on a valid place."""
        if not self.is_food_on_screen:
            while True:
                food = [int((var.BOARD_SIZE - 1) * random.random()),
                        int((var.BOARD_SIZE - 1) * random.random())]

                if food in body:
                    continue
                else:
                    self.pos = food
                    break

            logger.info('EVENT: FOOD APPEARED')
            self.is_food_on_screen = True

        return self.pos

    def set_food_on_screen(self, bool_value):
        """Set flag for existence (or not) of food."""
        self.is_food_on_screen = bool_value


class Game:
    """Hold the game window and functions.

    Attributes:
        window: pygame window to show the game.
        fps: Define Clock and ticks in which the game will be displayed.
        snake: The actual snake who is going to be played.
        food_generator: Generator of food which responds to the snake.
        food_pos: Position of the food on the board.
        game_over: Flag for game_over.
    """
    def __init__(self, board_size = 30, local_state = False, relative_pos = False):
        """Initialize window, fps and score."""
        var.BOARD_SIZE = board_size
        self.local_state = local_state
        self.relative_pos = relative_pos
        if self.relative_pos:
            self.nb_actions = 3
        else:
            self.nb_actions = 5
        self.reset()

    def reset(self):
        self.step = 0
        self.snake = Snake()
        self.food_generator = FoodGenerator(self.snake.body)
        self.food_pos = self.food_generator.pos
        self.scored = False
        self.game_over = False

    def create_window(self):
        flags = pygame.DOUBLEBUF
        self.window = pygame.display.set_mode((var.BOARD_SIZE * var.BLOCK_SIZE,\
                                               var.BOARD_SIZE * var.BLOCK_SIZE),
                                               flags)
        self.window.set_alpha(None)
        self.fps = pygame.time.Clock()

    def start(self):
        """Create some wait time before the actual drawing of the game."""
        for i in range(3):
            pygame.display.set_caption("SNAKE GAME  |  Game starts in " +\
                                       str(3 - i) + " second(s) ...")
            pygame.time.wait(1000)
        logger.info('EVENT: GAME START')

    def over(self):
        """If collision with wall or body, end the game."""
        pygame.display.set_caption("SNAKE GAME  |  Score: "
                            + str(self.snake.length - 3)
                            + "  |  GAME OVER. Press any Q or ESC to quit ...")
        logger.info('EVENT: GAME OVER')

        while True:
            keys = pygame.key.get_pressed()
            pygame.event.pump()

            if keys[pygame.K_ESCAPE] or keys[pygame.K_q]:
                logger.info('ACTION: KEY PRESSED: ESCAPE or Q')
                break

        pygame.quit()
        exit()

    def is_won(self):
        return self.snake.length > 3

    def generate_food(self):
        return self.food_generator.generate_food(self.snake.body)

    def handle_input(self, previous_action):
        """After getting current pressed keys, handle important cases."""
        pygame.event.set_allowed([pygame.QUIT, pygame.KEYDOWN])
        keys = pygame.key.get_pressed()
        pygame.event.pump()

        if keys[pygame.K_ESCAPE] or keys[pygame.K_q]:
            logger.info('ACTION: KEY PRESSED: ESCAPE or Q')
            self.over()
        elif keys[pygame.K_LEFT]:
            logger.info('ACTION: KEY PRESSED: LEFT')
            return actions['LEFT']
        elif keys[pygame.K_RIGHT]:
            logger.info('ACTION: KEY PRESSED: RIGHT')
            return actions['RIGHT']
        elif keys[pygame.K_UP]:
            logger.info('ACTION: KEY PRESSED: UP')
            return actions['UP']
        elif keys[pygame.K_DOWN]:
            logger.info('ACTION: KEY PRESSED: DOWN')
            return actions['DOWN']
        else:
            return previous_action

    def eval_local_safety(self, canvas, body):
        """Evaluate the safety of the head's possible next movements."""
        if (body[0][0] + 1) > (var.BOARD_SIZE - 1)\
            or ([body[0][0] + 1, body[0][1]]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 0] = point_type['DANGEROUS']
        if (body[0][0] - 1) < 0 or ([body[0][0] - 1, body[0][1]]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 1] = point_type['DANGEROUS']
        if (body[0][1] - 1) < 0 or ([body[0][0], body[0][1] - 1]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 2] = point_type['DANGEROUS']
        if (body[0][1] + 1) > (var.BOARD_SIZE - 1)\
            or ([body[0][0], body[0][1] + 1]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 3] = point_type['DANGEROUS']

        return canvas

    def state(self):
        """Create a matrix of the current state of the game."""
        body = self.snake.return_body()
        canvas = np.zeros((var.BOARD_SIZE, var.BOARD_SIZE))

        for part in body:
            canvas[part[0], part[1]] = point_type['BODY']

        canvas[body[0][0], body[0][1]] = point_type['HEAD']

        if self.local_state:
            canvas = self.eval_local_safety(canvas, body)

        canvas[self.food_pos[0], self.food_pos[1]] = point_type['FOOD']

        return canvas

    def relative_to_absolute(self, action):
        if action == relative_actions['FORWARD']:
            action = self.snake.previous_action
        elif action == relative_actions['LEFT']:
            if self.snake.previous_action == actions['LEFT']:
                action = actions['DOWN']
            elif self.snake.previous_action == actions['RIGHT']:
                action = actions['UP']
            elif self.snake.previous_action == actions['UP']:
                action = actions['LEFT']
            else:
                action = actions['RIGHT']
        else:
            if self.snake.previous_action == actions['LEFT']:
                action = actions['UP']
            elif self.snake.previous_action == actions['RIGHT']:
                action = actions['DOWN']
            elif self.snake.previous_action == actions['UP']:
                action = actions['RIGHT']
            else:
                action = actions['LEFT']

        return action

    def play(self, action, player):
        """Move the snake to the direction, eat and check collision."""
        self.scored = False
        self.step += 1
        self.food_pos = self.generate_food()

        if self.relative_pos:
            action = self.relative_to_absolute(action)

        if self.snake.move(action, self.food_pos):
            self.scored = True
            self.food_generator.set_food_on_screen(False)

        if player == "HUMAN":
            if self.snake.check_collision():
                self.over()
        elif self.snake.check_collision() or self.step > 50 * self.snake.length:
            self.game_over = True

    def get_reward(self):
        """Return the current score. Can be used as the reward function."""
        if self.game_over:
            return -1
        elif self.scored:
            return self.snake.length

        return -0.005

    def gradient(self, colors, steps, components = 3):
        """Function to create RGB gradients given 2 colors and steps.

        If component is changed to 4, it does the same to RGBA colors."""
        def linear_gradient(start, finish, substeps):
            yield start
            for i in range(1, substeps):
                yield tuple([(start[j] + (float(i) / (substeps-1)) * (finish[j]\
                            - start[j])) for j in range(components)])

        def pairs(seq):
            a, b = tee(seq)
            next(b, None)
            return zip(a, b)

        result = []
        substeps = int(float(steps) / (len(colors) - 1))

        for a, b in pairs(colors):
            for c in linear_gradient(a, b, substeps):
                result.append(c)

        return result

    def draw(self, color_list):
        """Draw the game, the snake and the food using pygame."""
        self.window.fill(pygame.Color(225, 225, 225))

        for part, color in zip(self.snake.body, color_list):
            pygame.draw.rect(self.window, color, pygame.Rect(part[0] *\
                        var.BLOCK_SIZE, part[1] * var.BLOCK_SIZE, \
                        var.BLOCK_SIZE, var.BLOCK_SIZE))

        pygame.draw.rect(self.window, var.FOOD_COLOR,\
                         pygame.Rect(self.food_pos[0] * var.BLOCK_SIZE,\
                         self.food_pos[1] * var.BLOCK_SIZE, var.BLOCK_SIZE,\
                         var.BLOCK_SIZE))

        pygame.display.set_caption("SNAKE GAME  |  Score: "
                                    + str(self.snake.length - 3))
        pygame.display.update()
        self.fps.tick(var.GAME_SPEED)


def main():
    """The main function where the game will be executed."""
    # Setup basic configurations for logging in this module
    logging.basicConfig(format = '%(asctime)s %(module)s %(levelname)s: %(message)s',
                        datefmt = '%m/%d/%Y %I:%M:%S %p', level = logging.INFO)
    game = Game()
    game.create_window()
    game.start()

    # The main loop, it pump key_presses and update the board every tick.
    previous_size = game.snake.length # Initial size of the snake
    current_size = 3 # Initial size of the snake
    color_list = game.gradient([(42, 42, 42), (152, 152, 152)],\
                               previous_size)

    # Main loop, where the snake keeps going each tick. It generate food, check
    # collisions and draw.
    while True:
        action = game.handle_input(game.snake.previous_action)
        game.play(action, "HUMAN")
        game.draw(color_list)

        current_size = game.snake.length # Update the body size

        if current_size > previous_size:
            color_list = game.gradient([(42, 42, 42), (152, 152, 152)],\
                                       game.snake.length)

        previous_size = current_size

var = GlobalVariables() # Initializing GlobalVariables
logger = logging.getLogger(__name__) # Setting logger
logger.setLevel(logging.ERROR)
environ['SDL_VIDEO_CENTERED'] = '1' # Centering the window

class SumTree(object):
    """
    This SumTree code is modified version of Morvan Zhou: 
    https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/5.2_Prioritized_Replay_DQN/RL_brain.py
    """
    data_pointer = 0
    
    """
    Here we initialize the tree with all nodes = 0, and initialize the data with all values = 0
    """
    def __init__(self, capacity):
        self.capacity = capacity # Number of leaf nodes (final nodes) that contains experiences
        
        # Generate the tree with all nodes values = 0
        # To understand this calculation (2 * capacity - 1) look at the schema above
        # Remember we are in a binary node (each node has max 2 children) so 2x size of leaf (capacity) - 1 (root node)
        # Parent nodes = capacity - 1
        # Leaf nodes = capacity
        self.tree = np.zeros(2 * capacity - 1)
        
        """ tree:
            0
           / \
          0   0
         / \ / \
        0  0 0  0  [Size: capacity] it's at this line that there is the priorities score (aka pi)
        """
        
        # Contains the experiences (so the size of data is capacity)
        self.data = np.zeros(capacity, dtype=object)
    
    
    """
    Here we add our priority score in the sumtree leaf and add the experience in data
    """
    def add(self, priority, data):
        # Look at what index we want to put the experience
        tree_index = self.data_pointer + self.capacity - 1
        
        """ tree:
            0
           / \
          0   0
         / \ / \
tree_index  0 0  0  We fill the leaves from left to right
        """
        
        # Update data frame
        self.data[self.data_pointer] = data
        
        # Update the leaf
        self.update (tree_index, priority)
        
        # Add 1 to data_pointer
        self.data_pointer += 1
        
        if self.data_pointer >= self.capacity:  # If we're above the capacity, you go back to first index (we overwrite)
            self.data_pointer = 0
            
    
    """
    Update the leaf priority score and propagate the change through tree
    """
    def update(self, tree_index, priority):
        # Change = new priority score - former priority score
        change = priority - self.tree[tree_index]
        self.tree[tree_index] = priority
        
        # then propagate the change through tree
        while tree_index != 0:    # this method is faster than the recursive loop in the reference code
            
            """
            Here we want to access the line above
            THE NUMBERS IN THIS TREE ARE THE INDEXES NOT THE PRIORITY VALUES
            
                0
               / \
              1   2
             / \ / \
            3  4 5  [6] 
            
            If we are in leaf at index 6, we updated the priority score
            We need then to update index 2 node
            So tree_index = (tree_index - 1) // 2
            tree_index = (6-1)//2
            tree_index = 2 (because // round the result)
            """
            tree_index = (tree_index - 1) // 2
            self.tree[tree_index] += change
    
    
    """
    Here we get the leaf_index, priority value of that leaf and experience associated with that index
    """
    def get_leaf(self, v):
        """
        Tree structure and array storage:
        Tree index:
             0         -> storing priority sum
            / \
          1     2
         / \   / \
        3   4 5   6    -> storing priority for experiences
        Array type for storing:
        [0,1,2,3,4,5,6]
        """
        parent_index = 0
        
        while True: # the while loop is faster than the method in the reference code
            left_child_index = 2 * parent_index + 1
            right_child_index = left_child_index + 1
            
            # If we reach bottom, end the search
            if left_child_index >= len(self.tree):
                leaf_index = parent_index
                break
            
            else: # downward search, always search for a higher priority node
                
                if v <= self.tree[left_child_index]:
                    parent_index = left_child_index
                    
                else:
                    v -= self.tree[left_child_index]
                    parent_index = right_child_index
            
        data_index = leaf_index - self.capacity + 1

        return leaf_index, self.tree[leaf_index], self.data[data_index]
    
    @property
    def total_priority(self):
        return self.tree[0] # Returns the root node

#!/usr/bin/env python

"""dqn: First try to create an AI for SnakeGame. Is it good enough?

This algorithm is a implementation of DQN, Double DQN logic (using a target
network to have fixed Q-targets), Dueling DQN logic (Q(s,a) = Advantage + Value)
and PER (Prioritized Experience Replay, using Sum Trees). You can read more
about these on https://medium.freecodecamp.org/improvements-in-deep-q-learning-dueling-double-dqn-prioritized-experience-replay-and-fixed-58b130cc5682

Possible usage:
    * Simple DQN;
    * DDQN;
    * DDDQN;
    * DDDQN + PER;
    * a combination of any of the above.

Arguments:
    --load FILE.h5: load a previously trained model in '.h5' format.
    --board_size INT: assign the size of the board, default = 10
    --nb_frames INT: assign the number of frames per stack, default = 4.
    --nb_actions INT: assign the number of actions possible, default = 5.
    --update_freq INT: assign how often, in epochs, to update the target,
      default = 500.
    --visual: select wheter or not to draw the game in pygame.
    --double: use a target network with double DQN logic.
    --dueling: use dueling network logic, Q(s,a) = A + V.
    --per: use Prioritized Experience Replay (based on Sum Trees).
    --local_state: Verify is possible next moves are dangerous (field expertise)
"""

import numpy as np
from os import path, environ, sys
import random

import inspect # Making relative imports from parallel folders possible
currentdir = path.dirname(path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = path.dirname(currentdir)
sys.path.insert(0, parentdir)

from keras.optimizers import RMSprop, Nadam
from keras.models import load_model, Sequential
from keras.layers import *
from keras import backend as K
K.set_image_dim_ordering('th')

__author__ = "Victor Neves"
__license__ = "MIT"
__version__ = "1.0"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"

class ExperienceReplay:
    """The class that handles memory and experiences replay.

    Attributes:
        memory: memory array to insert experiences.
        memory_size: the ammount of experiences to be stored in the memory.
        input_shape: the shape of the input which will be stored.
        batch_function: returns targets according to S.
        per: flag for PER usage.
        per_epsilon: used to replace "0" probabilities cases.
        per_alpha: how much prioritization to use.
        per_beta: importance sampling weights (IS_weights).
    """
    def __init__(self, memory_size = 100, per = False, alpha = 0.6,
                 epsilon = 0.001, beta = 0.4, nb_epoch = 10000, decay = 0.5):
        """Initialize parameters and the memory array."""
        self.per = per
        self.memory_size = memory_size
        self.reset_memory() # Initiate the memory

        if self.per:
            self.per_epsilon = epsilon
            self.per_alpha = alpha
            self.per_beta = beta
            self.schedule = LinearSchedule(nb_epoch * decay, 1.0, beta)

    def exp_size(self):
        """Returns how much memory is stored."""
        if self.per:
            return self.exp
        else:
            return len(self.memory)

    def get_priority(self, errors):
        """Returns priority based on how much prioritization to use."""
        return (errors + self.per_epsilon) ** self.per_alpha

    def update(self, tree_indices, errors):
        """Update a list of nodes, based on their errors."""
        priorities = self.get_priority(errors)

        for index, priority in zip(tree_indices, priorities):
            self.memory.update(index, priority)

    def remember(self, s, a, r, s_prime, game_over):
        """Remember SARS' experiences, with the game_over parameter (done)."""
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])

        if self.per: # If using PER, insert in the max_priority.
            max_priority = self.memory.max_leaf()

            if max_priority == 0:
                max_priority = self.get_priority(0)

            self.memory.insert(experience, max_priority)
            self.exp += 1
        else: # Else, just append the experience to the list.
            self.memory.append(experience)

            if self.memory_size > 0 and self.exp_size() > self.memory_size:
                self.memory.pop(0)

    def get_samples(self, batch_size):
        """Sample the memory according to PER flag."""
        if self.per:
            batch = [None] * batch_size
            IS_weights = np.zeros((batch_size, ))
            tree_indices = [0] * batch_size

            memory_sum = self.memory.sum()
            len_seg = memory_sum / batch_size
            min_prob = self.memory.min_leaf() / memory_sum

            for i in range(batch_size):
                val = uniform(len_seg * i, len_seg * (i + 1))
                tree_indices[i], priority, batch[i] = self.memory.retrieve(val)
                prob = priority / self.memory.sum()
                IS_weights[i] = np.power(prob / min_prob, -self.per_beta)

            return np.array(batch), IS_weights, tree_indices

        else:
            IS_weights = np.ones((batch_size, ))
            batch = random.sample(self.memory, batch_size)
            return np.array(batch), IS_weights, None

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        r = r.repeat(nb_actions).reshape((batch_size, nb_actions))
        game_over = game_over.repeat(nb_actions)\
                             .reshape((batch_size, nb_actions))
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])
            for i in range(batch_size):
                Qsa[i] = Y_target[i][actions[i]]
            Qsa = np.array(Qsa).repeat(nb_actions).reshape((batch_size, nb_actions))

        else:
            Qsa = np.max(Y[batch_size:], axis = 1).repeat(nb_actions)\
                                                .reshape((batch_size, nb_actions))

        # The targets here already take into account
        delta = np.zeros((batch_size, nb_actions))
        a = np.cast['int'](a)
        delta[np.arange(batch_size), a] = 1
        targets = ((1 - delta) * Y[:batch_size]
                  + delta * (r + gamma * (1 - game_over) * Qsa))

        if self.per: # Update the Sum Tree with the absolute error.
            errors = np.abs((targets - Y[:batch_size]).max(axis = 1)).clip(max = 1.)
            self.update(tree_indices, errors)

        return S, targets, IS_weights

    def reset_memory(self):
        """Set the memory as a blank list."""
        if self.per:
            if self.memory_size <= 0:
                self.memory_size = 150000

            self.memory = SumTree(self.memory_size)
            self.exp = 0
        else:
            self.memory = []



class Agent:
    """Agent based in a simple DQN that can read states, remember and play.

    Attributes:
    memory: memory used in the model. Input memory or ExperienceReplay.
    model: the input model, Conv2D in Keras.
    target: the target model, used to calculade the fixed Q-targets.
    nb_frames: ammount of frames for each sars.
    frames: the frames in each sars.
    per: flag for PER usage.
    """
    def __init__(self, model, target, memory = None, memory_size = 150000,
                 nb_frames = 4, board_size = 10, per = False):
        """Initialize the agent with given attributes."""
        if memory:
            self.memory = memory
        else:
            self.memory = ExperienceReplay(memory_size = memory_size, per = per)

        self.per = per
        self.model = model
        self.target = target
        self.nb_frames = nb_frames
        self.board_size = board_size
        self.frames = None
        self.target_updates = 0

    def reset_memory(self):
        """Reset memory if necessary."""
        self.memory.reset_memory()

    def get_game_data(self, game):
        """Create a list with 4 frames and append/pop them each frame."""
        if game.game_over:
            frame = np.zeros((self.board_size, self.board_size))
        else:
            frame = game.state()

        if self.frames is None:
            self.frames = [frame] * self.nb_frames
        else:
            self.frames.append(frame)
            self.frames.pop(0)

        return np.expand_dims(self.frames, 0)

    def clear_frames(self):
        """Reset frames to restart appending."""
        self.frames = None

    def update_target_model(self):
        """Update the target model with the main model's weights."""
        self.target_updates += 1
        self.target.set_weights(self.model.get_weights())

    def print_metrics(self, epoch, nb_epoch, history_size, history_loss,
                      history_step, history_reward, policy, value, win_count,
                      verbose = 1):
        """Function to print metrics of training steps."""
        if verbose == 0:
            pass
        elif verbose == 1:
            text_epoch = ('Epoch: {:03d}/{:03d} | Mean size 10: {:.1f} | '
                           + 'Longest 10: {:03d} | Mean steps 10: {:.1f} | '
                           + 'Wins: {:d} | Win percentage: {:.1f}%')
            print(text_epoch.format(epoch + 1, nb_epoch,
                                    sum(history_size[-10:]) / 10,
                                    max(history_size[-10:]),
                                    sum(history_step[-10:]) / 10,
                                    win_count, 100 * win_count/(epoch + 1)))
        else:
            text_epoch = 'Epoch: {:03d}/{:03d}' # Print epoch info
            print(text_epoch.format(epoch + 1, nb_epoch))

            # Print training performance
            text_train = ('\t\x1b[0;30;47m' + ' Training metrics ' + '\x1b[0m'
                          + '\tTotal loss: {:.4f} | Loss per step: {:.4f} | '
                          + 'Mean loss - 100 episodes: {:.4f}')
            print(text_perf.format(history_loss[-1],
                                   history_loss[-1] / history_step[-1],
                                   sum(history_loss[-100:]) / 100))

            text_game = ('\t\x1b[0;30;47m' + ' Game metrics ' + '\x1b[0m'
                         + '\t\tSize: {:d} | Ammount of steps: {:d} | '
                         + 'Steps per food eaten: {:.1f} | '
                         + 'Mean size - 100 episodes: {:.1f}')
            print(text_game.format(history_size[-1], history_step[-1],
                                   history_size[-1] / history_step[-1],
                                   sum(history_step[-100:]) / 100))

            # Print policy metrics
            if policy == "BoltzmannQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tBoltzmann Temperature: {:.2f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            elif policy == "BoltzmannGumbelQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tNumber of actions: {:.0f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            else:
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tEpsilon: {:.2f} | Episode reward: {:.1f} | '
                               + 'Wins: {:d} | Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))

    def train_model(self, model, target, batch_size, gamma, nb_actions, epoch = 0):
        """Function to train the model on a batch of the data. The optimization
        flag is used when we are not playing, just batching and optimizing."""
        loss = 0.

        batch = self.memory.get_targets(model = self.model,
                                        target = self.target,
                                        batch_size = batch_size,
                                        gamma = gamma,
                                        nb_actions = nb_actions)

        if batch:
            inputs, targets, IS_weights = batch

            if inputs is not None and targets is not None:
                loss = float(self.model.train_on_batch(inputs,
                                                       targets,
                                                       IS_weights))

        return loss

    def train(self, game, nb_epoch = 10000, batch_size = 64, gamma = 0.95,
              eps = [1., .01], temp = [1., 0.01], learning_rate = 0.5,
              observe = 0, update_target_freq = 500, optim_rounds = 1,
              policy = "EpsGreedyQPolicy", verbose = 1, n_steps = None):
        """The main training function, loops the game, remember and choose best
        action given game state (frames)."""

        history_size = []
        history_step = []
        history_loss = []
        history_reward = []

        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp[0], temp[1], nb_epoch * learning_rate)
        if policy == "BoltzmannGumbelQPolicy":
            q_policy = BoltzmannGumbelQPolicy()
        else:
            q_policy = EpsGreedyQPolicy(eps[0], eps[1], nb_epoch * learning_rate)

        nb_actions = game.nb_actions
        win_count = 0

        for turn in range(optim_rounds):
            if turn > 0:
                for epoch in range(nb_epoch):
                    loss = self.train_model(model = self.model,
                                            epoch = epoch,
                                            target = self.target,
                                            batch_size = batch_size,
                                            gamma = gamma,
                                            nb_actions = nb_actions)

                    print('Optimizer turn: {:2d} | Epoch: {:03d}/{:03d} | '
                          + 'Loss: {:.4f}'.format(turn, epoch + 1, nb_epoch, loss))
            else:
                for epoch in range(nb_epoch):
                    loss = 0.
                    total_reward = 0.
                    if n_steps is not None:
                        n_step_buffer = []
                    game.reset()
                    self.clear_frames()

                    S = self.get_game_data(game)

                    while not game.game_over:
                        game.food_pos = game.generate_food()
                        action, value = q_policy.select_action(self.model,
                                                               S, epoch,
                                                               nb_actions)

                        game.play(action, "ROBOT")

                        r = game.get_reward()
                        total_reward += r
                        if n_steps is not None:
                            n_step_buffer.append(r)

                            if len(n_step_buffer) < n_steps:
                                R = r
                            else:
                                R = sum([n_step_buffer[i] * (gamma ** i)\
                                        for i in range(n_steps)])
                        else:
                            R = r

                        S_prime = self.get_game_data(game)
                        experience = [S, action, R, S_prime, game.game_over]
                        self.memory.remember(*experience) # Add to the memory
                        S = S_prime # Advance to the next state (stack of S)

                        if epoch >= observe: # Get the batchs and train
                            loss += self.train_model(model = self.model,
                                                     target = self.target,
                                                     batch_size = batch_size,
                                                     gamma = gamma,
                                                     nb_actions = nb_actions)

                    if game.is_won():
                        win_count += 1 # Counter for metric purposes

                    if self.per: # Advance beta, used in PER
                        self.memory.per_beta = self.memory.schedule.value(epoch)

                    if self.target is not None: # Update the target model
                        if epoch % update_target_freq == 0:
                            self.update_target_model()

                    history_size.append(game.snake.length)
                    history_step.append(game.step)
                    history_loss.append(loss)
                    history_reward.append(total_reward)

                    if (epoch + 1) % 10 == 0:
                        self.print_metrics(epoch, nb_epoch, history_size,
                                           history_loss, history_step,
                                           history_reward, policy, value,
                                           win_count, verbose)

    def play(self, game, nb_epoch = 1000, eps = 0.01, temp = 0.01,
             visual = False, policy = "GreedyQPolicy"):
        """Play the game with the trained agent. Can use the visual tag to draw
            in pygame."""
        win_count = 0
        result_size = []
        result_step = []
        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp, temp, nb_epoch)
        elif policy == "EpsGreedyQPolicy":
            q_policy = EpsGreedyQPolicy(eps, eps, nb_epoch)
        else:
            q_policy = GreedyQPolicy()

        for epoch in range(nb_epoch):
            game.reset()
            self.clear_frames()
            S = self.get_game_data(game)

            if visual:
                game.create_window()
                # The main loop, it pump key_presses and update every tick.
                environ['SDL_VIDEO_CENTERED'] = '1' # Centering the window
                previous_size = game.snake.length # Initial size of the snake
                color_list = game.gradient([(42, 42, 42), (152, 152, 152)],\
                                               previous_size)

            while not game.game_over:
                action, value = q_policy.select_action(self.model, S, epoch, nb_actions)
                game.play(action, "ROBOT")
                current_size = game.snake.length # Update the body size

                if visual:
                    game.draw(color_list)

                    if current_size > previous_size:
                        color_list = game.gradient([(42, 42, 42), (152, 152, 152)],
                                                   game.snake.length)

                        previous_size = current_size

                S = self.get_game_data(game)

                if game.game_over:
                    result_size.append(current_size)
                    result_step.append(game.step)

            if game.is_won():
                win_count += 1

        print("Accuracy: {} %".format(100. * win_count / nb_epoch))
        print("Mean size: {} | Biggest size: {} | Smallest size: {}"\
              .format(np.mean(result_size), np.max(result_size),
                      np.min(result_size)))
        print("Mean steps: {} | Biggest step: {} | Smallest step: {}"\
              .format(np.mean(result_step), np.max(result_step),\
                      np.min(result_step)))
        
import random
import numpy as np

class LinearSchedule(object):
    def __init__(self, schedule_timesteps, final_p, initial_p):
        """Linear interpolation between initial_p and final_p over
        schedule_timesteps. After this many timesteps pass final_p is
        returned.
        Parameters
        ----------
        schedule_timesteps: int
            Number of timesteps for which to linearly anneal initial_p
            to final_p
        initial_p: float
            initial output value
        final_p: float
            final output value
        """
        self.schedule_timesteps = schedule_timesteps
        self.final_p = final_p
        self.initial_p = initial_p

    def value(self, t):
        """See Schedule.value"""
        fraction = min(float(t) / self.schedule_timesteps, 1.0)
        return self.initial_p + fraction * (self.final_p - self.initial_p)


class GreedyQPolicy:
    """Implement the greedy policy

    Greedy policy always takes current best action.
    """
    def __init__(self):
        super(GreedyQPolicy, self).__init__()

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)
        action = int(np.argmax(q[0]))

        return action, 0

    def get_config(self):
        """Return configurations of GreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(GreedyQPolicy, self).get_config()
        return config


class EpsGreedyQPolicy:
    """Implement the epsilon greedy policy

    Eps Greedy policy either:

    - takes a random action with probability epsilon
    - takes current best action with prob (1 - epsilon)
    """
    def __init__(self, max_eps=1., min_eps = .01, nb_epoch = 10000):
        super(EpsGreedyQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_eps, max_eps)

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        rand = random.random()
        self.eps = self.schedule.value(epoch)

        if rand < self.eps:
            action = int(nb_actions * rand)
        else:
            q = model.predict(state)
            action = int(np.argmax(q[0]))

        return action, self.eps

    def get_config(self):
        """Return configurations of EpsGreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(EpsGreedyQPolicy, self).get_config()
        config['eps'] = self.eps
        return config


class BoltzmannQPolicy:
    """Implement the Boltzmann Q Policy
    Boltzmann Q Policy builds a probability law on q values and returns
    an action selected randomly according to this law.
    """
    def __init__(self, max_temp = 1., min_temp = .01, nb_epoch = 10000, clip = (-500., 500.)):
        super(BoltzmannQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_temp, max_temp)
        self.clip = clip

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        self.temp = self.schedule.value(epoch)
        arg = q / self.temp

        exp_values = np.exp(arg - arg.max())
        probs = exp_values / exp_values.sum()
        action = np.random.choice(range(nb_actions), p = probs)

        return action, self.temp

    def get_config(self):
        """Return configurations of BoltzmannQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannQPolicy, self).get_config()
        config['temp'] = self.temp
        config['clip'] = self.clip
        return config


class BoltzmannGumbelQPolicy:
    """Implements Boltzmann-Gumbel exploration (BGE) adapted for Q learning
    based on the paper Boltzmann Exploration Done Right
    (https://arxiv.org/pdf/1705.10257.pdf).
    BGE is invariant with respect to the mean of the rewards but not their
    variance. The parameter C, which defaults to 1, can be used to correct for
    this, and should be set to the least upper bound on the standard deviation
    of the rewards.
    BGE is only available for training, not testing. For testing purposes, you
    can achieve approximately the same result as BGE after training for N steps
    on K actions with parameter C by using the BoltzmannQPolicy and setting
    tau = C/sqrt(N/K)."""

    def __init__(self, C = 1.0):
        super(BoltzmannGumbelQPolicy, self).__init__()
        self.C = C
        self.action_counts = None

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        q = q.astype('float64')

        # If we are starting training, we should reset the action_counts.
        # Otherwise, action_counts should already be initialized, since we
        # always do so when we begin training.
        if epoch == 0:
            self.action_counts = np.ones(q.shape)

        beta = self.C/np.sqrt(self.action_counts)
        Z = np.random.gumbel(size = q.shape)

        perturbation = beta * Z
        perturbed_q_values = q + perturbation
        action = np.argmax(perturbed_q_values)

        self.action_counts[action] += 1
        return action, np.sum(self.action_counts)

    def get_config(self):
        """Return configurations of BoltzmannGumbelQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannGumbelQPolicy, self).get_config()
        config['C'] = self.C
        return config

#!/usr/bin/env python

"""clipped_error: L1 for errors < clip_value else L2 error.

Functions:
    huber_loss: Return L1 error if absolute error is less than clip_value, else
                return L2 error.
    clipped_error: Call huber_loss with default clip_value to 1.0.
"""

import numpy as np
from keras import backend as K
import tensorflow as tf

__author__ = "Victor Neves"
__license__ = "MIT"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"

def huber_loss(y_true, y_pred, clip_value):
	# Huber loss, see https://en.wikipedia.org/wiki/Huber_loss and
	# https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b
	# for details.
	assert clip_value > 0.

	x = y_true - y_pred
	if np.isinf(clip_value):
		# Spacial case for infinity since Tensorflow does have problems
		# if we compare `K.abs(x) < np.inf`.
		return .5 * K.square(x)

	condition = K.abs(x) < clip_value
	squared_loss = .5 * K.square(x)
	linear_loss = clip_value * (K.abs(x) - .5 * clip_value)
	if K.backend() == 'tensorflow':
		if hasattr(tf, 'select'):
			return tf.select(condition, squared_loss, linear_loss)  # condition, true, false
		else:
			return tf.where(condition, squared_loss, linear_loss)  # condition, true, false
	elif K.backend() == 'theano':
		from theano import tensor as T
		return T.switch(condition, squared_loss, linear_loss)
	else:
		raise RuntimeError('Unknown backend "{}".'.format(K.backend()))

def clipped_error(y_true, y_pred):
	return K.mean(huber_loss(y_true, y_pred, clip_value = 1.), axis = -1)

#def CNN1(optimizer, loss, stack, input_size, output_size):
 #   model = Sequential()
  #  model.add(Conv2D(32, (3, 3), activation = 'relu', input_shape = (stack,
   #                                                                  input_size,
    #                                                                 input_size)))
#    model.add(Conv2D(64, (3, 3), activation = 'relu'))
 #   model.add(Conv2D(128, (3, 3), activation = 'relu'))
  #  model.add(Conv2D(256, (3, 3), activation = 'relu'))
   # model.add(Flatten())
    #model.add(Dense(1024, activation = 'relu'))
    #model.add(Dense(output_size))
    #model.compile(optimizer = optimizer, loss = loss)

    #return model
    
def CNN4(optimizer, loss, stack, input_size, output_size):
    """From @Kaixhin implementation's of the Rainbow paper."""
    model = Sequential()
    model.add(Conv2D(32, (4, 4), activation = 'relu', input_shape = (stack,
                                                                    input_size,
                                                                    input_size)))
    model.add(Conv2D(64, (2, 2), activation = 'relu'))
    model.add(Conv2D(64, (2, 2), activation = 'relu'))
    model.add(Flatten())
    model.add(GaussianNoise(stddev = 0.4))
    model.add(Dense(3136, activation = 'relu'))
    model.add(GaussianNoise(stddev = 0.4))
    model.add(Dense(output_size))
    model.compile(optimizer = optimizer, loss = loss)

    return model
  
board_size = 10
nb_frames = 4
  
game = Game(board_size = board_size,
                        local_state = True, relative_pos = False)

model = CNN4(optimizer = RMSprop(), loss = clipped_error,
                            stack = nb_frames, input_size = board_size,
                            output_size = game.nb_actions)
target = None

agent = Agent(model = model, target = target, memory_size = -1,
                          nb_frames = nb_frames, board_size = board_size,
                          per = False)
#%lprun -f agent.train agent.train(game, batch_size = 64, nb_epoch = 10, gamma = 0.95, update_target_freq = 500, policy = "EpsGreedyQPolicy")
agent.train(game, batch_size = 64, nb_epoch = 10000, gamma = 0.95, update_target_freq = 500, policy = "BoltzmannQPolicy")

In [0]:
model.save('keras.h5')

!zip -r model-epsgreedy-n-steps.zip keras.h5 
from google.colab import files
files.download('model-epsgreedy-noise.zip')

model = load_model('keras.h5', custom_objects={'clipped_error': clipped_error})

board_size = 10
nb_frames = 4
nb_actions = 5

target = None

agent = Agent(model = model, target = target, memory_size = 1500000,
                          nb_frames = nb_frames, board_size = board_size,
                          per = False)
#%lprun -f agent.train agent.train(game, batch_size = 64, nb_epoch = 10, gamma = 0.95, update_target_freq = 500, policy = "EpsGreedyQPolicy")

agent.play(game, visual = False, nb_epoch = 1000)

# Testing what changed on the new DQN

The same source a the banchmark, only difference is the Snake and Game classes.

In [0]:
#!/usr/bin/env python

"""SnakeGame: A simple and fun exploration, meant to be used by Human and AI.
"""

import sys  # To close the window when the game is over
from array import array  # Efficient numeric arrays
from os import environ, path  # To center the game window the best possible
import random  # Random numbers used for the food
import logging  # Logging function for movements and errors
from itertools import tee  # For the color gradient on snake
import numpy as np
import os

__author__ = "Victor Neves"
__license__ = "MIT"
__version__ = "1.0"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"

# Actions, options and forbidden moves
options = {'QUIT': 0, 'PLAY': 1, 'BENCHMARK': 2, 'LEADERBOARDS': 3, 'MENU': 4,
           'ADD_LEADERBOARDS': 5}
relative_actions = {'LEFT': 0, 'FORWARD': 1, 'RIGHT': 2}
actions = {'LEFT': 0, 'RIGHT': 1, 'UP': 2, 'DOWN': 3, 'IDLE': 4}
forbidden_moves = [(0, 1), (1, 0), (2, 3), (3, 2)]

# Possible rewards in the game
rewards = {'MOVE': -0.005, 'GAME_OVER': -1, 'SCORED': 1}

# Types of point in the board
point_type = {'EMPTY': 0, 'FOOD': 1, 'BODY': 2, 'HEAD': 3, 'DANGEROUS': 4}

# Speed levels possible to human players, MEGA HARDCORE starts with MEDIUM and
# increases with snake size
levels = [" EASY ", " MEDIUM ", " HARD ", " MEGA HARDCORE "]
speeds = {'EASY': 80, 'MEDIUM': 60, 'HARD': 40}

class GlobalVariables:
    """Global variables to be used while drawing and moving the snake game.

    Attributes
    ----------
    BOARD_SIZE: int, optional, default = 30
        The size of the board.
    BLOCK_SIZE: int, optional, default = 20
        The size in pixels of a block.
    HEAD_COLOR: tuple of 3 * int, optional, default = (42, 42, 42)
        Color of the head. Start of the body color gradient.
    TAIL_COLOR: tuple of 3 * int, optional, default = (152, 152, 152)
        Color of the tail. End of the body color gradient.
    FOOD_COLOR: tuple of 3 * int, optional, default = (200, 0, 0)
        Color of the food.
    GAME_SPEED: int, optional, default = 10
        Speed in ticks of the game. The higher the faster.
    BENCHMARK: int, optional, default = 10
        Ammount of matches to BENCHMARK and possibly go to leaderboards.
    """
    def __init__(self, BOARD_SIZE = 30, BLOCK_SIZE = 20,
                 HEAD_COLOR = (42, 42, 42), TAIL_COLOR = (152, 152, 152),
                 FOOD_COLOR = (200, 0, 0), GAME_SPEED = 80, GAME_FPS = 100,
                 BENCHMARK = 10):
        """Initialize all global variables. Can be updated with argument_handler.
        """
        self.BOARD_SIZE = BOARD_SIZE
        self.BLOCK_SIZE = BLOCK_SIZE
        self.HEAD_COLOR = HEAD_COLOR
        self.TAIL_COLOR = TAIL_COLOR
        self.FOOD_COLOR = FOOD_COLOR
        self.GAME_SPEED = GAME_SPEED
        self.GAME_FPS = GAME_FPS
        self.BENCHMARK = BENCHMARK

        if self.BOARD_SIZE > 50: # Warn the user about performance
            logger.warning('WARNING: BOARD IS TOO BIG, IT MAY RUN SLOWER.')

class TextBlock:
    """Block of text class, used by pygame. Can be used to both text and menu.

    Attributes:
    ----------
    text: string
        The text to be displayed.
    pos: tuple of 2 * int
        Color of the tail. End of the body color gradient.
    screen: pygame window object
        The screen where the text is drawn.
    scale: int, optional, default = 1 / 12
        Adaptive scale to resize if the board size changes.
    type: string, optional, default = "text"
        Assert whether the BlockText is a text or menu option.
    """
    def __init__(self, text, pos, screen, scale = (1 / 12), type = "text"):
        """Initialize, set position of the rectangle and render the text block."""
        self.type = type
        self.hovered = False
        self.text = text
        self.pos = pos
        self.screen = screen
        self.scale = scale
        self.set_rect()
        self.draw()

    def draw(self):
        """Set what to render and blit on the pygame screen."""
        self.set_rend()
        self.screen.blit(self.rend, self.rect)

    def set_rend(self):
        """Set what to render (font, colors, sizes)"""
        font = pygame.font.Font(resource_path("resources/fonts/freesansbold.ttf"),
                                int((var.BOARD_SIZE * var.BLOCK_SIZE) * self.scale))
        self.rend = font.render(self.text, True, self.get_color(),
                                self.get_background())

    def get_color(self):
        """Get color to render for text and menu (hovered or not).

        Return
        ----------
        color: tuple of 3 * int
            The color that will be rendered for the text block.
        """
        color = pygame.Color(42, 42, 42)

        if self.type == "menu":
            if self.hovered:
                pass
            else:
                color = pygame.Color(152, 152, 152)

        return color

    def get_background(self):
        """Get background color to render for text (hovered or not) and menu.

        Return
        ----------
        color: tuple of 3 * int
            The color that will be rendered for the background of the text block.
        """
        color = None

        if self.type == "menu":
            if self.hovered:
                color = pygame.Color(152, 152, 152)

        return color

    def set_rect(self):
        """Set the rectangle and it's position to draw on the screen."""
        self.set_rend()
        self.rect = self.rend.get_rect()
        self.rect.center = self.pos


class Snake:
    """Player (snake) class which initializes head, body and board.

    The body attribute represents a list of positions of the body, which are in-
    cremented when moving/eating on the position [0]. The orientation represents
    where the snake is looking at (head) and collisions happen when any element
    is superposed with the head.

    Attributes
    ----------
    head: list of 2 * int, default = [BOARD_SIZE / 4, BOARD_SIZE / 4]
        The head of the snake, located according to the board size.
    body: list of lists of 2 * int
        Starts with 3 parts and grows when food is eaten.
    previous_action: int, default = 1
        Last action which the snake took.
    length: int, default = 3
        Variable length of the snake, can increase when food is eaten.
    """
    def __init__(self):
        """Inits Snake with 3 body parts (one is the head) and pointing right"""
        self.head = [int(var.BOARD_SIZE / 4), int(var.BOARD_SIZE / 4)]
        self.body = [[self.head[0], self.head[1]],
                     [self.head[0] - 1, self.head[1]],
                     [self.head[0] - 2, self.head[1]]]
        self.previous_action = 1
        self.length = 3

    def is_movement_invalid(self, action):
        valid = False

        if (action, self.previous_action) in forbidden_moves:
            valid = True

        return valid

    def move(self, action, food_pos):
        """According to orientation, move 1 block. If the head is not positioned
        on food, pop a body part. Else, return without popping.

        Return
        ----------
        ate_food: boolean
            Flag which represents whether the snake ate or not food.
        """
        ate_food = False

        if action == actions['IDLE'] or self.is_movement_invalid(action):
            action = self.previous_action
        else:
            self.previous_action = action

        if action == actions['LEFT']:
            self.head[0] -= 1
        elif action == actions['RIGHT']:
            self.head[0] += 1
        elif action == actions['UP']:
            self.head[1] -= 1
        elif action == actions['DOWN']:
            self.head[1] += 1

        self.body.insert(0, list(self.head))

        if self.head == food_pos:
            logger.info('EVENT: FOOD EATEN')
            self.length = len(self.body)

            ate_food = True
        else:
            self.body.pop()

        return ate_food


class FoodGenerator:
    """Generate and keep track of food.

    Attributes
    ----------
    pos:
        Current position of food.
    is_food_on_screen:
        Flag for existence of food.
    """
    def __init__(self, body):
        """Initialize a food piece and set existence flag."""
        self.is_food_on_screen = False
        self.pos = self.generate_food(body)

    def generate_food(self, body):
        """Generate food and verify if it's on a valid place.

        Return
        ----------
        pos: tuple of 2 * int
            Position of the food that was generated. It can't be in the body.
        """
        if not self.is_food_on_screen:
            while True:
                food = [int((var.BOARD_SIZE - 1) * random.random()),
                        int((var.BOARD_SIZE - 1) * random.random())]

                if food in body:
                    continue
                else:
                    self.pos = food
                    break

            logger.info('EVENT: FOOD APPEARED')
            self.is_food_on_screen = True

        return self.pos


class Game:
    """Hold the game window and functions.

    Attributes
    ----------
    window: pygame display
        Pygame window to show the game.
    fps: pygame time clock
        Define Clock and ticks in which the game will be displayed.
    snake: object
        The actual snake who is going to be played.
    food_generator: object
        Generator of food which responds to the snake.
    food_pos: tuple of 2 * int
        Position of the food on the board.
    game_over: boolean
        Flag for game_over.
    player: string
        Define if human or robots are playing the game.
    board_size: int, optional, default = 30
        The size of the board.
    local_state: boolean, optional, default = False
        Whether to use or not game expertise (used mostly by robots players).
    relative_pos: boolean, optional, default = False
        Whether to use or not relative position of the snake head. Instead of
        actions, use relative_actions.
    screen_rect: tuple of 2 * int
        The screen rectangle, used to draw relatively positioned blocks.
    """
    def __init__(self, player, board_size = 30, local_state = False, relative_pos = False):
        """Initialize window, fps and score. Change nb_actions if relative_pos"""
        var.BOARD_SIZE = board_size
        self.local_state = local_state
        self.relative_pos = relative_pos
        self.player = player

        if player == "ROBOT":
            if self.relative_pos:
                self.nb_actions = 3
            else:
                self.nb_actions = 5

            self.reset_game()

    def reset_game(self):
        """Reset the game environment."""
        self.step = 0
        self.snake = Snake()
        self.food_generator = FoodGenerator(self.snake.body)
        self.food_pos = self.food_generator.pos
        self.scored = False
        self.game_over = False

    def create_window(self):
        """Create a pygame display with BOARD_SIZE * BLOCK_SIZE dimension."""
        pygame.init()

        flags = pygame.DOUBLEBUF
        self.window = pygame.display.set_mode((var.BOARD_SIZE * var.BLOCK_SIZE,\
                                               var.BOARD_SIZE * var.BLOCK_SIZE),
                                               flags)
        self.window.set_alpha(None)
        self.screen_rect = self.window.get_rect()
        self.fps = pygame.time.Clock()

    def menu(self):
        """Main menu of the game.

        Return
        ----------
        selected_option: int
            The selected option in the main loop.
        """
        pygame.display.set_caption("SNAKE GAME  | PLAY NOW!")

        img = pygame.image.load(resource_path("resources/images/snake_logo.png"))
        img = pygame.transform.scale(img, (var.BOARD_SIZE * var.BLOCK_SIZE, int(var.BOARD_SIZE * var.BLOCK_SIZE / 3)))

        img_rect = img.get_rect()
        img_rect.center = self.screen_rect.center

        menu_options = [TextBlock(' PLAY GAME ', (self.screen_rect.centerx,
                                                  4 * self.screen_rect.centery / 10),
                                                  self.window, (1 / 12), "menu"),
                        TextBlock(' BENCHMARK ', (self.screen_rect.centerx,
                                                  6 * self.screen_rect.centery / 10),
                                                  self.window, (1 / 12), "menu"),
                        TextBlock(' LEADERBOARDS ', (self.screen_rect.centerx,
                                                     8 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 12), "menu"),
                        TextBlock(' QUIT ', (self.screen_rect.centerx,
                                             10 * self.screen_rect.centery / 10),
                                             self.window, (1 / 12), "menu")]
        selected = False
        selected_option = None

        while not selected:
            pygame.event.pump()
            ev = pygame.event.get()

            self.window.fill(pygame.Color(225, 225, 225))

            for option in menu_options:
                option.draw()

                if option.rect.collidepoint(pygame.mouse.get_pos()):
                    option.hovered = True

                    if option == menu_options[0]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['PLAY']
                    elif option == menu_options[1]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['BENCHMARK']
                    elif option == menu_options[2]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['LEADERBOARDS']
                    elif option == menu_options[3]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                selected_option = options['QUIT']
                else:
                    option.hovered = False

            if selected_option is not None:
                selected = True

            self.window.blit(img, img_rect.bottomleft)
            pygame.display.update()

        return selected_option

    def start_match(self):
        """Create some wait time before the actual drawing of the game."""
        for i in range(3):
            time = str(3 - i)
            self.window.fill(pygame.Color(225, 225, 225))

            # Game starts in 3, 2, 1
            text = [TextBlock('Game starts in', (self.screen_rect.centerx,
                                                 4 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 10), "text"),
                    TextBlock(time, (self.screen_rect.centerx,
                                                 12 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 1.5), "text")]

            for text_block in text:
                text_block.draw()

            pygame.display.update()
            pygame.display.set_caption("SNAKE GAME  |  Game starts in "
                                       + time + " second(s) ...")

            pygame.time.wait(1000)

        logger.info('EVENT: GAME START')

    def start(self):
        """Use menu to select the option/game mode."""
        opt = self.menu()
        running = True

        while running:
            if opt == options['QUIT']:
                pygame.quit()
                sys.exit()
            elif opt == options['PLAY']:
                var.GAME_SPEED, mega_hardcore = self.select_speed()
                self.reset_game()
                self.start_match()
                score = self.single_player(mega_hardcore)
                opt = self.over(score)
            elif opt == options['BENCHMARK']:
                var.GAME_SPEED, mega_hardcore = self.select_speed()
                score = array('i')

                for i in range(var.BENCHMARK):
                    self.reset_game()
                    self.start_match()
                    score.append(self.single_player(mega_hardcore))

                opt = self.over(score)
            elif opt == options['LEADERBOARDS']:
                pass
            elif opt == options['ADD_LEADERBOARDS']:
                pass
            elif opt == options['MENU']:
                opt = self.menu()

    def over(self, score):
        """If collision with wall or body, end the game and open options.

        Return
        ----------
        selected_option: int
            The selected option in the main loop.
        """
        menu_options = [None] * 5
        menu_options[0] = TextBlock(' PLAY AGAIN ', (self.screen_rect.centerx,
                                                     4 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 15), "menu")
        menu_options[1] = TextBlock(' GO TO MENU ', (self.screen_rect.centerx,
                                                     6 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 15), "menu")
        menu_options[3] = TextBlock(' QUIT ', (self.screen_rect.centerx,
                                               10 * self.screen_rect.centery / 10),
                                               self.window, (1 / 15), "menu")

        if isinstance(score, int):
            text_score = 'SCORE: ' + str(score)
        else:
            text_score = 'MEAN SCORE: ' + str(sum(score) / var.BENCHMARK)
            menu_options[2] = TextBlock(' ADD TO LEADERBOARDS ', (self.screen_rect.centerx,
                                                                  8 * self.screen_rect.centery / 10),
                                                                  self.window, (1 / 15), "menu")

        pygame.display.set_caption("SNAKE GAME  | " + text_score
                                   + "  |  GAME OVER...")
        logger.info('EVENT: GAME OVER | FINAL ' + text_score)
        menu_options[4] = TextBlock(text_score, (self.screen_rect.centerx,
                                                 15 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 10), "text")
        selected = False
        selected_option = None

        while not selected:
            pygame.event.pump()
            ev = pygame.event.get()

            # Game over screen
            self.window.fill(pygame.Color(225, 225, 225))

            for option in menu_options:
                if option is not None:
                    option.draw()

                    if option.rect.collidepoint(pygame.mouse.get_pos()):
                        option.hovered = True

                        if option == menu_options[0]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    selected_option = options['PLAY']
                        elif option == menu_options[1]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    selected_option = options['MENU']
                        elif option == menu_options[2]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    selected_option = options['ADD_LEADERBOARDS']
                        elif option == menu_options[3]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    pygame.quit()
                                    sys.exit()
                    else:
                        option.hovered = False

            if selected_option is not None:
                selected = True

            pygame.display.update()

        return selected_option

    def select_speed(self):
        """Speed menu, right before calling start_match.

        Return
        ----------
        speed: int
            The selected speed in the main loop.
        """
        menu_options = [TextBlock(levels[0], (self.screen_rect.centerx,
                                              4 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu"),
                        TextBlock(levels[1], (self.screen_rect.centerx,
                                              8 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu"),
                        TextBlock(levels[2], (self.screen_rect.centerx,
                                              12 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu"),
                        TextBlock(levels[3], (self.screen_rect.centerx,
                                              16 * self.screen_rect.centery / 10),
                                              self.window, (1 / 10), "menu")]
        mega_hardcore = False
        selected = False
        speed = None

        while not selected:
            pygame.event.pump()
            ev = pygame.event.get()

            # Game over screen
            self.window.fill(pygame.Color(225, 225, 225))

            for option in menu_options:
                if option is not None:
                    option.draw()

                    if option.rect.collidepoint(pygame.mouse.get_pos()):
                        option.hovered = True

                        if option == menu_options[0]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = speeds['EASY']
                        elif option == menu_options[1]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = speeds['MEDIUM']
                        elif option == menu_options[2]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = speeds['HARD']
                        elif option == menu_options[3]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    speed = speeds['MEDIUM']
                                    mega_hardcore = True

                    else:
                        option.hovered = False

            if speed is not None:
                selected = True

            pygame.display.update()

        return speed, mega_hardcore

    def single_player(self, mega_hardcore = False):
        """Game loop for single_player (HUMANS).

        Return
        ----------
        score: int
            The final score for the match (discounted of initial length).
        """
        # The main loop, it pump key_presses and update the board every tick.
        previous_size = self.snake.length # Initial size of the snake
        current_size = previous_size # Initial size
        color_list = self.gradient([(42, 42, 42), (152, 152, 152)],\
                                   previous_size)

        # Main loop, where snakes moves after elapsed time is bigger than the
        # move_wait time. The last_key pressed is recorded to make the game more
        # smooth for human players.
        elapsed = 0
        last_key = self.snake.previous_action
        move_wait = var.GAME_SPEED

        while not self.game_over:
            elapsed += self.fps.get_time()  # Get elapsed time since last call.

            if mega_hardcore:  # Progressive speed increments, the hardest.
                move_wait = var.GAME_SPEED - (2 * (self.snake.length - 3))

            key_input = self.handle_input()  # Receive inputs with tick.
            invalid_key = self.snake.is_movement_invalid(key_input)

            if key_input is not None and not invalid_key:
                last_key = key_input

            if elapsed >= move_wait:  # Move and redraw
                elapsed = 0
                self.game_over = self.play(last_key)
                current_size = self.snake.length  # Update the body size

                if current_size > previous_size:
                    color_list = self.gradient([(42, 42, 42), (152, 152, 152)],
                                                   current_size)

                    previous_size = current_size

                self.draw(color_list)

            pygame.display.update()
            self.fps.tick(100)  # Limit FPS to 100

        score = current_size - 3  # After the game is over, record score

        return score

    def check_collision(self):
        """Check wether any collisions happened with the wall or body.

        Return
        ----------
        collided: boolean
            Whether the snake collided or not.
        """
        collided = False

        if self.snake.head[0] > (var.BOARD_SIZE - 1) or self.snake.head[0] < 0:
            logger.info('EVENT: WALL COLLISION')
            collided = True
        elif self.snake.head[1] > (var.BOARD_SIZE - 1) or self.snake.head[1] < 0:
            logger.info('EVENT: WALL COLLISION')
            collided = True
        elif self.snake.head in self.snake.body[1:]:
            logger.info('EVENT: BODY COLLISION')
            collided = True

        return collided

    def is_won(self):
        """Verify if the score is greater than 0.

        Return
        ----------
        won: boolean
            Whether the score is greater than 0.
        """
        return self.snake.length > 3

    def generate_food(self):
        """Generate new food if needed.

        Return
        ----------
        food_pos: tuple of 2 * int
            Current position of the food.
        """
        food_pos = self.food_generator.generate_food(self.snake.body)

        return food_pos

    def handle_input(self):
        """After getting current pressed keys, handle important cases.

        Return
        ----------
        action: int
            Handle human input to assess the next action.
        """
        pygame.event.set_allowed([pygame.QUIT, pygame.KEYDOWN])
        keys = pygame.key.get_pressed()
        pygame.event.pump()
        action = None

        if keys[pygame.K_ESCAPE] or keys[pygame.K_q]:
            logger.info('ACTION: KEY PRESSED: ESCAPE or Q')
            self.over(self.snake.length - 3)
        elif keys[pygame.K_LEFT]:
            logger.info('ACTION: KEY PRESSED: LEFT')
            action = actions['LEFT']
        elif keys[pygame.K_RIGHT]:
            logger.info('ACTION: KEY PRESSED: RIGHT')
            action = actions['RIGHT']
        elif keys[pygame.K_UP]:
            logger.info('ACTION: KEY PRESSED: UP')
            action = actions['UP']
        elif keys[pygame.K_DOWN]:
            logger.info('ACTION: KEY PRESSED: DOWN')
            action = actions['DOWN']

        return action

    def eval_local_safety(self, canvas, body):
        """Evaluate the safety of the head's possible next movements.

        Return
        ----------
        canvas: np.array of size BOARD_SIZE**2
            After using game expertise, change canvas values to DANGEROUS if true.
        """
        if (body[0][0] + 1) > (var.BOARD_SIZE - 1)\
            or ([body[0][0] + 1, body[0][1]]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 0] = point_type['DANGEROUS']
        if (body[0][0] - 1) < 0 or ([body[0][0] - 1, body[0][1]]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 1] = point_type['DANGEROUS']
        if (body[0][1] - 1) < 0 or ([body[0][0], body[0][1] - 1]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 2] = point_type['DANGEROUS']
        if (body[0][1] + 1) > (var.BOARD_SIZE - 1)\
            or ([body[0][0], body[0][1] + 1]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 3] = point_type['DANGEROUS']

        return canvas

    def state(self):
        """Create a matrix of the current state of the game.

        Return
        ----------
        canvas: np.array of size BOARD_SIZE**2
            Return the current state of the game in a matrix.
        """
        canvas = np.zeros((var.BOARD_SIZE, var.BOARD_SIZE))

        if self.game_over:
            pass
        else:
            body = self.snake.body

            for part in body:
                canvas[part[0], part[1]] = point_type['BODY']

            canvas[body[0][0], body[0][1]] = point_type['HEAD']

            if self.local_state:
                canvas = self.eval_local_safety(canvas, body)

            canvas[self.food_pos[0], self.food_pos[1]] = point_type['FOOD']

        return canvas

    def relative_to_absolute(self, action):
        """Translate relative actions to absolute.

        Return
        ----------
        action: int
            Translated action from relative to absolute.
        """
        if action == relative_actions['FORWARD']:
            action = self.snake.previous_action
        elif action == relative_actions['LEFT']:
            if self.snake.previous_action == actions['LEFT']:
                action = actions['DOWN']
            elif self.snake.previous_action == actions['RIGHT']:
                action = actions['UP']
            elif self.snake.previous_action == actions['UP']:
                action = actions['LEFT']
            else:
                action = actions['RIGHT']
        else:
            if self.snake.previous_action == actions['LEFT']:
                action = actions['UP']
            elif self.snake.previous_action == actions['RIGHT']:
                action = actions['DOWN']
            elif self.snake.previous_action == actions['UP']:
                action = actions['RIGHT']
            else:
                action = actions['LEFT']

        return action

    def play(self, action):
        """Move the snake to the direction, eat and check collision."""
        self.scored = False
        self.step += 1
        self.food_pos = self.generate_food()

        if self.relative_pos:
            action = self.relative_to_absolute(action)

        if self.snake.move(action, self.food_pos):
            self.scored = True
            self.food_generator.is_food_on_screen = False

        if self.player == "HUMAN":
            if self.check_collision():
                return True
        elif self.check_collision() or self.step > 50 * self.snake.length:
            self.game_over = True

    def get_reward(self):
        """Return the current score. Can be used as the reward function.

        Return
        ----------
        reward: float
            Current reward of the game.
        """
        reward = rewards['MOVE']

        if self.game_over:
            reward = rewards['GAME_OVER']
        elif self.scored:
            reward = self.snake.length

        return reward

    def gradient(self, colors, steps, components = 3):
        """Function to create RGB gradients given 2 colors and steps. If
        component is changed to 4, it does the same to RGBA colors.

        Return
        ----------
        result: list of steps length of tuple of 3 * int (if RGBA, 4 * int)
            List of colors of calculated gradient from start to end.
        """
        def linear_gradient(start, finish, substeps):
            yield start

            for i in range(1, substeps):
                yield tuple([(start[j] + (float(i) / (substeps-1)) * (finish[j]\
                            - start[j])) for j in range(components)])

        def pairs(seq):
            a, b = tee(seq)
            next(b, None)

            return zip(a, b)

        result = []
        substeps = int(float(steps) / (len(colors) - 1))

        for a, b in pairs(colors):
            for c in linear_gradient(a, b, substeps):
                result.append(c)

        return result

    def draw(self, color_list):
        """Draw the game, the snake and the food using pygame."""
        self.window.fill(pygame.Color(225, 225, 225))

        for part, color in zip(self.snake.body, color_list):
            pygame.draw.rect(self.window, color, pygame.Rect(part[0] *\
                        var.BLOCK_SIZE, part[1] * var.BLOCK_SIZE, \
                        var.BLOCK_SIZE, var.BLOCK_SIZE))

        pygame.draw.rect(self.window, var.FOOD_COLOR,\
                         pygame.Rect(self.food_pos[0] * var.BLOCK_SIZE,\
                         self.food_pos[1] * var.BLOCK_SIZE, var.BLOCK_SIZE,\
                         var.BLOCK_SIZE))

        pygame.display.set_caption("SNAKE GAME  |  Score: "
                                    + str(self.snake.length - 3))

def resource_path(relative_path):
    """Function to return absolute paths. Used while creating .exe file."""
    if hasattr(sys, '_MEIPASS'):
        return path.join(sys._MEIPASS, relative_path)

    return path.join(path.dirname(path.realpath(__file__)), relative_path)

var = GlobalVariables() # Initializing GlobalVariables
logger = logging.getLogger(__name__) # Setting logger
environ['SDL_VIDEO_CENTERED'] = '1' # Centering the window

import numpy as np

from random import sample, uniform

class ExperienceReplay:
    """The class that handles memory and experiences replay.

    Attributes:
        memory: memory array to insert experiences.
        memory_size: the ammount of experiences to be stored in the memory.
        input_shape: the shape of the input which will be stored.
        batch_function: returns targets according to S.
        per: flag for PER usage.
        per_epsilon: used to replace "0" probabilities cases.
        per_alpha: how much prioritization to use.
        per_beta: importance sampling weights (IS_weights).
    """
    def __init__(self, memory_size = 100, per = False, alpha = 0.6,
                 epsilon = 0.001, beta = 0.4, nb_epoch = 10000, decay = 0.5):
        """Initialize parameters and the memory array."""
        self.per = per
        self.memory_size = memory_size
        self.reset_memory() # Initiate the memory

        if self.per:
            self.per_epsilon = epsilon
            self.per_alpha = alpha
            self.per_beta = beta
            self.schedule = LinearSchedule(nb_epoch * decay, 1.0, beta)

    def exp_size(self):
        """Returns how much memory is stored."""
        if self.per:
            return self.exp
        else:
            return len(self.memory)

    def get_priority(self, errors):
        """Returns priority based on how much prioritization to use."""
        return (errors + self.per_epsilon) ** self.per_alpha

    def update(self, tree_indices, errors):
        """Update a list of nodes, based on their errors."""
        priorities = self.get_priority(errors)

        for index, priority in zip(tree_indices, priorities):
            self.memory.update(index, priority)

    def remember(self, s, a, r, s_prime, game_over):
        """Remember SARS' experiences, with the game_over parameter (done)."""
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])

        if self.per: # If using PER, insert in the max_priority.
            max_priority = self.memory.max_leaf()

            if max_priority == 0:
                max_priority = self.get_priority(0)

            self.memory.insert(experience, max_priority)
            self.exp += 1
        else: # Else, just append the experience to the list.
            self.memory.append(experience)

            if self.memory_size > 0 and self.exp_size() > self.memory_size:
                self.memory.pop(0)

    def get_samples(self, batch_size):
        """Sample the memory according to PER flag."""
        if self.per:
            batch = [None] * batch_size
            IS_weights = np.zeros((batch_size, ))
            tree_indices = [0] * batch_size

            memory_sum = self.memory.sum()
            len_seg = memory_sum / batch_size
            min_prob = self.memory.min_leaf() / memory_sum

            for i in range(batch_size):
                val = uniform(len_seg * i, len_seg * (i + 1))
                tree_indices[i], priority, batch[i] = self.memory.retrieve(val)
                prob = priority / self.memory.sum()
                IS_weights[i] = np.power(prob / min_prob, -self.per_beta)

            return np.array(batch), IS_weights, tree_indices

        else:
            IS_weights = np.ones((batch_size, ))
            batch = sample(self.memory, batch_size)
            return np.array(batch), IS_weights, None

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        r = r.repeat(nb_actions).reshape((batch_size, nb_actions))
        game_over = game_over.repeat(nb_actions)\
                             .reshape((batch_size, nb_actions))
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])
            for i in range(batch_size):
                Qsa[i] = Y_target[i][actions[i]]
            Qsa = np.array(Qsa).repeat(nb_actions).reshape((batch_size, nb_actions))

        else:
            Qsa = np.max(Y[batch_size:], axis = 1).repeat(nb_actions)\
                                                .reshape((batch_size, nb_actions))

        # The targets here already take into account
        delta = np.zeros((batch_size, nb_actions))
        a = np.cast['int'](a)
        delta[np.arange(batch_size), a] = 1
        targets = ((1 - delta) * Y[:batch_size]
                  + delta * (r + gamma * (1 - game_over) * Qsa))

        if self.per: # Update the Sum Tree with the absolute error.
            errors = np.abs((targets - Y[:batch_size]).max(axis = 1)).clip(max = 1.)
            self.update(tree_indices, errors)

        return S, targets, IS_weights

    def reset_memory(self):
        """Set the memory as a blank list."""
        if self.per:
            if self.memory_size <= 0:
                self.memory_size = 150000

            self.memory = SumTree(self.memory_size)
            self.exp = 0
        else:
            self.memory = []


#!/usr/bin/env python

"""dqn: First try to create an AI for SnakeGame. Is it good enough?

This algorithm is a implementation of DQN, Double DQN logic (using a target
network to have fixed Q-targets), Dueling DQN logic (Q(s,a) = Advantage + Value),
PER (Prioritized Experience Replay, using Sum Trees) and Multi-step returns. You
can read more about these on https://goo.gl/MctLzp

Implemented algorithms
----------
    * Simple Deep Q-network (DQN with ExperienceReplay);
        Paper: https://arxiv.org/abs/1312.5602
    * Double Deep Q-network (Double DQN);
        Paper: https://arxiv.org/abs/1509.06461
    * Dueling Deep Q-network (Dueling DQN);
        Paper: https://arxiv.org/abs/1511.06581
    * Prioritized Experience Replay (PER);
        Paper: https://arxiv.org/abs/1511.05952
    * Multi-step returns.
        Paper: https://arxiv.org/pdf/1703.01327

Arguments
----------
--load: 'file.h5'
    Load a previously trained model in '.h5' format.
--board_size: int, optional, default = 10
    Assign the size of the board.
--nb_frames: int, optional, default = 4
    Assign the number of frames per stack, default = 4.
--nb_actions: int, optional, default = 5
    Assign the number of actions possible.
--update_freq: int, optional, default = 0.001
    Whether to soft or hard update the target. Epochs or ammount of the update.
--visual: boolean, optional, default = False
    Select wheter or not to draw the game in pygame.
--double: boolean, optional, default = False
    Use a target network with double DQN logic.
--dueling: boolean, optional, default = False
    Whether to use dueling network logic, Q(s,a) = A + V.
--per: boolean, optional, default = False
    Use Prioritized Experience Replay (based on Sum Trees).
--local_state: boolean, optional, default = True
    Verify is possible next moves are dangerous (field expertise)
"""

import numpy as np
from array import array
from os import path, environ, sys
import random
import inspect

# Making relative imports from parallel folders possible
currentdir = path.dirname(path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = path.dirname(currentdir)
sys.path.insert(0, parentdir)

from keras.optimizers import RMSprop, Nadam
from keras.models import load_model, Sequential
from keras.layers import *
from keras import backend as K

__author__ = "Victor Neves"
__license__ = "MIT"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"

K.set_image_dim_ordering('th')  # Setting keras ordering

class Agent:
    """Agent based in a simple DQN that can read states, remember and play.

    Attributes
    ----------
    memory: object
        Memory used in training. ExperienceReplay or PrioritizedExperienceReplay
    memory_size: int, optional, default = -1
        Capacity of the memory used.
    model: keras model
        The input model in Keras.
    target: keras model, optional, default = None
        The target model, used to calculade the fixed Q-targets.
    nb_frames: int, optional, default = 4
        Ammount of frames for each experience (sars).
    board_size: int, optional, default = 10
        Size of the board used.
    frames: list of experiences
        The buffer of frames, store sars experiences.
    per: boolean, optional, default = False
        Flag for PER usage.
    update_target_freq: int or float, default = 0.001
        Whether soft or hard updates occur. If < 1, soft updated target model.
    n_steps: int, optional, default = 1
        Size of the rewards buffer, to use Multi-step returns.
    """
    def __init__(self, model, target = None, memory_size = -1, nb_frames = 4,
                 board_size = 10, per = False, update_target_freq = 0.001):
        """Initialize the agent with given attributes."""
        if per:
            self.memory = PrioritizedExperienceReplay(memory_size = memory_size)
        else:
            self.memory = ExperienceReplay(memory_size = memory_size)

        self.per = per
        self.model = model
        self.target = target
        self.nb_frames = nb_frames
        self.board_size = board_size
        self.update_target_freq = update_target_freq
        self.clear_frames()

    def reset_memory(self):
        """Reset memory if necessary."""
        self.memory.reset_memory()

    def get_game_data(self, game):
        """Create a list with 4 frames and append/pop them each frame.

        Return
        ----------
        expanded_frames: list of experiences
            The buffer of frames, shape = (nb_frames, board_size, board_size)
        """
        frame = game.state()

        if self.frames is None:
            self.frames = [frame] * self.nb_frames
        else:
            self.frames.append(frame)
            self.frames.pop(0)

        expanded_frames = np.expand_dims(self.frames, 0)

        return expanded_frames

    def clear_frames(self):
        """Reset frames to restart appending."""
        self.frames = None

    def update_target_model_hard(self):
        """Update the target model with the main model's weights."""
        self.target.set_weights(self.model.get_weights())

    def transfer_weights(self):
        """Transfer Weights from Model to Target at rate update_target_freq."""
        model_weights = self.model.get_weights()
        target_weights = self.target.get_weights()

        for i in range(len(W)):
            target_weights[i] = (self.update_target_freq * model_weights[i]
                                 + ((1 - self.update_target_frequency)
                                    * target_weights[i]))

        self.target.set_weights(target_weights)

    def print_metrics(self, epoch, nb_epoch, history_size, policy, value,
                      win_count, history_step, history_reward,
                      history_loss = None, verbose = 1):
        """Function to print metrics of training steps."""
        if verbose == 0:
            pass
        elif verbose == 1:
            text_epoch = ('Epoch: {:03d}/{:03d} | Mean size 10: {:.1f} | '
                           + 'Longest 10: {:03d} | Mean steps 10: {:.1f} | '
                           + 'Wins: {:d} | Win percentage: {:.1f}%')
            print(text_epoch.format(epoch + 1, nb_epoch,
                                    np.mean(history_size[-10:]),
                                    max(history_size[-10:]),
                                    np.mean(history_step[-10:]),
                                    win_count, 100 * win_count/(epoch + 1)))
        else:
            text_epoch = 'Epoch: {:03d}/{:03d}'  # Print epoch info
            print(text_epoch.format(epoch + 1, nb_epoch))

            if loss is not None:  # Print training performance
                text_train = ('\t\x1b[0;30;47m' + ' Training metrics ' + '\x1b[0m'
                              + '\tTotal loss: {:.4f} | Loss per step: {:.4f} | '
                              + 'Mean loss - 100 episodes: {:.4f}')
                print(text_perf.format(history_loss[-1],
                                       history_loss[-1] / history_step[-1],
                                       np.mean(history_loss[-100:])))

            text_game = ('\t\x1b[0;30;47m' + ' Game metrics ' + '\x1b[0m'
                         + '\t\tSize: {:d} | Ammount of steps: {:d} | '
                         + 'Steps per food eaten: {:.1f} | '
                         + 'Mean size - 100 episodes: {:.1f}')
            print(text_game.format(history_size[-1], history_step[-1],
                                   history_size[-1] / history_step[-1],
                                   np.mean(history_step[-100:])))

            # Print policy metrics
            if policy == "BoltzmannQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tBoltzmann Temperature: {:.2f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            elif policy == "BoltzmannGumbelQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tNumber of actions: {:.0f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            else:
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tEpsilon: {:.2f} | Episode reward: {:.1f} | '
                               + 'Wins: {:d} | Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))

    def train_model(self, model, target, batch_size, gamma, nb_actions, epoch = 0):
        """Function to train the model on a batch of the data. The optimization
        flag is used when we are not playing, just batching and optimizing.

        Return
        ----------
        loss: float
            Training loss of given batch.
        """
        loss = 0.
        batch = self.memory.get_targets(model = self.model,
                                        target = self.target,
                                        batch_size = batch_size,
                                        gamma = gamma,
                                        nb_actions = nb_actions)

        if batch:
            inputs, targets, IS_weights = batch

            if inputs is not None and targets is not None:
                loss = float(self.model.train_on_batch(inputs,
                                                       targets,
                                                       IS_weights))

        return loss

    def train(self, game, nb_epoch = 10000, batch_size = 64, gamma = 0.95,
              eps = [1., .01], temp = [1., 0.01], learning_rate = 0.5,
              observe = 0, optim_rounds = 1, policy = "EpsGreedyQPolicy",
              verbose = 1, n_steps = 1):
        """The main training function, loops the game, remember and choose best
        action given game state (frames)."""
        if not hasattr(self, 'n_steps'):
            self.n_steps = n_steps  # Set attribute only once

        history_size = array('i')  # Holds all the sizes
        history_step = array('f')  # Holds all the steps
        history_loss = array('f')  # Holds all the losses
        history_reward = array('f')  # Holds all the rewards

        # Select exploration policy. EpsGreedyQPolicy runs faster, but takes
        # longer to converge. BoltzmannGumbelQPolicy is the slowest, but
        # converge really fast (0.1 * nb_epoch used in EpsGreedyQPolicy).
        # BoltzmannQPolicy is in the middle.
        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp[0], temp[1], nb_epoch * learning_rate)
        elif policy == "BoltzmannGumbelQPolicy":
            q_policy = BoltzmannGumbelQPolicy()
        else:
            q_policy = EpsGreedyQPolicy(eps[0], eps[1], nb_epoch * learning_rate)

        nb_actions = game.nb_actions
        win_count = 0

        # If optim_rounds is bigger than one, the model will keep optimizing
        # after the exploration, in turns of nb_epoch size.
        for turn in range(optim_rounds):
            if turn > 0:
                for epoch in range(nb_epoch):
                    loss = self.train_model(model = self.model,
                                            epoch = epoch,
                                            target = self.target,
                                            batch_size = batch_size,
                                            gamma = gamma,
                                            nb_actions = nb_actions)
                text_optim = ('Optimizer turn: {:2d} | Epoch: {:03d}/{:03d}'
                              + '| Loss: {:.4f}')
                print(text_optim.format(turn, epoch + 1, nb_epoch, loss))
            else:  # Exploration and training
                for epoch in range(nb_epoch):
                    loss = 0.
                    total_reward = 0.
                    game.reset_game()
                    self.clear_frames()
                    S = self.get_game_data(game)

                    if n_steps > 1:  # Create multi-step returns buffer.
                        n_step_buffer = array('f')

                    while not game.game_over:  # Main loop, until game_over
                        game.food_pos = game.generate_food()
                        action, value = q_policy.select_action(self.model,
                                                               S, epoch,
                                                               nb_actions)
                        game.play(action)
                        r = game.get_reward()
                        total_reward += r

                        if n_steps > 1:
                            n_step_buffer.append(r)

                            if len(n_step_buffer) < n_steps:
                                R = r
                            else:
                                R = sum([n_step_buffer[i] * (gamma ** i)\
                                        for i in range(n_steps)])
                        else:
                            R = r

                        S_prime = self.get_game_data(game)
                        experience = [S, action, R, S_prime, game.game_over]
                        self.memory.remember(*experience)  # Add to the memory
                        S = S_prime  # Advance to the next state (stack of S)

                        if epoch >= observe:  # Get the batchs and train
                            loss += self.train_model(model = self.model,
                                                     target = self.target,
                                                     batch_size = batch_size,
                                                     gamma = gamma,
                                                     nb_actions = nb_actions)

                    if game.is_won():
                        win_count += 1  # Counter of wins for metrics

                    if self.per:  # Advance beta, used in PER
                        self.memory.beta = self.memory.schedule.value(epoch)

                    if self.target is not None:  # Update the target model
                        if update_target_freq >= 1: # Hard updates
                            if epoch % self.update_target_freq == 0:
                                self.update_target_model_hard()
                        elif update_target_freq < 1.:  # Soft updates
                            self.transfer_weights()

                    history_size.append(game.snake.length)
                    history_step.append(game.step)
                    history_loss.append(loss)
                    history_reward.append(total_reward)

                    if (epoch + 1) % 10 == 0:
                        self.print_metrics(epoch = epoch, nb_epoch = nb_epoch,
                                           history_size = history_size,
                                           history_loss = history_loss,
                                           history_step = history_step,
                                           history_reward = history_reward,
                                           policy = policy, value = value,
                                           win_count = win_count,
                                           verbose = verbose)

    def play(self, game, nb_epoch = 1000, eps = 0.01, temp = 0.01,
             visual = False, policy = "GreedyQPolicy"):
        """Play the game with the trained agent. Can use the visual tag to draw
            in pygame."""
        win_count = 0

        history_size = array('i')  # Holds all the sizes
        history_step = array('f')  # Holds all the steps
        history_reward = array('f')  # Holds all the rewards

        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp, temp, nb_epoch)
        elif policy == "EpsGreedyQPolicy":
            q_policy = EpsGreedyQPolicy(eps, eps, nb_epoch)
        else:
            q_policy = GreedyQPolicy()

        for epoch in range(nb_epoch):
            game.reset_game()
            self.clear_frames()
            S = self.get_game_data(game)

            if visual:
                game.create_window()
                # The main loop, it pump key_presses and update every tick.
                environ['SDL_VIDEO_CENTERED'] = '1'  # Centering the window
                previous_size = game.snake.length  # Initial size of the snake
                color_list = game.gradient([(42, 42, 42), (152, 152, 152)],\
                                               previous_size)

            while not game.game_over:
                action, value = q_policy.select_action(self.model, S, epoch, nb_actions)
                game.play(action)
                current_size = game.snake.length  # Update the body size

                if visual:
                    game.draw(color_list)

                    if current_size > previous_size:
                        color_list = game.gradient([(42, 42, 42), (152, 152, 152)],
                                                   game.snake.length)

                        previous_size = current_size

                S = self.get_game_data(game)

                if game.game_over:
                    history_size.append(current_size)
                    history_step.append(game.step)
                    history_reward.append(game.get_reward())

            if game.is_won():
                win_count += 1

        print("Accuracy: {} %".format(100. * win_count / nb_epoch))
        print("Mean size: {} | Biggest size: {} | Smallest size: {}"\
              .format(np.mean(history_size), np.max(history_size),
                      np.min(history_size)))
        print("Mean steps: {} | Biggest step: {} | Smallest step: {}"\
              .format(np.mean(history_step), np.max(history_step),
                      np.min(history_step)))
        print("Mean rewards: {} | Biggest reward: {} | Smallest reward: {}"\
              .format(np.mean(history_reward), np.max(history_reward),
                      np.min(history_reward)))
       

import random
import numpy as np

class LinearSchedule(object):
    def __init__(self, schedule_timesteps, final_p, initial_p):
        """Linear interpolation between initial_p and final_p over
        schedule_timesteps. After this many timesteps pass final_p is
        returned.
        Parameters
        ----------
        schedule_timesteps: int
            Number of timesteps for which to linearly anneal initial_p
            to final_p
        initial_p: float
            initial output value
        final_p: float
            final output value
        """
        self.schedule_timesteps = schedule_timesteps
        self.final_p = final_p
        self.initial_p = initial_p

    def value(self, t):
        """See Schedule.value"""
        fraction = min(float(t) / self.schedule_timesteps, 1.0)
        return self.initial_p + fraction * (self.final_p - self.initial_p)


class GreedyQPolicy:
    """Implement the greedy policy

    Greedy policy always takes current best action.
    """
    def __init__(self):
        super(GreedyQPolicy, self).__init__()

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)
        action = int(np.argmax(q[0]))

        return action, 0

    def get_config(self):
        """Return configurations of GreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(GreedyQPolicy, self).get_config()
        return config


class EpsGreedyQPolicy:
    """Implement the epsilon greedy policy

    Eps Greedy policy either:

    - takes a random action with probability epsilon
    - takes current best action with prob (1 - epsilon)
    """
    def __init__(self, max_eps=1., min_eps = .01, nb_epoch = 10000):
        super(EpsGreedyQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_eps, max_eps)

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        rand = random.random()
        self.eps = self.schedule.value(epoch)

        if rand < self.eps:
            action = int(nb_actions * rand)
        else:
            q = model.predict(state)
            action = int(np.argmax(q[0]))

        return action, self.eps

    def get_config(self):
        """Return configurations of EpsGreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(EpsGreedyQPolicy, self).get_config()
        config['eps'] = self.eps
        return config


class BoltzmannQPolicy:
    """Implement the Boltzmann Q Policy
    Boltzmann Q Policy builds a probability law on q values and returns
    an action selected randomly according to this law.
    """
    def __init__(self, max_temp = 1., min_temp = .01, nb_epoch = 10000, clip = (-500., 500.)):
        super(BoltzmannQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_temp, max_temp)
        self.clip = clip

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        self.temp = self.schedule.value(epoch)
        arg = q / self.temp

        exp_values = np.exp(arg - arg.max())
        probs = exp_values / exp_values.sum()
        action = np.random.choice(range(nb_actions), p = probs)

        return action, self.temp

    def get_config(self):
        """Return configurations of BoltzmannQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannQPolicy, self).get_config()
        config['temp'] = self.temp
        config['clip'] = self.clip
        return config


class BoltzmannGumbelQPolicy:
    """Implements Boltzmann-Gumbel exploration (BGE) adapted for Q learning
    based on the paper Boltzmann Exploration Done Right
    (https://arxiv.org/pdf/1705.10257.pdf).
    BGE is invariant with respect to the mean of the rewards but not their
    variance. The parameter C, which defaults to 1, can be used to correct for
    this, and should be set to the least upper bound on the standard deviation
    of the rewards.
    BGE is only available for training, not testing. For testing purposes, you
    can achieve approximately the same result as BGE after training for N steps
    on K actions with parameter C by using the BoltzmannQPolicy and setting
    tau = C/sqrt(N/K)."""

    def __init__(self, C = 1.0):
        super(BoltzmannGumbelQPolicy, self).__init__()
        self.C = C
        self.action_counts = None

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        q = q.astype('float64')

        # If we are starting training, we should reset the action_counts.
        # Otherwise, action_counts should already be initialized, since we
        # always do so when we begin training.
        if epoch == 0:
            self.action_counts = np.ones(q.shape)

        beta = self.C/np.sqrt(self.action_counts)
        Z = np.random.gumbel(size = q.shape)

        perturbation = beta * Z
        perturbed_q_values = q + perturbation
        action = np.argmax(perturbed_q_values)

        self.action_counts[action] += 1
        return action, np.sum(self.action_counts)

    def get_config(self):
        """Return configurations of BoltzmannGumbelQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannGumbelQPolicy, self).get_config()
        config['C'] = self.C
        return config

#!/usr/bin/env python

"""clipped_error: L1 for errors < clip_value else L2 error.

Functions:
    huber_loss: Return L1 error if absolute error is less than clip_value, else
                return L2 error.
    clipped_error: Call huber_loss with default clip_value to 1.0.
"""

import numpy as np
from keras import backend as K
import tensorflow as tf

__author__ = "Victor Neves"
__license__ = "MIT"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"

def huber_loss(y_true, y_pred, clip_value):
	# Huber loss, see https://en.wikipedia.org/wiki/Huber_loss and
	# https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b
	# for details.
	assert clip_value > 0.

	x = y_true - y_pred
	if np.isinf(clip_value):
		# Spacial case for infinity since Tensorflow does have problems
		# if we compare `K.abs(x) < np.inf`.
		return .5 * K.square(x)

	condition = K.abs(x) < clip_value
	squared_loss = .5 * K.square(x)
	linear_loss = clip_value * (K.abs(x) - .5 * clip_value)
	if K.backend() == 'tensorflow':
		if hasattr(tf, 'select'):
			return tf.select(condition, squared_loss, linear_loss)  # condition, true, false
		else:
			return tf.where(condition, squared_loss, linear_loss)  # condition, true, false
	elif K.backend() == 'theano':
		from theano import tensor as T
		return T.switch(condition, squared_loss, linear_loss)
	else:
		raise RuntimeError('Unknown backend "{}".'.format(K.backend()))

def clipped_error(y_true, y_pred):
	return K.mean(huber_loss(y_true, y_pred, clip_value = 1.), axis = -1)

#def CNN1(optimizer, loss, stack, input_size, output_size):
 #   model = Sequential()
  #  model.add(Conv2D(32, (3, 3), activation = 'relu', input_shape = (stack,
   #                                                                  input_size,
    #                                                                 input_size)))
#    model.add(Conv2D(64, (3, 3), activation = 'relu'))
 #   model.add(Conv2D(128, (3, 3), activation = 'relu'))
  #  model.add(Conv2D(256, (3, 3), activation = 'relu'))
   # model.add(Flatten())
    #model.add(Dense(1024, activation = 'relu'))
    #model.add(Dense(output_size))
    #model.compile(optimizer = optimizer, loss = loss)

    #return model
    
def CNN4(optimizer, loss, stack, input_size, output_size):
    """From @Kaixhin implementation's of the Rainbow paper."""
    model = Sequential()
    model.add(Conv2D(32, (4, 4), activation = 'relu', input_shape = (stack,
                                                                    input_size,
                                                                    input_size)))
    model.add(Conv2D(64, (2, 2), activation = 'relu'))
    model.add(Conv2D(64, (2, 2), activation = 'relu'))
    model.add(Flatten())
    model.add(Dense(3136, activation = 'relu'))
    model.add(Dense(output_size))
    model.compile(optimizer = optimizer, loss = loss)

    return model

board_size = 10
nb_frames = 4
  
game = Game(player = "ROBOT", board_size = board_size,
                        local_state = True, relative_pos = False)

model = CNN4(optimizer = RMSprop(), loss = clipped_error,
                            stack = nb_frames, input_size = board_size,
                            output_size = game.nb_actions)

target = None

agent = Agent(model = model, target = target, memory_size = -1,
                          nb_frames = nb_frames, board_size = board_size,
                          per = False, update_target_freq = 0.001)

#%lprun -f agent.train agent.train(game, batch_size = 64, nb_epoch = 10, gamma = 0.95, update_target_freq = 500, policy = "EpsGreedyQPolicy")
agent.train(game, batch_size = 64, nb_epoch = 10000, gamma = 0.95, policy = "EpsGreedyQPolicy")        

Using TensorFlow backend.


Epoch: 010/10000 | Mean size 10: 3.3 | Longest 10: 005 | Mean steps 10: 14.9 | Wins: 2 | Win percentage: 20.0%
Epoch: 020/10000 | Mean size 10: 3.0 | Longest 10: 003 | Mean steps 10: 6.0 | Wins: 2 | Win percentage: 10.0%
Epoch: 030/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 10.8 | Wins: 3 | Win percentage: 10.0%
Epoch: 040/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 12.5 | Wins: 4 | Win percentage: 10.0%
Epoch: 050/10000 | Mean size 10: 3.2 | Longest 10: 004 | Mean steps 10: 11.0 | Wins: 6 | Win percentage: 12.0%
Epoch: 060/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 11.7 | Wins: 7 | Win percentage: 11.7%
Epoch: 070/10000 | Mean size 10: 3.3 | Longest 10: 004 | Mean steps 10: 13.9 | Wins: 10 | Win percentage: 14.3%
Epoch: 080/10000 | Mean size 10: 3.3 | Longest 10: 004 | Mean steps 10: 13.4 | Wins: 13 | Win percentage: 16.2%
Epoch: 090/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 12.8 | Wins: 14 | Win percentage: 15.6%

In [0]:
model.save('keras.h5')

!zip -r model-epsgreedy-bench.zip keras.h5 
from google.colab import files
files.download('model-epsgreedy-bench.zip')
model = load_model('keras.h5', custom_objects={'clipped_error': clipped_error})

board_size = 10
nb_frames = 4
nb_actions = 5

target = None

agent = Agent(model = model, target = target, memory_size = 1500000,
                          nb_frames = nb_frames, board_size = board_size,
                          per = False)
#%lprun -f agent.train agent.train(game, batch_size = 64, nb_epoch = 10, gamma = 0.95, update_target_freq = 500, policy = "EpsGreedyQPolicy")

agent.play(game, visual = False, nb_epoch = 10000)

  adding: keras.h5 (deflated 42%)
Accuracy: 100.0 %
Mean size: 18.5717 | Biggest size: 43 | Smallest size: 4
Mean steps: 143.53480529785156 | Biggest step: 491.0 | Smallest step: 6.0
Mean rewards: -1.0 | Biggest reward: -1.0 | Smallest reward: -1.0


# Testing what changed between the modules
Same as the benchmark, changed the memory file.

In [0]:
#!/usr/bin/env python

"""SnakeGame: A simple and fun exploration, meant to be used by AI algorithms.
"""

import sys # To close the window when the game is over
from os import environ, path # To center the game window the best possible
import random # Random numbers used for the food
import logging # Logging function for movements and errors
from itertools import tee # For the color gradient on snake
import numpy as np

__author__ = "Victor Neves"
__license__ = "MIT"
__version__ = "1.0"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"

# Actions, options and forbidden moves
options = {'QUIT': 0, 'PLAY': 1, 'BENCHMARK': 2, 'LEADERBOARDS': 3, 'MENU': 4, 'ADD_LEADERBOARDS': 5}
relative_actions = {'LEFT': 0, 'FORWARD': 1, 'RIGHT': 2}
actions = {'LEFT': 0, 'RIGHT': 1, 'UP': 2, 'DOWN': 3, 'IDLE': 4}
forbidden_moves = [(0, 1), (1, 0), (2, 3), (3, 2)]

# Types of point in the board
point_type = {'EMPTY': 0, 'FOOD': 1, 'BODY': 2, 'HEAD': 3, 'DANGEROUS': 4}

class GlobalVariables:
    """Global variables to be used while drawing and moving the snake game.

    Attributes:
        BLOCK_SIZE: The size in pixels of a block.
        HEAD_COLOR: Color of the head.
        BODY_COLOR: Color of the body.
        FOOD_COLOR: Color of the food.
        GAME_SPEED: Speed in ticks of the game. The higher the faster.
    """
    def __init__(self):
        """Initialize all global variables."""
        self.BOARD_SIZE = 30
        self.BLOCK_SIZE = 20
        self.HEAD_COLOR = (0, 0, 0)
        self.BODY_COLOR = (0, 200, 0)
        self.FOOD_COLOR = (200, 0, 0)
        self.GAME_SPEED = 10
        self.BENCHMARK = 10

        if self.BOARD_SIZE > 50:
            logger.warning('WARNING: BOARD IS TOO BIG, IT MAY RUN SLOWER.')

class TextBlock:
    def __init__(self, text, pos, screen, scale = (1 / 12), type = "text"):
        self.type = type
        self.hovered = False
        self.text = text
        self.pos = pos
        self.screen = screen
        self.scale = scale
        self.set_rect()
        self.draw()

    def draw(self):
        self.set_rend()
        self.screen.blit(self.rend, self.rect)

    def set_rend(self):
        font = pygame.font.Font(resource_path("resources/fonts/freesansbold.ttf"),
                                int((var.BOARD_SIZE * var.BLOCK_SIZE) * self.scale))
        self.rend = font.render(self.text, True, self.get_color(),
                                self.get_background())

    def get_color(self):
        if self.type == "menu":
            if self.hovered:
                return pygame.Color(42, 42, 42)
            else:
                return pygame.Color(152, 152, 152)

        return pygame.Color(42, 42, 42)

    def get_background(self):
        if self.type == "menu":
            if self.hovered:
                return pygame.Color(152, 152, 152)

        return None

    def set_rect(self):
        self.set_rend()
        self.rect = self.rend.get_rect()
        self.rect.center = self.pos


class Snake:
    """Player (snake) class which initializes head, body and board.

    The body attribute represents a list of positions of the body, which are in-
    cremented when moving/eating on the position [0]. The orientation represents
    where the snake is looking at (head) and collisions happen when any element
    is superposed with the head.

    Attributes:
        head: The head of the snake, located according to the board size.
        body: Starts with 3 parts and grows when food is eaten.
        orientation: Current orientation where head is pointing.
    """
    def __init__(self):
        """Inits Snake with 3 body parts (one is the head) and pointing right"""
        self.head = [int(var.BOARD_SIZE / 4), int(var.BOARD_SIZE / 4)]
        self.body = [[self.head[0], self.head[1]],
                     [self.head[0] - 1, self.head[1]],
                     [self.head[0] - 2, self.head[1]]]
        self.previous_action = 1
        self.length = 3

    def move(self, action, food_pos):
        """According to orientation, move 1 block. If the head is not positioned
        on food, pop a body part. Else (food), return without popping."""
        if action == actions['IDLE']\
            or (action, self.previous_action) in forbidden_moves:
            action = self.previous_action
        else:
            self.previous_action = action

        if action == actions['LEFT']:
            self.head[0] -= 1
        elif action == actions['RIGHT']:
            self.head[0] += 1
        elif action == actions['UP']:
            self.head[1] -= 1
        elif action == actions['DOWN']:
            self.head[1] += 1

        self.body.insert(0, list(self.head))

        if self.head == food_pos:
            logger.info('EVENT: FOOD EATEN')
            self.length = len(self.body)

            return True
        else:
            self.body.pop()

            return False

    def return_body(self):
        """Return the whole body."""
        return self.body


class FoodGenerator:
    """Generate and keep track of food.

    Attributes:
        pos: Current position of food.
        is_food_on_screen: Flag for existence of food.
    """
    def __init__(self, body):
        """Initialize a food piece and set existence flag."""
        self.is_food_on_screen = False
        self.pos = self.generate_food(body)

    def generate_food(self, body):
        """Generate food and verify if it's on a valid place."""
        if not self.is_food_on_screen:
            while True:
                food = [int((var.BOARD_SIZE - 1) * random.random()),
                        int((var.BOARD_SIZE - 1) * random.random())]

                if food in body:
                    continue
                else:
                    self.pos = food
                    break

            logger.info('EVENT: FOOD APPEARED')
            self.is_food_on_screen = True

        return self.pos

    def set_food_on_screen(self, bool_value):
        """Set flag for existence (or not) of food."""
        self.is_food_on_screen = bool_value


class Game:
    """Hold the game window and functions.

    Attributes:
        window: pygame window to show the game.
        fps: Define Clock and ticks in which the game will be displayed.
        snake: The actual snake who is going to be played.
        food_generator: Generator of food which responds to the snake.
        food_pos: Position of the food on the board.
        game_over: Flag for game_over.
    """
    def __init__(self, player, board_size = 30, local_state = False, relative_pos = False):
        """Initialize window, fps and score."""
        var.BOARD_SIZE = board_size
        self.local_state = local_state
        self.relative_pos = relative_pos
        self.player = player

        if player == "ROBOT":
            if self.relative_pos:
                self.nb_actions = 3
            else:
                self.nb_actions = 5

            self.reset_game()

    def reset_game(self):
        self.step = 0
        self.snake = Snake()
        self.food_generator = FoodGenerator(self.snake.body)
        self.food_pos = self.food_generator.pos
        self.scored = False
        self.game_over = False

    def create_window(self):
        pygame.init()

        flags = pygame.DOUBLEBUF
        self.window = pygame.display.set_mode((var.BOARD_SIZE * var.BLOCK_SIZE,\
                                               var.BOARD_SIZE * var.BLOCK_SIZE),
                                               flags)
        self.window.set_alpha(None)
        self.fps = pygame.time.Clock()

    def menu(self):
        pygame.display.set_caption("SNAKE GAME  | PLAY NOW!")

        img = pygame.image.load(resource_path("resources/images/snake_logo.png"))
        img = pygame.transform.scale(img, (var.BOARD_SIZE * var.BLOCK_SIZE, int(var.BOARD_SIZE * var.BLOCK_SIZE / 3)))

        self.screen_rect = self.window.get_rect()
        img_rect = img.get_rect()
        img_rect.center = self.screen_rect.center

        menu_options = [TextBlock(' PLAY GAME ', (self.screen_rect.centerx,
                                                  4 * self.screen_rect.centery / 10),
                                                  self.window, (1 / 12), "menu"),
                        TextBlock(' BENCHMARK ', (self.screen_rect.centerx,
                                                  6 * self.screen_rect.centery / 10),
                                                  self.window, (1 / 12), "menu"),
                        TextBlock(' LEADERBOARDS ', (self.screen_rect.centerx,
                                                     8 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 12), "menu"),
                        TextBlock(' QUIT ', (self.screen_rect.centerx,
                                             10 * self.screen_rect.centery / 10),
                                             self.window, (1 / 12), "menu")]

        while True:
            pygame.event.pump()
            ev = pygame.event.get()

            self.window.fill(pygame.Color(225, 225, 225))

            for option in menu_options:
                option.draw()

                if option.rect.collidepoint(pygame.mouse.get_pos()):
                    option.hovered = True

                    if option == menu_options[0]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                return options['PLAY']
                    elif option == menu_options[1]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                return options['BENCHMARK']
                    elif option == menu_options[2]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                return options['LEADERBOARDS']
                    elif option == menu_options[3]:
                        for event in ev:
                            if event.type == pygame.MOUSEBUTTONUP:
                                return options['QUIT']
                else:
                    option.hovered = False

            self.window.blit(img, img_rect.bottomleft)
            pygame.display.update()

    def start_match(self):
        """Create some wait time before the actual drawing of the game."""
        for i in range(3):
            time = str(3 - i)
            self.window.fill(pygame.Color(225, 225, 225))

            # Game starts in 3, 2, 1
            text = [TextBlock('Game starts in', (self.screen_rect.centerx,
                                                 4 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 10), "text"),
                    TextBlock(time, (self.screen_rect.centerx,
                                                 12 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 1.5), "text")]

            for text_block in text:
                text_block.draw()

            pygame.display.update()
            pygame.display.set_caption("SNAKE GAME  |  Game starts in "
                                       + time + " second(s) ...")

            pygame.time.wait(1000)

        logger.info('EVENT: GAME START')

    def start(self):
        """Use menu to select the option/game mode."""
        opt = self.menu()
        running = True

        while running:
            if opt == options['QUIT']:
                pygame.quit()
                sys.exit()
            elif opt == options['PLAY']:
                self.select_speed()
                self.reset_game()
                self.start_match()
                score = self.single_player()
                opt = self.over(score)
            elif opt == options['BENCHMARK']:
                self.select_speed()
                score = []

                for i in range(var.BENCHMARK):
                    self.reset_game()
                    self.start_match()
                    score.append(self.single_player())

                opt = self.over(score)
            elif opt == options['LEADERBOARDS']:
                pass
            elif opt == options['ADD_LEADERBOARDS']:
                pass
            elif opt == options['MENU']:
                opt = self.menu()

    def over(self, score):
        """If collision with wall or body, end the game."""
        menu_options = [None] * 5

        menu_options[0] = TextBlock(' PLAY AGAIN ', (self.screen_rect.centerx,
                                                     4 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 15), "menu")
        menu_options[1] = TextBlock(' GO TO MENU ', (self.screen_rect.centerx,
                                                     6 * self.screen_rect.centery / 10),
                                                     self.window, (1 / 15), "menu")
        menu_options[3] = TextBlock(' QUIT ', (self.screen_rect.centerx,
                                               10 * self.screen_rect.centery / 10),
                                               self.window, (1 / 15), "menu")

        if isinstance(score, int):
            text_score = 'SCORE: ' + str(score)

        else:
            text_score = 'MEAN SCORE: ' + str(sum(score) / var.BENCHMARK)

            menu_options[2] = TextBlock(' ADD TO LEADERBOARDS ', (self.screen_rect.centerx,
                                                                  8 * self.screen_rect.centery / 10),
                                                                  self.window, (1 / 15), "menu")

        pygame.display.set_caption("SNAKE GAME  | " + text_score
                                   + "  |  GAME OVER...")
        logger.info('EVENT: GAME OVER | FINAL ' + text_score)

        menu_options[4] = TextBlock(text_score, (self.screen_rect.centerx,
                                                 15 * self.screen_rect.centery / 10),
                                                 self.window, (1 / 10), "text")

        while True:
            pygame.event.pump()
            ev = pygame.event.get()

            # Game over screen
            self.window.fill(pygame.Color(225, 225, 225))

            for option in menu_options:
                if option is not None:
                    option.draw()

                    if option.rect.collidepoint(pygame.mouse.get_pos()):
                        option.hovered = True

                        if option == menu_options[0]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    return options['PLAY']
                        elif option == menu_options[1]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    return options['MENU']
                        elif option == menu_options[2]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    return options['ADD_LEADERBOARDS']
                        elif option == menu_options[3]:
                            for event in ev:
                                if event.type == pygame.MOUSEBUTTONUP:
                                    pygame.quit()
                                    sys.exit()
                    else:
                        option.hovered = False

            pygame.display.update()

    def single_player(self):
        # The main loop, it pump key_presses and update the board every tick.
        previous_size = self.snake.length # Initial size of the snake
        current_size = previous_size # Initial size
        color_list = self.gradient([(42, 42, 42), (152, 152, 152)],\
                                   previous_size)

        # Main loop, where the snake keeps going each tick. It generate food, check
        # collisions and draw.
        while True:
            action = self.handle_input()

            if self.play(action):
                return current_size

            self.draw(color_list)
            current_size = self.snake.length # Update the body size

            if current_size > previous_size:
                color_list = self.gradient([(42, 42, 42), (152, 152, 152)],\
                                           current_size)

                previous_size = current_size

    def check_collision(self):
        """Check wether any collisions happened with the wall or body and re-
        turn."""
        if self.snake.head[0] > (var.BOARD_SIZE - 1) or self.snake.head[0] < 0:
            logger.info('EVENT: WALL COLLISION')

            return True
        elif self.snake.head[1] > (var.BOARD_SIZE - 1) or self.snake.head[1] < 0:
            logger.info('EVENT: WALL COLLISION')

            return True
        elif self.snake.head in self.snake.body[1:]:
            logger.info('EVENT: BODY COLLISION')

            return True

        return False

    def is_won(self):
        return self.snake.length > 3

    def generate_food(self):
        return self.food_generator.generate_food(self.snake.body)

    def handle_input(self):
        """After getting current pressed keys, handle important cases."""
        pygame.event.set_allowed([pygame.QUIT, pygame.KEYDOWN])
        keys = pygame.key.get_pressed()
        pygame.event.pump()

        if keys[pygame.K_ESCAPE] or keys[pygame.K_q]:
            logger.info('ACTION: KEY PRESSED: ESCAPE or Q')
            self.over(self.snake.length - 3)
        elif keys[pygame.K_LEFT]:
            logger.info('ACTION: KEY PRESSED: LEFT')
            return actions['LEFT']
        elif keys[pygame.K_RIGHT]:
            logger.info('ACTION: KEY PRESSED: RIGHT')
            return actions['RIGHT']
        elif keys[pygame.K_UP]:
            logger.info('ACTION: KEY PRESSED: UP')
            return actions['UP']
        elif keys[pygame.K_DOWN]:
            logger.info('ACTION: KEY PRESSED: DOWN')
            return actions['DOWN']
        else:
            return self.snake.previous_action

    def eval_local_safety(self, canvas, body):
        """Evaluate the safety of the head's possible next movements."""
        if (body[0][0] + 1) > (var.BOARD_SIZE - 1)\
            or ([body[0][0] + 1, body[0][1]]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 0] = point_type['DANGEROUS']
        if (body[0][0] - 1) < 0 or ([body[0][0] - 1, body[0][1]]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 1] = point_type['DANGEROUS']
        if (body[0][1] - 1) < 0 or ([body[0][0], body[0][1] - 1]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 2] = point_type['DANGEROUS']
        if (body[0][1] + 1) > (var.BOARD_SIZE - 1)\
            or ([body[0][0], body[0][1] + 1]) in body[1:]:
            canvas[var.BOARD_SIZE - 1, 3] = point_type['DANGEROUS']

        return canvas

    def state(self):
        """Create a matrix of the current state of the game."""
        body = self.snake.return_body()
        canvas = np.zeros((var.BOARD_SIZE, var.BOARD_SIZE))

        for part in body:
            canvas[part[0], part[1]] = point_type['BODY']

        canvas[body[0][0], body[0][1]] = point_type['HEAD']

        if self.local_state:
            canvas = self.eval_local_safety(canvas, body)

        canvas[self.food_pos[0], self.food_pos[1]] = point_type['FOOD']

        return canvas

    def relative_to_absolute(self, action):
        if action == relative_actions['FORWARD']:
            action = self.snake.previous_action
        elif action == relative_actions['LEFT']:
            if self.snake.previous_action == actions['LEFT']:
                action = actions['DOWN']
            elif self.snake.previous_action == actions['RIGHT']:
                action = actions['UP']
            elif self.snake.previous_action == actions['UP']:
                action = actions['LEFT']
            else:
                action = actions['RIGHT']
        else:
            if self.snake.previous_action == actions['LEFT']:
                action = actions['UP']
            elif self.snake.previous_action == actions['RIGHT']:
                action = actions['DOWN']
            elif self.snake.previous_action == actions['UP']:
                action = actions['RIGHT']
            else:
                action = actions['LEFT']

        return action

    def play(self, action):
        """Move the snake to the direction, eat and check collision."""
        self.scored = False
        self.step += 1
        self.food_pos = self.generate_food()

        if self.relative_pos:
            action = self.relative_to_absolute(action)

        if self.snake.move(action, self.food_pos):
            self.scored = True
            self.food_generator.set_food_on_screen(False)

        if self.player == "HUMAN":
            if self.check_collision():
                return True
        elif self.check_collision() or self.step > 50 * self.snake.length:
            self.game_over = True

    def get_reward(self):
        """Return the current score. Can be used as the reward function."""
        if self.game_over:
            return -1
        elif self.scored:
            return self.snake.length

        return -0.005

    def gradient(self, colors, steps, components = 3):
        """Function to create RGB gradients given 2 colors and steps.

        If component is changed to 4, it does the same to RGBA colors."""
        def linear_gradient(start, finish, substeps):
            yield start

            for i in range(1, substeps):
                yield tuple([(start[j] + (float(i) / (substeps-1)) * (finish[j]\
                            - start[j])) for j in range(components)])

        def pairs(seq):
            a, b = tee(seq)
            next(b, None)

            return zip(a, b)

        result = []
        substeps = int(float(steps) / (len(colors) - 1))

        for a, b in pairs(colors):
            for c in linear_gradient(a, b, substeps):
                result.append(c)

        return result

    def draw(self, color_list):
        """Draw the game, the snake and the food using pygame."""
        self.window.fill(pygame.Color(225, 225, 225))

        for part, color in zip(self.snake.body, color_list):
            pygame.draw.rect(self.window, color, pygame.Rect(part[0] *\
                        var.BLOCK_SIZE, part[1] * var.BLOCK_SIZE, \
                        var.BLOCK_SIZE, var.BLOCK_SIZE))

        pygame.draw.rect(self.window, var.FOOD_COLOR,\
                         pygame.Rect(self.food_pos[0] * var.BLOCK_SIZE,\
                         self.food_pos[1] * var.BLOCK_SIZE, var.BLOCK_SIZE,\
                         var.BLOCK_SIZE))

        pygame.display.set_caption("SNAKE GAME  |  Score: "
                                    + str(self.snake.length - 3))
        pygame.display.update()
        self.fps.tick(var.GAME_SPEED)

def resource_path(relative_path):
    if hasattr(sys, '_MEIPASS'):
        return path.join(sys._MEIPASS, relative_path)

    return path.join(path.dirname(path.realpath(__file__)), relative_path)

var = GlobalVariables() # Initializing GlobalVariables
logger = logging.getLogger(__name__) # Setting logger
environ['SDL_VIDEO_CENTERED'] = '1' # Centering the window

import numpy as np
from random import sample, uniform
from array import array  # Efficient numeric arrays

class ExperienceReplay:
    """The class that handles memory and experiences replay.

    Attributes
    ----------
    memory: list of experiences
        Memory list to insert experiences.
    memory_size: int, optional, default = 150000
        The ammount of experiences to be stored in the memory.
    input_shape: tuple of 3 * int
        The shape of the input which will be stored.
    """
    def __init__(self, memory_size = 150000):
        """Initialize parameters and the memory array."""
        self.memory_size = memory_size
        self.reset_memory() # Initiate the memory

    def exp_size(self):
        """Returns how much memory is stored."""
        return len(self.memory)

    def remember(self, s, a, r, s_prime, game_over):
        """Remember SARS' experiences, with the game_over parameter (done)."""
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])

        self.memory.append(experience)

        if self.memory_size > 0 and self.exp_size() > self.memory_size:
            self.memory.pop(0)

    def get_samples(self, batch_size):
        """Sample the memory according to PER flag.

        Return
        ----------
        batch: np.array of batch_size experiences
            The batched experiences from memory.
        IS_weights: np.array of batch_size of the weights
            As it's used only in PER, is an array of ones in this case.
        Indexes: list of batch_size * int
            As it's used only in PER, return None.
        """
        IS_weights = np.ones((batch_size, ))
        batch = np.array(sample(self.memory, batch_size))

        return batch, IS_weights, None

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9,
                    n_steps = 1):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])

            for idx, target in enumerate(Y_target):
                Qsa[idx] = target[actions[idx]]

            Qsa = np.array(Qsa)
        else:
            Qsa = np.max(Y[batch_size:], axis = 1)

        # Where the action happened, replace with the Q values of S_prime
        targets = np.array(Y[:batch_size])
        value = r + (gamma ** n_steps) * (1 - game_over) * Qsa
        targets[range(batch_size), a.astype(int)] = value

        return S, targets, IS_weights

    def reset_memory(self):
        """Set the memory as a blank list."""
        self.memory = []


class PrioritizedExperienceReplayNaive:
    """The class that handles memory and experiences replay.

    Attributes:
        memory: memory array to insert experiences.
        memory_size: the ammount of experiences to be stored in the memory.
        input_shape: the shape of the input which will be stored.
        batch_function: returns targets according to S.
        per: flag for PER usage.
        per_epsilon: used to replace "0" probabilities cases.
        per_alpha: how much prioritization to use.
        per_beta: importance sampling weights (IS_weights).
    """
    def __init__(self, memory_size = 150000, alpha = 0.6, epsilon = 0.001,
                 beta = 0.4, nb_epoch = 10000, decay = 0.5):
        """Initialize parameters and the memory array."""
        self.memory_size = memory_size
        self.epsilon = epsilon
        self.alpha = alpha
        self.beta = beta
        self.schedule = LinearSchedule(nb_epoch * decay, 1.0, beta)
        self.reset_memory() # Initiate the memory

    def exp_size(self):
        """Returns how much memory is stored."""
        return self.exp

    def get_priority(self, errors):
        """Returns priority based on how much prioritization to use."""
        return (errors + self.epsilon) ** self.alpha

    def update(self, tree_indices, errors):
        """Update a list of nodes, based on their errors."""
        priorities = self.get_priority(errors)

        for index, priority in zip(tree_indices, priorities):
            self.memory.update(index, priority)

    def remember(self, s, a, r, s_prime, game_over):
        """Remember SARS' experiences, with the game_over parameter (done)."""
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])

        max_priority = self.memory.max_leaf()

        if max_priority == 0:
            max_priority = self.get_priority(0)

        self.memory.insert(experience, max_priority)
        self.exp += 1

    def get_samples(self, batch_size):
        """Sample the memory according to PER flag."""
        batch = [None] * batch_size
        IS_weights = np.zeros((batch_size, ))
        tree_indices = [0] * batch_size

        memory_sum = self.memory.sum()
        len_seg = memory_sum / batch_size
        min_prob = self.memory.min_leaf() / memory_sum

        for i in range(batch_size):
            val = uniform(len_seg * i, len_seg * (i + 1))
            tree_indices[i], priority, batch[i] = self.memory.retrieve(val)
            prob = priority / self.memory.sum()
            IS_weights[i] = np.power(prob / min_prob, -self.beta)

        return np.array(batch), IS_weights, tree_indices

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9,
                    n_steps = 1):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])

            for idx, target in enumerate(Y_target):
                Qsa[idx] = target[actions[idx]]

            Qsa = np.array(Qsa)
        else:
            Qsa = np.max(Y[batch_size:], axis = 1)

        # Where the action happened, replace with the Q values of S_prime
        targets = np.array(Y[:batch_size])
        value = r + (gamma ** n_steps) * (1 - game_over) * Qsa
        targets[range(batch_size), a.astype(int)] = value

        errors = np.abs(value - Y[:batch_size].max(axis = 1)).clip(max = 1.)
        self.update_priorities(tree_indices, errors)

        return S, targets, IS_weights

    def reset_memory(self):
        """Set the memory as a blank list."""
        if self.memory_size <= 100:
            self.memory_size = 150000

        self.memory = SumTree(self.memory_size)
        self.exp = 0


class PrioritizedExperienceReplay:
    def __init__(self, memory_size, nb_epoch = 10000, epsilon = 0.001,
                 alpha = 0.6, beta = 0.4, decay = 0.5):
        self.memory_size = memory_size
        self.alpha = alpha
        self.epsilon = epsilon
        self.beta = beta
        self.schedule = LinearSchedule(nb_epoch * decay, 1.0, beta)
        self.max_priority = 1.0
        self.reset_memory()

    def exp_size(self):
        """Returns how much memory is stored."""
        return len(self.memory)

    def remember(self, s, a, r, s_prime, game_over):
        if not hasattr(self, 'input_shape'):
            self.input_shape = s.shape[1:] # set attribute only once

        experience = np.concatenate([s.flatten(),
                                     np.array(a).flatten(),
                                     np.array(r).flatten(),
                                     s_prime.flatten(),
                                     1 * np.array(game_over).flatten()])
        if self.exp_size() < self.memory_size:
            self.memory.append(experience)
            self.pos += 1
        else:
            self.memory[self.pos] = experience
            self.pos = (self.pos + 1) % self.memory_size

        self._it_sum[self.pos] = self.max_priority ** self.alpha
        self._it_min[self.pos] = self.max_priority ** self.alpha

    def _sample_proportional(self, batch_size):
        res = array('i')

        for _ in range(batch_size):
            mass = random.random() * self._it_sum.sum(0, self.exp_size() - 1)
            idx = self._it_sum.find_prefixsum_idx(mass)
            res.append(idx)

        return res

    def get_priority(self, errors):
        """Returns priority based on how much prioritization to use."""
        return (errors + self.epsilon) ** self.alpha

    def get_samples(self, batch_size):
        idxes = self._sample_proportional(batch_size)

        weights = array('f')
        p_min = self._it_min.min() / self._it_sum.sum()
        max_weight = (p_min * self.exp_size()) ** (-self.beta)

        for idx in idxes:
            p_sample = self._it_sum[idx] / self._it_sum.sum()
            weight = (p_sample * self.exp_size()) ** (-self.beta)
            weights.append(weight / max_weight)

        weights = np.array(weights, dtype=np.float32)
        samples = [self.memory[idx] for idx in idxes]

        return np.array(samples), weights, idxes

    def get_targets(self, target, model, batch_size, nb_actions, gamma = 0.9,
                    n_steps = 1):
        """Function to sample, set batch function and use it for targets."""
        if self.exp_size() < batch_size:
            return None

        samples, IS_weights, tree_indices = self.get_samples(batch_size)
        input_dim = np.prod(self.input_shape) # Get the input shape, multiplied

        S = samples[:, 0 : input_dim] # Seperate the states
        a = samples[:, input_dim] # Separate the actions
        r = samples[:, input_dim + 1] # Separate the rewards
        S_prime = samples[:, input_dim + 2 : 2 * input_dim + 2] # Next_actions
        game_over = samples[:, 2 * input_dim + 2] # Separate terminal flags

        # Reshape the arrays to make them usable by the model.
        S = S.reshape((batch_size, ) + self.input_shape)
        S_prime = S_prime.reshape((batch_size, ) + self.input_shape)

        X = np.concatenate([S, S_prime], axis = 0)
        Y = model.predict(X)

        if target is not None: # Use Double DQN logic:
            Qsa = [None] * 64
            actions = np.argmax(Y[batch_size:], axis = 1)
            Y_target = target.predict(X[batch_size:])

            for idx, target in enumerate(Y_target):
                Qsa[idx] = target[actions[idx]]

            Qsa = np.array(Qsa)
        else:
            Qsa = np.max(Y[batch_size:], axis = 1)

        # Where the action happened, replace with the Q values of S_prime
        targets = np.array(Y[:batch_size])
        value = r + (gamma ** n_steps) * (1 - game_over) * Qsa
        targets[range(batch_size), a.astype(int)] = value

        errors = np.abs(value - Y[:batch_size].max(axis = 1)).clip(max = 1.)
        self.update_priorities(tree_indices, errors)

        return S, targets, IS_weights

    def update_priorities(self, idxes, errors):
        priorities = self.get_priority(errors)

        for idx, priority in zip(idxes, priorities):
            self._it_sum[idx] = priority ** self.alpha
            self._it_min[idx] = priority ** self.alpha

            self.max_priority = max(self.max_priority, priority)

    def reset_memory(self):
        """Set the memory as a blank list."""
        if self.memory_size <= 100:
            self.memory_size = 150000

        self.memory = []
        self.pos = 0

        it_capacity = 1

        while it_capacity < self.memory_size:
            it_capacity *= 2

        self._it_sum = SumSegmentTree(it_capacity)
        self._it_min = MinSegmentTree(it_capacity)


#!/usr/bin/env python

"""dqn: First try to create an AI for SnakeGame. Is it good enough?

This algorithm is a implementation of DQN, Double DQN logic (using a target
network to have fixed Q-targets), Dueling DQN logic (Q(s,a) = Advantage + Value),
PER (Prioritized Experience Replay, using Sum Trees) and Multi-step returns. You
can read more about these on https://goo.gl/MctLzp

Implemented algorithms:
    * Simple DQN (with ExperienceReplay);
        Paper: https://arxiv.org/abs/1312.5602
    * Double DQN;
        Paper: https://arxiv.org/abs/1509.06461
    * Dueling DQN;
        Paper: https://arxiv.org/abs/1511.06581
    * DQN + PER;
        Paper: https://arxiv.org/abs/1511.05952
    * Multi-step returns.
        Paper: https://arxiv.org/pdf/1703.01327

Arguments:
    --load FILE.h5: load a previously trained model in '.h5' format.
    --board_size INT: assign the size of the board, default = 10
    --nb_frames INT: assign the number of frames per stack, default = 4.
    --nb_actions INT: assign the number of actions possible, default = 5.
    --update_freq INT: assign how often, in epochs, to update the target,
      default = 500.
    --visual: select wheter or not to draw the game in pygame.
    --double: use a target network with double DQN logic.
    --dueling: use dueling network logic, Q(s,a) = A + V.
    --per: use Prioritized Experience Replay (based on Sum Trees).
    --local_state: Verify is possible next moves are dangerous (field expertise)
"""

import numpy as np
from os import path, environ, sys
import random

import inspect # Making relative imports from parallel folders possible
currentdir = path.dirname(path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = path.dirname(currentdir)
sys.path.insert(0, parentdir)

from keras.optimizers import RMSprop, Nadam
from keras.models import load_model, Sequential
from keras.layers import *
from keras import backend as K
K.set_image_dim_ordering('th')

__author__ = "Victor Neves"
__license__ = "MIT"
__version__ = "1.0"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"
__status__ = "Production"

class Agent:
    """Agent based in a simple DQN that can read states, remember and play.

    Attributes:
    memory: memory used in the model. Input memory or ExperienceReplay.
    model: the input model, Conv2D in Keras.
    target: the target model, used to calculade the fixed Q-targets.
    nb_frames: ammount of frames for each sars.
    frames: the frames in each sars.
    per: flag for PER usage.
    """
    def __init__(self, model, target, memory = None, memory_size = 150000,
                 nb_frames = 4, board_size = 10, per = False):
        """Initialize the agent with given attributes."""
        if memory:
            self.memory = memory
        else:
            self.memory = ExperienceReplay(memory_size = memory_size)

        self.per = per
        self.model = model
        self.target = target
        self.nb_frames = nb_frames
        self.board_size = board_size
        self.frames = None
        self.target_updates = 0

    def reset_memory(self):
        """Reset memory if necessary."""
        self.memory.reset_memory()

    def get_game_data(self, game):
        """Create a list with 4 frames and append/pop them each frame."""
        if game.game_over:
            frame = np.zeros((self.board_size, self.board_size))
        else:
            frame = game.state()

        if self.frames is None:
            self.frames = [frame] * self.nb_frames
        else:
            self.frames.append(frame)
            self.frames.pop(0)

        return np.expand_dims(self.frames, 0)

    def clear_frames(self):
        """Reset frames to restart appending."""
        self.frames = None

    def update_target_model(self):
        """Update the target model with the main model's weights."""
        self.target_updates += 1
        self.target.set_weights(self.model.get_weights())

    def print_metrics(self, epoch, nb_epoch, history_size, history_loss,
                      history_step, history_reward, policy, value, win_count,
                      verbose = 1):
        """Function to print metrics of training steps."""
        if verbose == 0:
            pass
        elif verbose == 1:
            text_epoch = ('Epoch: {:03d}/{:03d} | Mean size 10: {:.1f} | '
                           + 'Longest 10: {:03d} | Mean steps 10: {:.1f} | '
                           + 'Wins: {:d} | Win percentage: {:.1f}%')
            print(text_epoch.format(epoch + 1, nb_epoch,
                                    sum(history_size[-10:]) / 10,
                                    max(history_size[-10:]),
                                    sum(history_step[-10:]) / 10,
                                    win_count, 100 * win_count/(epoch + 1)))
        else:
            text_epoch = 'Epoch: {:03d}/{:03d}' # Print epoch info
            print(text_epoch.format(epoch + 1, nb_epoch))

            # Print training performance
            text_train = ('\t\x1b[0;30;47m' + ' Training metrics ' + '\x1b[0m'
                          + '\tTotal loss: {:.4f} | Loss per step: {:.4f} | '
                          + 'Mean loss - 100 episodes: {:.4f}')
            print(text_perf.format(history_loss[-1],
                                   history_loss[-1] / history_step[-1],
                                   sum(history_loss[-100:]) / 100))

            text_game = ('\t\x1b[0;30;47m' + ' Game metrics ' + '\x1b[0m'
                         + '\t\tSize: {:d} | Ammount of steps: {:d} | '
                         + 'Steps per food eaten: {:.1f} | '
                         + 'Mean size - 100 episodes: {:.1f}')
            print(text_game.format(history_size[-1], history_step[-1],
                                   history_size[-1] / history_step[-1],
                                   sum(history_step[-100:]) / 100))

            # Print policy metrics
            if policy == "BoltzmannQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tBoltzmann Temperature: {:.2f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            elif policy == "BoltzmannGumbelQPolicy":
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tNumber of actions: {:.0f} | '
                               + 'Episode reward: {:.1f} | Wins: {:d} | '
                               + 'Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))
            else:
                text_policy = ('\t\x1b[0;30;47m' + ' Policy metrics ' + '\x1b[0m'
                               + '\tEpsilon: {:.2f} | Episode reward: {:.1f} | '
                               + 'Wins: {:d} | Win percentage: {:.1f}%')
                print(text_policy.format(value, history_reward[-1], win_count,
                                         100 * win_count/(epoch + 1)))

    def train_model(self, model, target, batch_size, gamma, nb_actions, epoch = 0):
        """Function to train the model on a batch of the data. The optimization
        flag is used when we are not playing, just batching and optimizing."""
        loss = 0.

        batch = self.memory.get_targets(model = self.model,
                                        target = self.target,
                                        batch_size = batch_size,
                                        gamma = gamma,
                                        nb_actions = nb_actions)

        if batch:
            inputs, targets, IS_weights = batch

            if inputs is not None and targets is not None:
                loss = float(self.model.train_on_batch(inputs,
                                                       targets,
                                                       IS_weights))

        return loss

    def train(self, game, nb_epoch = 10000, batch_size = 64, gamma = 0.95,
              eps = [1., .01], temp = [1., 0.01], learning_rate = 0.5,
              observe = 0, update_target_freq = 0.001, optim_rounds = 1,
              policy = "EpsGreedyQPolicy", verbose = 1, n_steps = None):
        """The main training function, loops the game, remember and choose best
        action given game state (frames)."""
        history_size = []
        history_step = []
        history_loss = []
        history_reward = []

        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp[0], temp[1], nb_epoch * learning_rate)
        if policy == "BoltzmannGumbelQPolicy":
            q_policy = BoltzmannGumbelQPolicy()
        else:
            q_policy = EpsGreedyQPolicy(eps[0], eps[1], nb_epoch * learning_rate)

        nb_actions = game.nb_actions
        win_count = 0

        for turn in range(optim_rounds):
            if turn > 0:
                for epoch in range(nb_epoch):
                    loss = self.train_model(model = self.model,
                                            epoch = epoch,
                                            target = self.target,
                                            batch_size = batch_size,
                                            gamma = gamma,
                                            nb_actions = nb_actions)
                    text_optim = ('Optimizer turn: {:2d} | Epoch: {:03d}/{:03d}'
                                  + '| Loss: {:.4f}')
                    print(text_optim.format(turn, epoch + 1, nb_epoch, loss))
            else:
                for epoch in range(nb_epoch):
                    loss = 0.
                    total_reward = 0.
                    if n_steps is not None:
                        n_step_buffer = []
                    game.reset_game()
                    self.clear_frames()

                    S = self.get_game_data(game)

                    while not game.game_over:
                        game.food_pos = game.generate_food()
                        action, value = q_policy.select_action(self.model,
                                                               S, epoch,
                                                               nb_actions)

                        game.play(action)

                        r = game.get_reward()
                        total_reward += r
                        if n_steps is not None:
                            n_step_buffer.append(r)

                            if len(n_step_buffer) < n_steps:
                                R = r
                            else:
                                R = sum([n_step_buffer[i] * (gamma ** i)\
                                        for i in range(n_steps)])
                        else:
                            R = r

                        S_prime = self.get_game_data(game)
                        experience = [S, action, R, S_prime, game.game_over]
                        self.memory.remember(*experience) # Add to the memory
                        S = S_prime # Advance to the next state (stack of S)

                        if epoch >= observe: # Get the batchs and train
                            loss += self.train_model(model = self.model,
                                                     target = self.target,
                                                     batch_size = batch_size,
                                                     gamma = gamma,
                                                     nb_actions = nb_actions)

                    if game.is_won():
                        win_count += 1 # Counter for metric purposes

                    if self.per: # Advance beta, used in PER
                        self.memory.per_beta = self.memory.schedule.value(epoch)

                    if self.target is not None: # Update the target model
                        if epoch % update_target_freq == 0:
                            self.update_target_model()

                    history_size.append(game.snake.length)
                    history_step.append(game.step)
                    history_loss.append(loss)
                    history_reward.append(total_reward)

                    if (epoch + 1) % 10 == 0:
                        self.print_metrics(epoch, nb_epoch, history_size,
                                           history_loss, history_step,
                                           history_reward, policy, value,
                                           win_count, verbose)

    def play(self, game, nb_epoch = 1000, eps = 0.01, temp = 0.01,
             visual = False, policy = "GreedyQPolicy"):
        """Play the game with the trained agent. Can use the visual tag to draw
            in pygame."""
        win_count = 0
        result_size = []
        result_step = []
        if policy == "BoltzmannQPolicy":
            q_policy = BoltzmannQPolicy(temp, temp, nb_epoch)
        elif policy == "EpsGreedyQPolicy":
            q_policy = EpsGreedyQPolicy(eps, eps, nb_epoch)
        else:
            q_policy = GreedyQPolicy()

        for epoch in range(nb_epoch):
            game.reset_game()
            self.clear_frames()
            S = self.get_game_data(game)

            if visual:
                game.create_window()
                # The main loop, it pump key_presses and update every tick.
                environ['SDL_VIDEO_CENTERED'] = '1' # Centering the window
                previous_size = game.snake.length # Initial size of the snake
                color_list = game.gradient([(42, 42, 42), (152, 152, 152)],\
                                               previous_size)

            while not game.game_over:
                action, value = q_policy.select_action(self.model, S, epoch, nb_actions)
                game.play(action)
                current_size = game.snake.length # Update the body size

                if visual:
                    game.draw(color_list)

                    if current_size > previous_size:
                        color_list = game.gradient([(42, 42, 42), (152, 152, 152)],
                                                   game.snake.length)

                        previous_size = current_size

                S = self.get_game_data(game)

                if game.game_over:
                    result_size.append(current_size)
                    result_step.append(game.step)

            if game.is_won():
                win_count += 1

        print("Accuracy: {} %".format(100. * win_count / nb_epoch))
        print("Mean size: {} | Biggest size: {} | Smallest size: {}"\
              .format(np.mean(result_size), np.max(result_size),
                      np.min(result_size)))
        print("Mean steps: {} | Biggest step: {} | Smallest step: {}"\
              .format(np.mean(result_step), np.max(result_step),\
                      np.min(result_step)))

import random
import numpy as np

class LinearSchedule(object):
    def __init__(self, schedule_timesteps, final_p, initial_p):
        """Linear interpolation between initial_p and final_p over
        schedule_timesteps. After this many timesteps pass final_p is
        returned.
        Parameters
        ----------
        schedule_timesteps: int
            Number of timesteps for which to linearly anneal initial_p
            to final_p
        initial_p: float
            initial output value
        final_p: float
            final output value
        """
        self.schedule_timesteps = schedule_timesteps
        self.final_p = final_p
        self.initial_p = initial_p

    def value(self, t):
        """See Schedule.value"""
        fraction = min(float(t) / self.schedule_timesteps, 1.0)
        return self.initial_p + fraction * (self.final_p - self.initial_p)


class GreedyQPolicy:
    """Implement the greedy policy

    Greedy policy always takes current best action.
    """
    def __init__(self):
        super(GreedyQPolicy, self).__init__()

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)
        action = int(np.argmax(q[0]))

        return action, 0

    def get_config(self):
        """Return configurations of GreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(GreedyQPolicy, self).get_config()
        return config


class EpsGreedyQPolicy:
    """Implement the epsilon greedy policy

    Eps Greedy policy either:

    - takes a random action with probability epsilon
    - takes current best action with prob (1 - epsilon)
    """
    def __init__(self, max_eps=1., min_eps = .01, nb_epoch = 10000):
        super(EpsGreedyQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_eps, max_eps)

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        rand = random.random()
        self.eps = self.schedule.value(epoch)

        if rand < self.eps:
            action = int(nb_actions * rand)
        else:
            q = model.predict(state)
            action = int(np.argmax(q[0]))

        return action, self.eps

    def get_config(self):
        """Return configurations of EpsGreedyQPolicy
        # Returns
            Dict of config
        """
        config = super(EpsGreedyQPolicy, self).get_config()
        config['eps'] = self.eps
        return config


class BoltzmannQPolicy:
    """Implement the Boltzmann Q Policy
    Boltzmann Q Policy builds a probability law on q values and returns
    an action selected randomly according to this law.
    """
    def __init__(self, max_temp = 1., min_temp = .01, nb_epoch = 10000, clip = (-500., 500.)):
        super(BoltzmannQPolicy, self).__init__()
        self.schedule = LinearSchedule(nb_epoch, min_temp, max_temp)
        self.clip = clip

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        self.temp = self.schedule.value(epoch)
        arg = q / self.temp

        exp_values = np.exp(arg - arg.max())
        probs = exp_values / exp_values.sum()
        action = np.random.choice(range(nb_actions), p = probs)

        return action, self.temp

    def get_config(self):
        """Return configurations of BoltzmannQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannQPolicy, self).get_config()
        config['temp'] = self.temp
        config['clip'] = self.clip
        return config


class BoltzmannGumbelQPolicy:
    """Implements Boltzmann-Gumbel exploration (BGE) adapted for Q learning
    based on the paper Boltzmann Exploration Done Right
    (https://arxiv.org/pdf/1705.10257.pdf).
    BGE is invariant with respect to the mean of the rewards but not their
    variance. The parameter C, which defaults to 1, can be used to correct for
    this, and should be set to the least upper bound on the standard deviation
    of the rewards.
    BGE is only available for training, not testing. For testing purposes, you
    can achieve approximately the same result as BGE after training for N steps
    on K actions with parameter C by using the BoltzmannQPolicy and setting
    tau = C/sqrt(N/K)."""

    def __init__(self, C = 1.0):
        super(BoltzmannGumbelQPolicy, self).__init__()
        self.C = C
        self.action_counts = None

    def select_action(self, model, state, epoch, nb_actions):
        """Return the selected action
        # Arguments
            q_values (np.ndarray): List of the estimations of Q for each action
        # Returns
            Selection action
        """
        q = model.predict(state)[0]
        q = q.astype('float64')

        # If we are starting training, we should reset the action_counts.
        # Otherwise, action_counts should already be initialized, since we
        # always do so when we begin training.
        if epoch == 0:
            self.action_counts = np.ones(q.shape)

        beta = self.C/np.sqrt(self.action_counts)
        Z = np.random.gumbel(size = q.shape)

        perturbation = beta * Z
        perturbed_q_values = q + perturbation
        action = np.argmax(perturbed_q_values)

        self.action_counts[action] += 1
        return action, np.sum(self.action_counts)

    def get_config(self):
        """Return configurations of BoltzmannGumbelQPolicy
        # Returns
            Dict of config
        """
        config = super(BoltzmannGumbelQPolicy, self).get_config()
        config['C'] = self.C
        return config

#!/usr/bin/env python

"""clipped_error: L1 for errors < clip_value else L2 error.

Functions:
    huber_loss: Return L1 error if absolute error is less than clip_value, else
                return L2 error.
    clipped_error: Call huber_loss with default clip_value to 1.0.
"""

import numpy as np
from keras import backend as K
import tensorflow as tf

__author__ = "Victor Neves"
__license__ = "MIT"
__maintainer__ = "Victor Neves"
__email__ = "victorneves478@gmail.com"

def huber_loss(y_true, y_pred, clip_value):
	# Huber loss, see https://en.wikipedia.org/wiki/Huber_loss and
	# https://medium.com/@karpathy/yes-you-should-understand-backprop-e2f06eab496b
	# for details.
	assert clip_value > 0.

	x = y_true - y_pred
	if np.isinf(clip_value):
		# Spacial case for infinity since Tensorflow does have problems
		# if we compare `K.abs(x) < np.inf`.
		return .5 * K.square(x)

	condition = K.abs(x) < clip_value
	squared_loss = .5 * K.square(x)
	linear_loss = clip_value * (K.abs(x) - .5 * clip_value)
	if K.backend() == 'tensorflow':
		if hasattr(tf, 'select'):
			return tf.select(condition, squared_loss, linear_loss)  # condition, true, false
		else:
			return tf.where(condition, squared_loss, linear_loss)  # condition, true, false
	elif K.backend() == 'theano':
		from theano import tensor as T
		return T.switch(condition, squared_loss, linear_loss)
	else:
		raise RuntimeError('Unknown backend "{}".'.format(K.backend()))

def clipped_error(y_true, y_pred):
	return K.mean(huber_loss(y_true, y_pred, clip_value = 1.), axis = -1)

#def CNN1(optimizer, loss, stack, input_size, output_size):
 #   model = Sequential()
  #  model.add(Conv2D(32, (3, 3), activation = 'relu', input_shape = (stack,
   #                                                                  input_size,
    #                                                                 input_size)))
#    model.add(Conv2D(64, (3, 3), activation = 'relu'))
 #   model.add(Conv2D(128, (3, 3), activation = 'relu'))
  #  model.add(Conv2D(256, (3, 3), activation = 'relu'))
   # model.add(Flatten())
    #model.add(Dense(1024, activation = 'relu'))
    #model.add(Dense(output_size))
    #model.compile(optimizer = optimizer, loss = loss)

    #return model
    
def CNN4(optimizer, loss, stack, input_size, output_size):
    """From @Kaixhin implementation's of the Rainbow paper."""
    model = Sequential()
    model.add(Conv2D(32, (4, 4), activation = 'relu', input_shape = (stack,
                                                                    input_size,
                                                                    input_size)))
    model.add(Conv2D(64, (2, 2), activation = 'relu'))
    model.add(Conv2D(64, (2, 2), activation = 'relu'))
    model.add(Flatten())
    model.add(Dense(3136, activation = 'relu'))
    model.add(Dense(output_size))
    model.compile(optimizer = optimizer, loss = loss)

    return model
  
board_size = 10
nb_frames = 4
  
game = Game(player = "ROBOT", board_size = board_size,
                        local_state = True, relative_pos = False)

model = CNN4(optimizer = RMSprop(), loss = clipped_error,
                            stack = nb_frames, input_size = board_size,
                            output_size = game.nb_actions)
target = None

agent = Agent(model = model, target = target, memory_size = -1,
                          nb_frames = nb_frames, board_size = board_size,
                          per = False)
#%lprun -f agent.train agent.train(game, batch_size = 64, nb_epoch = 10, gamma = 0.95, update_target_freq = 500, policy = "EpsGreedyQPolicy")
agent.train(game, batch_size = 64, nb_epoch = 10000, gamma = 0.95, policy = "EpsGreedyQPolicy")

Epoch: 010/10000 | Mean size 10: 3.3 | Longest 10: 005 | Mean steps 10: 11.8 | Wins: 2 | Win percentage: 20.0%
Epoch: 020/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 13.3 | Wins: 3 | Win percentage: 15.0%
Epoch: 030/10000 | Mean size 10: 3.0 | Longest 10: 003 | Mean steps 10: 7.5 | Wins: 3 | Win percentage: 10.0%
Epoch: 040/10000 | Mean size 10: 3.3 | Longest 10: 004 | Mean steps 10: 10.2 | Wins: 6 | Win percentage: 15.0%
Epoch: 050/10000 | Mean size 10: 3.3 | Longest 10: 004 | Mean steps 10: 14.2 | Wins: 9 | Win percentage: 18.0%
Epoch: 060/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 18.8 | Wins: 10 | Win percentage: 16.7%
Epoch: 070/10000 | Mean size 10: 3.3 | Longest 10: 005 | Mean steps 10: 14.4 | Wins: 12 | Win percentage: 17.1%
Epoch: 080/10000 | Mean size 10: 3.0 | Longest 10: 003 | Mean steps 10: 12.6 | Wins: 12 | Win percentage: 15.0%
Epoch: 090/10000 | Mean size 10: 3.1 | Longest 10: 004 | Mean steps 10: 10.7 | Wins: 13 | Win percentage: 14.4

In [0]:
model.save('keras.h5')

!zip -r model-epsgreedy-bench-newmemory.zip keras.h5 
from google.colab import files
files.download('model-epsgreedy-bench-newmemory.zip')
model = load_model('keras.h5', custom_objects={'clipped_error': clipped_error})

board_size = 10
nb_frames = 4
nb_actions = 5

target = None

agent = Agent(model = model, target = target, memory_size = 1500000,
                          nb_frames = nb_frames, board_size = board_size,
                          per = False)
#%lprun -f agent.train agent.train(game, batch_size = 64, nb_epoch = 10, gamma = 0.95, update_target_freq = 500, policy = "EpsGreedyQPolicy")

agent.play(game, visual = False, nb_epoch = 10000)

  adding: keras.h5 (deflated 43%)
Accuracy: 100.0 %
Mean size: 19.7908 | Biggest size: 45 | Smallest size: 4
Mean steps: 144.3177 | Biggest step: 801 | Smallest step: 5
