# Game

In [None]:
import enum
import random
from abc import abstractmethod
from collections import deque
from typing import Deque, Protocol, Tuple

import pygame
from pygame.color import Color
from pygame.event import Event
from pygame.surface import Surface

In [None]:
pygame.init()


class Point(Protocol):
    x: int
    y: int


class DisplayObject:
    def __init__(self, color: Color) -> None:
        self.color = color

    @abstractmethod
    def draw(self, surface: Surface, gridSize: float) -> None:
        raise NotImplementedError()


class Block(DisplayObject):
    def __init__(self, color: Color, x: int, y: int) -> None:
        super().__init__(color)
        self.x = x
        self.y = y

    def hitTest(self, target: Point) -> bool:
        return self.x == target.x and self.y == target.y

    def draw(self, surface: Surface, gridSize: float) -> None:
        pygame.draw.rect(surface, self.color, pygame.Rect(
            self.x * gridSize + 1,
            self.y * gridSize + 1,
            gridSize - 2,
            gridSize - 2))


class Apple(Block):
    def __init__(self, color: Color, x: int, y: int) -> None:
        super().__init__(color, x, y)


class Snake(DisplayObject):
    def __init__(self, color: Color) -> None:
        super().__init__(color)
        self.tails: Deque[Block] = deque()

    def reset(self, x: int, y: int, length: int) -> None:
        self.x = x
        self.y = y
        self.vx = 0
        self.vy = 0
        self.length = length
        self.tails.clear()
        self.tails.append(Block(self.color, self.x, self.y))

    def hitTest(self, target: Point) -> bool:
        return any(tail.hitTest(target) for tail in self.tails)

    def getNewHead(self) -> Block:
        return Block(self.color, self.x + self.vx, self.y + self.vy)

    def setNewHead(self, head: Block) -> None:
        self.x = head.x
        self.y = head.y
        self.tails.appendleft(head)
        while len(self.tails) > self.length:
            self.tails.pop()

    def draw(self, surface: Surface, gridSize: float) -> None:
        for tail in self.tails:
            tail.draw(surface, gridSize)


class GameState(enum.Enum):
    Ready = enum.auto()
    Dead = enum.auto()
    Step = enum.auto()
    TurnLeft = enum.auto()
    TurnRight = enum.auto()
    TurnUp = enum.auto()
    TurnDown = enum.auto()


class SnakeGame:
    def __init__(
            self,
            background: Color,
            snakeColor: Color,
            appleColor: Color,
            gridSize: float,
            size: Tuple[int, int],
            initialSnakeLength: int) -> None:
        self.background = background
        self.snake = Snake(snakeColor)
        self.apple = Apple(appleColor, -1, -1)
        self.font = pygame.font.SysFont("Segoe UI", 20)
        self.gridSize = gridSize
        self.cols = size[0]
        self.rows = size[1]
        self.initialSnakeLength = initialSnakeLength
        self.reset()

    @property
    def score(self) -> int:
        return self.snake.length - self.initialSnakeLength

    def load(self) -> None:
        self.surface = pygame.display.set_mode(
            (self.gridSize * self.cols, self.gridSize * self.rows))

    def reset(self) -> None:
        self.lastState = GameState.Ready
        self.snake.reset(self.cols // 2, self.rows // 2, self.initialSnakeLength)
        self.spawnApple()

    def spawnApple(self) -> None:
        do = True
        while do:
            self.apple.x = random.randint(0, self.cols - 1)
            self.apple.y = random.randint(0, self.rows - 1)
            do = self.snake.hitTest(self.apple)

    def handleEvent(self, event: Event) -> None:
        if event.type != pygame.KEYDOWN:
            return
        if self.lastState == GameState.Dead:
            self.reset()
            return
        if self.lastState not in (GameState.Ready, GameState.Step):
            self.update()
            self.draw()
        if event.key == pygame.K_LEFT:
            self.lastState = GameState.TurnLeft
        elif event.key == pygame.K_RIGHT:
            self.lastState = GameState.TurnRight
        elif event.key == pygame.K_UP:
            self.lastState = GameState.TurnUp
        elif event.key == pygame.K_DOWN:
            self.lastState = GameState.TurnDown

    def update(self) -> None:
        if self.lastState in (GameState.Ready, GameState.Dead):
            return
        elif self.lastState == GameState.Step:
            pass
        elif self.lastState == GameState.TurnLeft:
            if self.snake.vx == 0:
                self.snake.vx = -1
                self.snake.vy = 0
        elif self.lastState == GameState.TurnRight:
            if self.snake.vx == 0:
                self.snake.vx = 1
                self.snake.vy = 0
        elif self.lastState == GameState.TurnUp:
            if self.snake.vy == 0:
                self.snake.vx = 0
                self.snake.vy = -1
        elif self.lastState == GameState.TurnDown:
            if self.snake.vy == 0:
                self.snake.vx = 0
                self.snake.vy = 1

        head = self.snake.getNewHead()
        if not 0 <= head.x < self.cols or not 0 <= head.y < self.rows:
            self.lastState = GameState.Dead
            return
        if self.snake.hitTest(head):
            self.lastState = GameState.Dead
            return
        self.snake.setNewHead(head)
        if head.hitTest(self.apple):
            self.snake.length += 1
            self.spawnApple()
        self.lastState = GameState.Step

    def draw(self) -> None:
        self.surface.fill(self.background)
        self.snake.draw(self.surface, self.gridSize)
        self.apple.draw(self.surface, self.gridSize)
        text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
        self.surface.blit(text, [0, 0])
        pygame.display.update()

## Running the game for a human player

Using `pygame.event` as the input source.

In [None]:
def runGame(frameRate: int) -> None:
    pygame.display.set_caption("Snake")
    clock = pygame.time.Clock()
    game = SnakeGame(Color(0, 0, 0),
                     Color(0, 0, 255),
                     Color(255, 0, 0),
                     20,
                     (20, 20),
                     5)
    game.load()
    while True:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.display.quit()
                return
            else:
                game.handleEvent(event)
        game.update()
        game.draw()
        clock.tick(frameRate)

In [None]:
# runGame(16)

# Environment


The game as `gym.Env` to use for the AI.

In [None]:
from typing import Any, Dict, NamedTuple

import gym
from gym import spaces

import numpy as np
from numpy.typing import NDArray

In [None]:
class P(NamedTuple):
    x: int
    y: int

In [None]:
ObservationSpace = NDArray[np.bool8]
ActionSpace = np.uint8


class SnakeGameEnv(gym.Env):
    Size = P(10, 10)
    Stale = 100
    InitialSnakeLength = 1

    def __init__(self) -> None:
        self.game = SnakeGame(Color(0, 0, 0),
                              Color(0, 0, 255),
                              Color(255, 0, 0),
                              20,
                              self.Size,
                              self.InitialSnakeLength)
        self.observation_space = spaces.Box(0, 1, (*self.Size, 3), np.bool8)
        self.action_space = spaces.Discrete(4)

    def reset(self) -> ObservationSpace:
        self.game.reset()
        self.staleCounter = 0
        return self._getState()

    def step(self, action: ActionSpace) -> Tuple[ObservationSpace, float, bool, Dict[str, Any]]:
        prevScore = self.game.score
        self._performAction(action)

        state = self._getState()
        reward = self._calculateReward(prevScore)
        done = self.game.lastState == GameState.Dead or self.staleCounter >= self.Stale
        return (state, reward, done, {})

    def render(self, *args: Any, **kwargs: Any) -> None:
        self.game.draw()

    def load(self) -> None:
        self.game.load()

    def close(self) -> None:
        pygame.display.quit()

    def _getState(self) -> ObservationSpace:
        state = np.zeros((*self.Size, 3), np.bool8)
        for tail in self.game.snake.tails:
            state[tail.x, tail.y, 0] = 1
        state[self.game.snake.x, self.game.snake.y, 1] = 1
        state[self.game.apple.x, self.game.apple.y, 2] = 1
        return state

    def _performAction(self, action: ActionSpace) -> None:
        self.game.lastState = (
            GameState.TurnLeft,
            GameState.TurnRight,
            GameState.TurnUp,
            GameState.TurnDown,
        )[action]
        self.game.update()

    def _calculateReward(self, prevScore: int) -> float:
        if self.game.lastState == GameState.Dead:
            return -self.Stale
        elif self.game.score > prevScore:
            self.staleCounter = 0
            return self.Stale
        else:
            self.staleCounter += 1
            return -1

    def __enter__(self):
        self.load()
        return self

    def __exit__(self, *args: Any) -> None:
        self.close()

In [None]:
env = SnakeGameEnv()

## Testing the environment

In [None]:
from stable_baselines3.common.env_checker import check_env
check_env(env)

In [None]:
with env:
    clock = pygame.time.Clock()
    for i in range(5):
        obs = env.reset()
        done = False
        score = 0
        env.render()
        pygame.event.pump()
        clock.tick(30)

        while not done:
            action = env.action_space.sample()
            obs, reward, done, info = env.step(action)
            score += reward
            env.render()
            pygame.event.pump()
            clock.tick(30)

        print(f"Episode:{i + 1} Score:{score}")

# Train Model

## Using DQN with MlpPolicy

In [None]:
import os

from stable_baselines3 import DQN
from stable_baselines3.dqn.policies import MlpPolicy

In [None]:
logPath = os.path.join("logs")
modelPath = os.path.join("saved_models", "DQN2")

In [None]:
model = DQN(
    policy=MlpPolicy,
    env=env,
    verbose=1,
    tensorboard_log=logPath)

In [None]:
model.learn(total_timesteps=1_000_000)

In [None]:
model.learn(total_timesteps=2_000_000, reset_num_timesteps=False)

In [None]:
model.save(modelPath)

In [None]:
del model

# Evaluation

In [None]:
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor

In [None]:
model = DQN.load(modelPath, env)

In [None]:
with env:
    eval = evaluate_policy(
        model,
        Monitor(env),
        n_eval_episodes=10,
        render=True,
        return_episode_rewards=True)
print(eval)
print((np.average(eval[0]), np.average(eval[1])))

In [None]:
with env:
    clock = pygame.time.Clock()
    for i in range(5):
        obs = env.reset()
        done = False
        score = 0
        env.render()
        pygame.event.pump()
        clock.tick(30)

        while not done:
            action, state = model.predict(obs, deterministic=True)
            obs, reward, done, info = env.step(action)  # type: ignore
            score += reward
            env.render()
            pygame.event.pump()
            clock.tick(30)

        print(f"Episode:{i + 1} Score:{score}")
