# Game

In [1]:
import enum
import random
from abc import abstractmethod
from collections import deque
from typing import Deque, Protocol, Tuple

import pygame
from pygame.color import Color
from pygame.event import Event
from pygame.surface import Surface

pygame 2.1.2 (SDL 2.0.18, Python 3.9.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
pygame.init()

(5, 0)

In [3]:
class Point(Protocol):
    x: int
    y: int

In [4]:
class DisplayObject:
    def __init__(self, color: Color) -> None:
        self.color = color

    @abstractmethod
    def draw(self, surface: Surface, gridSize: float) -> None:
        raise NotImplementedError()

In [5]:
class Block(DisplayObject):
    def __init__(self, color: Color, x: int, y: int) -> None:
        super().__init__(color)
        self.x = x
        self.y = y

    def hitTest(self, target: Point) -> bool:
        return self.x == target.x and self.y == target.y

    def draw(self, surface: Surface, gridSize: float) -> None:
        pygame.draw.rect(surface, self.color, pygame.Rect(
            self.x * gridSize + 1,
            self.y * gridSize + 1,
            gridSize - 2,
            gridSize - 2))

In [6]:
class Apple(Block):
    def __init__(self, color: Color, x: int, y: int) -> None:
        super().__init__(color, x, y)

In [7]:
class Snake(DisplayObject):
    def __init__(self, color: Color) -> None:
        super().__init__(color)
        self.tails: Deque[Block] = deque()

    def reset(self, x: int, y: int, length: int) -> None:
        self.x = x
        self.y = y
        self.vx = 0
        self.vy = 0
        self.length = length
        self.tails.clear()
        self.tails.append(Block(self.color, self.x, self.y))

    def hitTest(self, target: Point) -> bool:
        return any(tail.hitTest(target) for tail in self.tails)

    def getNewHead(self) -> Block:
        return Block(self.color, self.x + self.vx, self.y + self.vy)

    def setNewHead(self, head: Block) -> None:
        self.x = head.x
        self.y = head.y
        self.tails.appendleft(head)
        while len(self.tails) > self.length:
            self.tails.pop()

    def draw(self, surface: Surface, gridSize: float) -> None:
        for tail in self.tails:
            tail.draw(surface, gridSize)

In [8]:
class GameState(enum.Enum):
    Ready = enum.auto()
    Dead = enum.auto()
    Step = enum.auto()
    TurnLeft = enum.auto()
    TurnRight = enum.auto()
    TurnUp = enum.auto()
    TurnDown = enum.auto()

In [9]:
class SnakeGame:
    initialSnakeLength: int = 5

    def __init__(
            self,
            background: Color,
            snakeColor: Color,
            appleColor: Color,
            gridSize: float,
            size: Tuple[int, int]) -> None:
        self.background = background
        self.snake = Snake(snakeColor)
        self.apple = Apple(appleColor, -1, -1)
        self.font = pygame.font.SysFont("Segoe UI", 20)
        self.gridSize = gridSize
        self.cols = size[0]
        self.rows = size[1]
        self.reset()

    @property
    def score(self) -> int:
        return self.snake.length - self.initialSnakeLength

    def load(self) -> None:
        self.surface = pygame.display.set_mode(
            (self.gridSize * self.cols, self.gridSize * self.rows))

    def reset(self) -> None:
        self.lastState = GameState.Ready
        self.snake.reset(self.cols // 2, self.rows // 2, self.initialSnakeLength)
        self.spawnApple()

    def spawnApple(self) -> None:
        do = True
        while do:
            self.apple.x = random.randint(0, self.cols - 1)
            self.apple.y = random.randint(0, self.rows - 1)
            do = self.snake.hitTest(self.apple)

    def handleEvent(self, event: Event) -> None:
        if event.type != pygame.KEYDOWN:
            return
        if self.lastState == GameState.Dead:
            self.reset()
            return
        if self.lastState not in (GameState.Ready, GameState.Step):
            self.update()
            self.draw()
        if event.key == pygame.K_LEFT:
            self.lastState = GameState.TurnLeft
        elif event.key == pygame.K_RIGHT:
            self.lastState = GameState.TurnRight
        elif event.key == pygame.K_UP:
            self.lastState = GameState.TurnUp
        elif event.key == pygame.K_DOWN:
            self.lastState = GameState.TurnDown

    def update(self) -> None:
        if self.lastState in (GameState.Ready, GameState.Dead):
            return
        elif self.lastState == GameState.Step:
            pass
        elif self.lastState == GameState.TurnLeft:
            if self.snake.vx == 0:
                self.snake.vx = -1
                self.snake.vy = 0
        elif self.lastState == GameState.TurnRight:
            if self.snake.vx == 0:
                self.snake.vx = 1
                self.snake.vy = 0
        elif self.lastState == GameState.TurnUp:
            if self.snake.vy == 0:
                self.snake.vx = 0
                self.snake.vy = -1
        elif self.lastState == GameState.TurnDown:
            if self.snake.vy == 0:
                self.snake.vx = 0
                self.snake.vy = 1

        head = self.snake.getNewHead()
        if not 0 <= head.x < self.cols or not 0 <= head.y < self.rows:
            self.lastState = GameState.Dead
            return
        if self.snake.hitTest(head):
            self.lastState = GameState.Dead
            return
        self.snake.setNewHead(head)
        if head.hitTest(self.apple):
            self.snake.length += 1
            self.spawnApple()
        self.lastState = GameState.Step

    def draw(self) -> None:
        self.surface.fill(self.background)
        self.snake.draw(self.surface, self.gridSize)
        self.apple.draw(self.surface, self.gridSize)
        text = self.font.render(f"Score: {self.score}", True, (255, 255, 255))
        self.surface.blit(text, [0, 0])
        pygame.display.update()

## Running the game for a human player

Using `pygame.event` as the input source.

In [10]:
def runGame(frameRate: int) -> None:
    pygame.display.set_caption("Snake")
    clock = pygame.time.Clock()
    game = SnakeGame(Color(0, 0, 0),
                     Color(0, 0, 255),
                     Color(255, 0, 0),
                     20,
                     (20, 20))
    game.load()
    while True:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.display.quit()
                return
            else:
                game.handleEvent(event)
        game.update()
        game.draw()
        clock.tick(frameRate)

In [11]:
runGame(16)

# Environment


The game as `gym.Env` to use for the AI.

In [12]:
from typing import Any, Dict, NamedTuple

import gym
from gym import spaces

import numpy as np
from numpy.typing import NDArray

In [13]:
ObservationSpace = NDArray[np.bool8]
ActionSpace = np.uint8

In [14]:
class P(NamedTuple):
    x: int
    y: int

In [15]:
class SnakeGameEnv(gym.Env):
    Size = P(20, 20)
    Stale = 1_000

    def __init__(self) -> None:
        self.game = SnakeGame(Color(0, 0, 0),
                              Color(0, 0, 255),
                              Color(255, 0, 0),
                              20,
                              self.Size)
        self.observation_space = spaces.MultiBinary(8)
        self.action_space = spaces.Discrete(4)

    def reset(self) -> ObservationSpace:
        self.game.reset()
        self.staleCounter = 0
        return self._getState()

    def step(self, action: ActionSpace) -> Tuple[ObservationSpace, float, bool, Dict[str, Any]]:
        prevScore = self.game.score
        self._performAction(action)

        state = self._getState()
        reward = self._calculateReward(prevScore)
        done = self.game.lastState == GameState.Dead or self.staleCounter >= self.Stale
        return (state, reward, done, {})

    def render(self, *args: Any, **kwargs: Any) -> None:
        self.game.draw()

    def load(self) -> None:
        self.game.load()

    def close(self) -> None:
        pygame.display.quit()

    def _getState(self) -> ObservationSpace:
        def isUnsafe(x: int, y: int) -> bool:
            return not 0 <= x < self.Size.x or not 0 <= y < self.Size.y or self.game.snake.hitTest(P(x, y))

        (x, y) = (self.game.snake.x, self.game.snake.y)
        return np.array([
            isUnsafe(x - 1, y),
            isUnsafe(x + 1, y),
            isUnsafe(x, y - 1),
            isUnsafe(x, y + 1),
            self.game.apple.x < x,
            self.game.apple.x > x,
            self.game.apple.y < y,
            self.game.apple.y > y,
        ], np.bool8)

    def _performAction(self, action: ActionSpace) -> None:
        self.game.lastState = (
            GameState.TurnLeft,
            GameState.TurnRight,
            GameState.TurnUp,
            GameState.TurnDown,
        )[action]
        self.game.update()

    def _calculateReward(self, prevScore: int) -> float:
        if self.game.lastState == GameState.Dead:
            return -1
        elif self.game.score > prevScore:
            self.staleCounter = 0
            return 1
        else:
            self.staleCounter += 1
            return 0

    def __enter__(self):
        self.load()
        return self

    def __exit__(self, *args: Any) -> None:
        self.close()

In [16]:
env = SnakeGameEnv()

## Testing the environment

In [17]:
from stable_baselines3.common.env_checker import check_env

In [18]:
check_env(env)

In [19]:
with env:
    clock = pygame.time.Clock()
    for i in range(5):
        obs = env.reset()
        done = False
        score = 0

        while not done:
            env.render()
            action = env.action_space.sample()
            obs, reward, done, info = env.step(action)
            score += reward
            clock.tick(30)

        print(f"Episode:{i + 1} Score:{score}")

Episode:1 Score:-1
Episode:2 Score:-1
Episode:3 Score:-1
Episode:4 Score:-1
Episode:5 Score:-1


# Train Model

## Using DQN with MlpPolicy

In [20]:
import os

from stable_baselines3.dqn.dqn import DQN
from stable_baselines3.dqn.policies import MlpPolicy

In [21]:
logPath = os.path.join("logs")

In [22]:
modelPath = os.path.join("saved_models", "DQN")

In [38]:
model = DQN(
    policy=MlpPolicy,
    env=env,
    learning_starts=50_000,
    learning_rate=0.000_1,
    gamma=0.99,
    verbose=1,
    tensorboard_log=logPath)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [51]:
model.learn(total_timesteps=200_000)

Logging to logs\DQN_4
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 18.8     |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.996    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 3133     |
|    time_elapsed     | 0        |
|    total_timesteps  | 75       |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 20.5     |
|    ep_rew_mean      | -1       |
|    exploration_rate | 0.992    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 4216     |
|    time_elapsed     | 0        |
|    total_timesteps  | 164      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 25.2     |
|    ep_rew_mean      | -0.917   |
|    exploration_rate | 0.986    |
| time/               |          

<stable_baselines3.dqn.dqn.DQN at 0x2a2a553c7f0>

In [52]:
model.save(modelPath)

In [26]:
del model

# Evaluation

In [54]:
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor

In [28]:
model = DQN.load(modelPath, env)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [55]:
with env:
    print(evaluate_policy(
        model,
        Monitor(env),
        n_eval_episodes=10,
        render=True))


(33.4, 9.86103442849684)


In [64]:
with env:
    clock = pygame.time.Clock()
    for i in range(5):
        obs = env.reset()
        done = False
        score = 0

        while not done:
            env.render()
            action, state = model.predict(obs)
            obs, reward, done, info = env.step(action) # type: ignore
            score += reward

        print(f"Episode:{i + 1} Score:{score}")

Episode:1 Score:7
Episode:2 Score:13
Episode:3 Score:18
Episode:4 Score:31
Episode:5 Score:21
