# Flappy Bird Game
The aim of the game is to remain alive as long as possible. The game ends when the bird touches the floor or a pipe. So, the bird needs to flap its wings at the right times to get through the random pipes and to avoid falling to the ground. Possible actions include flapping and not flapping. In the game environment, the reward is +0.1 for every step, with the following two exceptions:


1.   -1 when a collision occurs
2.   +1 when the bird gets through the gap between two pipes. The original Flappy
Bird game is scored based on the number of gaps passed through.



In [1]:
!pip install pygame

Collecting pygame
[?25l  Downloading https://files.pythonhosted.org/packages/8e/24/ede6428359f913ed9cd1643dd5533aefeb5a2699cc95bea089de50ead586/pygame-1.9.6-cp36-cp36m-manylinux1_x86_64.whl (11.4MB)
[K     |████████████████████████████████| 11.4MB 4.6MB/s 
[?25hInstalling collected packages: pygame
Successfully installed pygame-1.9.6


# Utility Function

In [2]:
import cv2
import numpy as np
from pygame.image import load
from pygame.surfarray import pixels_alpha
from pygame.transform import rotate

pygame 1.9.6
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
def load_images(sprites_path):
    base_image = load(sprites_path + 'base.png').convert_alpha()
    background_image = load(sprites_path + 'background-black.png').convert()
    pipe_images = [rotate(load(sprites_path + 'pipe-green.png').convert_alpha(), 180),
                   load(sprites_path + 'pipe-green.png').convert_alpha()]
    bird_images = [load(sprites_path + 'redbird-upflap.png').convert_alpha(),
                   load(sprites_path + 'redbird-midflap.png').convert_alpha(),
                   load(sprites_path + 'redbird-downflap.png').convert_alpha()]
    bird_hitmask = [pixels_alpha(image).astype(bool) for image in bird_images]
    pipe_hitmask = [pixels_alpha(image).astype(bool) for image in pipe_images]
    return base_image, background_image, pipe_images, bird_images, bird_hitmask, pipe_hitmask


In [4]:
def pre_processing(image, width, height):
    image = cv2.cvtColor(cv2.resize(image, (width, height)), cv2.COLOR_BGR2GRAY)
    _, image = cv2.threshold(image, 1, 255, cv2.THRESH_BINARY)
    return image[None, :, :].astype(np.float32)

# Flappy bird env

In [7]:
from itertools import cycle
from random import randint
import pygame

In [9]:
import os
os.environ["SDL_VIDEODRIVER"] = "dummy"

In [11]:
!mkdir sprites

In [12]:
pygame.init()
fps = 30
fps_clock = pygame.time.Clock()

screen_width = 288
screen_height = 512
screen = pygame.display.set_mode((screen_width, screen_height))

pygame.display.set_caption('Flappy Bird')

base_image, background_image, pipe_images, bird_images, bird_hitmask, pipe_hitmask = load_images('sprites/')

bird_width = bird_images[0].get_width()
bird_height = bird_images[0].get_height()
pipe_width = pipe_images[0].get_width()
pipe_height = pipe_images[0].get_height()



pipe_gap_size = 100
bird_index_gen = cycle([0, 1, 2, 1])

In [13]:
class FlappyBird(object):
    def __init__(self):
        self.pipe_vel_x = -4
        self.min_velocity_y = -8
        self.max_velocity_y = 10
        self.downward_speed = 1
        self.upward_speed = -9
        self.cur_velocity_y = 0
        self.iter = self.bird_index = self.score = 0
        self.bird_x = int(screen_width / 5)
        self.bird_y = int((screen_height - bird_height) / 2)
        self.base_x = 0
        self.base_y = screen_height * 0.79
        self.base_shift = base_image.get_width() - background_image.get_width()
        self.pipes = [self.gen_random_pipe(screen_width), self.gen_random_pipe(screen_width * 1.5)]
        self.is_flapped = False


    def gen_random_pipe(self, x):
        gap_y = randint(2, 10) * 10 + int(self.base_y * 0.2)
        return {"x_upper": x,
                "y_upper": gap_y - pipe_height,
                "x_lower": x,
                "y_lower": gap_y + pipe_gap_size}

    def check_collision(self):
        if bird_height + self.bird_y >= self.base_y - 1:
            return True
        bird_rect = pygame.Rect(self.bird_x, self.bird_y, bird_width, bird_height)
        for pipe in self.pipes:
            pipe_boxes = [pygame.Rect(pipe["x_upper"], pipe["y_upper"], pipe_width, pipe_height),
                          pygame.Rect(pipe["x_lower"], pipe["y_lower"], pipe_width, pipe_height)]
            # Check if the bird's bounding box overlaps to the bounding box of any pipe
            if bird_rect.collidelist(pipe_boxes) == -1:
                return False
            for i in range(2):
                cropped_bbox = bird_rect.clip(pipe_boxes[i])
                x1 = cropped_bbox.x - bird_rect.x
                y1 = cropped_bbox.y - bird_rect.y
                x2 = cropped_bbox.x - pipe_boxes[i].x
                y2 = cropped_bbox.y - pipe_boxes[i].y
                for x in range(cropped_bbox.width):
                    for y in range(cropped_bbox.height):
                        if bird_hitmask[self.bird_index][x1+x, y1+y] and pipe_hitmask[i][x2+x, y2+y]:
                            return True
        return False

    def next_step(self, action):
        pygame.event.pump()
        reward = 0.1
        if action == 1:
            self.cur_velocity_y = self.upward_speed
            self.is_flapped = True
        # Update score
        bird_center_x = self.bird_x + bird_width / 2
        for pipe in self.pipes:
            pipe_center_x = pipe["x_upper"] + pipe_width / 2
            if pipe_center_x < bird_center_x < pipe_center_x + 5:
                self.score += 1
                reward = 1
                break
        # Update index and iteration
        if (self.iter + 1) % 3 == 0:
            self.bird_index = next(bird_index_gen)
        self.iter = (self.iter + 1) % fps
        self.base_x = -((-self.base_x + 100) % self.base_shift)
        # Update bird's position
        if self.cur_velocity_y < self.max_velocity_y and not self.is_flapped:
            self.cur_velocity_y += self.downward_speed
        self.is_flapped = False
        self.bird_y += min(self.cur_velocity_y, self.bird_y - self.cur_velocity_y - bird_height)
        if self.bird_y < 0:
            self.bird_y = 0
        # Update pipe position
        for pipe in self.pipes:
            pipe["x_upper"] += self.pipe_vel_x
            pipe["x_lower"] += self.pipe_vel_x
        #  Add new pipe when first pipe is about to touch left of screen
        if 0 < self.pipes[0]["x_lower"] < 5:
            self.pipes.append(self.gen_random_pipe(screen_width + 10))
        # remove first pipe if its out of the screen
        if self.pipes[0]["x_lower"] < -pipe_width:
            self.pipes.pop(0)
        if self.check_collision():
            is_done = True
            reward = -1
            self.__init__()
        else:
            is_done = False
        # Draw sprites
        screen.blit(background_image, (0, 0))
        screen.blit(base_image, (self.base_x, self.base_y))
        screen.blit(bird_images[self.bird_index], (self.bird_x, self.bird_y))
        for pipe in self.pipes:
            screen.blit(pipe_images[0], (pipe["x_upper"], pipe["y_upper"]))
            screen.blit(pipe_images[1], (pipe["x_lower"], pipe["y_lower"]))
        image = pygame.surfarray.array3d(pygame.display.get_surface())
        pygame.display.update()
        fps_clock.tick(fps)
        return image, reward, is_done

# Deep Q-network

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random

In [15]:
class DQNModel(nn.Module):
    def __init__(self, n_action=2):
        super(DQNModel, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, 3, stride=1)
        self.fc = nn.Linear(7 * 7 * 64, 512)
        self.out = nn.Linear(512, n_action)
        self._create_weights()

    def _create_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.uniform(m.weight, -0.01, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        output = self.out(x)
        return output

In [16]:
class DQN():
    def __init__(self, n_action, lr=1e-6):
        self.criterion = nn.MSELoss()
        self.model = DQNModel(n_action)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr)

    def update(self, y_predict, y_target):
        """
        Update the weights of the DQN given a training sample
        @param y_predict:
        @param y_target:
        @return:
        """
        loss = self.criterion(y_predict, y_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss

    def predict(self, s):
        """
        Compute the Q values of the state for all actions using the learning model
        @param s: input state
        @return: Q values of the state for all actions
        """
        return self.model(torch.Tensor(s))

    def replay(self, memory, replay_size, gamma):
        """
        Experience replay
        @param memory: a list of experience
        @param replay_size: the number of samples we use to update the model each time
        @param gamma: the discount factor
        @return: the loss
        """
        if len(memory) >= replay_size:
            replay_data = random.sample(memory, replay_size)
            state_batch, action_batch, next_state_batch, reward_batch, done_batch = zip(*replay_data)

            state_batch = torch.cat(tuple(state for state in state_batch))
            next_state_batch = torch.cat(tuple(state for state in next_state_batch))
            q_values_batch = self.predict(state_batch)
            q_values_next_batch = self.predict(next_state_batch)

            reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])

            action_batch = torch.from_numpy(
                np.array([[1, 0] if action == 0 else [0, 1] for action in action_batch], dtype=np.float32))

            q_value = torch.sum(q_values_batch * action_batch, dim=1)

            td_targets = torch.cat(
                tuple(reward if terminal else reward + gamma * torch.max(prediction) for reward, terminal, prediction
                    in zip(reward_batch, done_batch, q_values_next_batch)))

            loss = self.update(q_value, td_targets)
            return loss

# Training the network

In [None]:
import random
import torch
from collections import deque





def gen_epsilon_greedy_policy(estimator, epsilon, n_action):
    def policy_function(state):
        if random.random() < epsilon:
            return random.randint(0, n_action - 1)
        else:
            q_values = estimator.predict(state)
            return torch.argmax(q_values).item()
    return policy_function


image_size = 84
batch_size = 32
lr = 1e-6
gamma = 0.99
init_epsilon = 0.1
final_epsilon = 1e-4
n_iter = 2000000
memory_size = 50000
saved_path = 'trained_models'
n_action = 2


torch.manual_seed(123)

estimator = DQN(n_action)

memory = deque(maxlen=memory_size)

env = FlappyBird()
image, reward, is_done = env.next_step(0)
image = pre_processing(image[:screen_width, :int(env.base_y)], image_size, image_size)
image = torch.from_numpy(image)
state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]

for iter in range(n_iter):
    epsilon = final_epsilon + (n_iter - iter) * (init_epsilon - final_epsilon) / n_iter
    policy = gen_epsilon_greedy_policy(estimator, epsilon, n_action)
    action = policy(state)
    next_image, reward, is_done = env.next_step(action)
    next_image = pre_processing(next_image[:screen_width, :int(env.base_y)], image_size, image_size)
    next_image = torch.from_numpy(next_image)
    next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]
    memory.append([state, action, next_state, reward, is_done])
    loss = estimator.replay(memory, batch_size, gamma)
    state = next_state
    print("Iteration: {}/{}, Action: {}, Loss: {}, Epsilon {}, Reward: {}".format(
            iter + 1, n_iter, action, loss, epsilon, reward))
    if iter+1 % 10000 == 0:
        torch.save(estimator.model, "{}/{}".format(saved_path, iter+1))

torch.save(estimator.model, "{}/final".format(saved_path))

  


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Iteration: 9613/2000000, Action: 0, Loss: 0.040831658989191055, Epsilon 0.09951988060000001, Reward: 0.1
Iteration: 9614/2000000, Action: 1, Loss: 0.010003555566072464, Epsilon 0.09951983065, Reward: 0.1
Iteration: 9615/2000000, Action: 1, Loss: 0.040831584483385086, Epsilon 0.0995197807, Reward: 0.1
Iteration: 9616/2000000, Action: 0, Loss: 0.07165955752134323, Epsilon 0.09951973075000001, Reward: 0.1
Iteration: 9617/2000000, Action: 0, Loss: 0.040831536054611206, Epsilon 0.09951968080000001, Reward: 0.1
Iteration: 9618/2000000, Action: 0, Loss: 0.07165954262018204, Epsilon 0.09951963085, Reward: 0.1
Iteration: 9619/2000000, Action: 0, Loss: 0.07165947556495667, Epsilon 0.0995195809, Reward: 0.1
Iteration: 9620/2000000, Action: 0, Loss: 0.010003520175814629, Epsilon 0.09951953095, Reward: 0.1
Iteration: 9621/2000000, Action: 1, Loss: 0.04083149507641792, Epsilon 0.099519481, Reward: 0.1
Iteration: 9622/2000000, Action: 1

In [None]:
import torch

saved_path = 'trained_models'
model = torch.load("{}/final".format(saved_path))


image_size = 84
n_episode = 100



for episode in range(n_episode):
    env = FlappyBird()
    image, reward, is_done = env.next_step(0)
    image = pre_processing(image[:screen_width, :int(env.base_y)], image_size, image_size)
    image = torch.from_numpy(image)
    state = torch.cat(tuple(image for _ in range(4)))[None, :, :, :]

    while True:
        prediction = model(state)[0]
        action = torch.argmax(prediction).item()

        next_image, reward, is_done = env.next_step(action)

        if is_done:
            break

        next_image = pre_processing(next_image[:screen_width, :int(env.base_y)], image_size, image_size)
        next_image = torch.from_numpy(next_image)
        next_state = torch.cat((state[0, 1:, :, :], next_image))[None, :, :, :]

        state = next_state