In [3]:
import gymnasium as gym
from gymnasium import Env
from gymnasium.spaces import Discrete, Box
import numpy as np
import random
import pygame
import sys
import os

from stable_baselines3 import DQN
from sb3_contrib import QRDQN



  from pkg_resources import resource_stream, resource_exists


In [4]:
class FlappyBirdEnv(Env):
    def __init__(self, render_mode=False):
        super().__init__()

        self.action_space = Discrete(2)
        self.observation_space = Box(
            low=np.array([0, -10, 0, 0], dtype=np.float32),
            high=np.array([512, 10, 400, 512], dtype=np.float32),
            dtype=np.float32
        )

        self.gravity = 1
        self.flap_velocity = -8
        self.gap_height = 150
        self.pipe_width = 80
        self.pipe_speed = 4
        self.screen_width = 400
        self.screen_height = 512

        self.render_mode = render_mode

        if self.render_mode:
            pygame.init()
            self.screen = pygame.display.set_mode((self.screen_width, self.screen_height))
            pygame.display.set_caption("Flappy Bird")
            self.clock = pygame.time.Clock()
            self.font = pygame.font.SysFont(None, 32)
            self.bird_img = pygame.image.load("bird.png").convert_alpha()
            self.bird_img = pygame.transform.scale(self.bird_img, (34, 24))

        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.bird_y = 250
        self.bird_velocity = 0
        self.pipe_x = self.screen_width
        self.pipe_gap_center = 200  
        self.score = 0
        self.done = False
        obs = np.array([self.bird_y, self.bird_velocity, self.pipe_x, self.pipe_gap_center], dtype=np.float32)
        return obs, {}

    def step(self, action):
        if action == 1:
            self.bird_velocity = self.flap_velocity
        else:
            self.bird_velocity += self.gravity

        self.bird_y += self.bird_velocity
        self.pipe_x -= self.pipe_speed

        if self.pipe_x < -self.pipe_width:
            self.pipe_x = self.screen_width
            self.pipe_gap_center = 200  
            self.score += 1

        terminated = False
        if self.bird_y < 0 or self.bird_y > self.screen_height:
            terminated = True
        elif self.pipe_x < 50 < self.pipe_x + self.pipe_width:
            gap_top = self.pipe_gap_center - self.gap_height / 2
            gap_bottom = self.pipe_gap_center + self.gap_height / 2
            if not (gap_top < self.bird_y < gap_bottom):
                terminated = True

        reward = 1 if not terminated else -100
        truncated = False
        obs = np.array([self.bird_y, self.bird_velocity, self.pipe_x, self.pipe_gap_center], dtype=np.float32)

        if self.render_mode:
            self.render()

        return obs, reward, terminated, truncated, {}

    def render(self):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()

        self.screen.fill((135, 206, 235))
        self.screen.blit(self.bird_img, (50, int(self.bird_y)))

        gap_top = self.pipe_gap_center - self.gap_height // 2
        gap_bottom = self.pipe_gap_center + self.gap_height // 2
        top_pipe = pygame.Rect(self.pipe_x, 0, self.pipe_width, gap_top)
        bottom_pipe = pygame.Rect(self.pipe_x, gap_bottom, self.pipe_width, self.screen_height)
        pygame.draw.rect(self.screen, (34, 139, 34), top_pipe)
        pygame.draw.rect(self.screen, (34, 139, 34), bottom_pipe)

        score_surface = self.font.render(f"Score: {self.score}", True, (0, 0, 0))
        self.screen.blit(score_surface, (10, 10))

        pygame.display.flip()
        self.clock.tick(30)

    def close(self):
        if self.render_mode:
            pygame.quit()

In [None]:
env = FlappyBirdEnv()


model = QRDQN(
    policy="MlpPolicy",
    env=env,
    learning_rate=5e-4,              
    buffer_size=50000,              
    learning_starts=1000,
    batch_size=64,                   
    tau=1.0,
    gamma=0.995,                    
    train_freq=1,
    target_update_interval=500,      
    exploration_fraction=0.05,       
    exploration_final_eps=0.01,      
    max_grad_norm=10,
    verbose=1
)

model.learn(total_timesteps=50_000)
model.save("FlappyDQN")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 36.2     |
|    ep_rew_mean      | -64.8    |
|    exploration_rate | 0.943    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 7989     |
|    time_elapsed     | 0        |
|    total_timesteps  | 145      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 36.9     |
|    ep_rew_mean      | -64.1    |
|    exploration_rate | 0.883    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 9482     |
|    time_elapsed     | 0        |
|    total_timesteps  | 295      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 37.4     |
|    ep_rew_mean      | -63.6  

In [7]:
model = QRDQN.load("FlappyDQN")
env = FlappyBirdEnv(render_mode=True)
obs, _ = env.reset()
done = False

while not done:
    
    action, _ = model.predict(obs, deterministic=True) 
    obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
env.close()   


KeyboardInterrupt: 