In [2]:
%load_ext autoreload
%autoreload 2

import os
import time
import numpy as np
import torch
from tqdm import trange

from game.flappy_bird import FlappyBirdEnv
from agent.dqn_agent import DQNAgent, DQNConfig

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
pygame 2.6.1 (SDL 2.28.4, Python 3.12.6)
Hello from the pygame community. https://www.pygame.org/contribute.html


  from pkg_resources import resource_stream, resource_exists


In [None]:

env = FlappyBirdEnv(seed=42)
state_dim = 5
action_dim = 2

cfg = DQNConfig(
	state_dim=state_dim,
	action_dim=action_dim,
	gamma=0.99,
	lr=1e-3,
	batch_size=64,
	replay_size=50_000,
	start_learning_after=5_000,
	target_update_freq=1_000,
	eps_start=1.0,
	eps_end=0.05,
	eps_decay_steps=50_000,
	gradient_clip_norm=5.0,
)

agent = DQNAgent(cfg)
print("Using device:", agent.cfg.device)

Using device: cpu


In [4]:
num_episodes = 500  
max_steps_per_ep = 10_000

moving_avg_window = 50
ep_returns = []
losses = []

for ep in trange(num_episodes, desc="Training"):
	state = env.reset()
	ep_return = 0.0

	for t in range(max_steps_per_ep):
		action = agent.select_action(state)
		next_state, reward, done, info = env.step(action)

		agent.store(state, action, reward, next_state, done)
		loss = agent.train_step()
		if loss is not None:
			losses.append(loss)

		ep_return += reward
		state = next_state
		if done:
			break

	ep_returns.append(ep_return)

	# Display simple stats every few episodes
	if (ep + 1) % 10 == 0:
		recent = ep_returns[-moving_avg_window:]
		mavg = np.mean(recent) if recent else 0.0
		print(f"Ep {ep+1}/{num_episodes} | Return: {ep_return:.1f} | 50-ep avg: {mavg:.1f} | Eps: {agent.epsilon:.3f}")

print("Training finished.")

Training:  15%|█▍        | 73/500 [00:00<00:00, 728.84it/s]

Ep 10/500 | Return: -67.0 | 50-ep avg: -63.8 | Eps: 0.993
Ep 20/500 | Return: -66.0 | 50-ep avg: -64.1 | Eps: 0.986
Ep 30/500 | Return: -67.0 | 50-ep avg: -64.5 | Eps: 0.979
Ep 40/500 | Return: -66.0 | 50-ep avg: -64.7 | Eps: 0.972
Ep 50/500 | Return: -61.0 | 50-ep avg: -64.4 | Eps: 0.965
Ep 60/500 | Return: -66.0 | 50-ep avg: -64.6 | Eps: 0.958
Ep 70/500 | Return: -66.0 | 50-ep avg: -64.7 | Eps: 0.951
Ep 80/500 | Return: -62.0 | 50-ep avg: -64.4 | Eps: 0.944
Ep 90/500 | Return: -59.0 | 50-ep avg: -64.1 | Eps: 0.937
Ep 100/500 | Return: -64.0 | 50-ep avg: -64.3 | Eps: 0.930
Ep 110/500 | Return: -62.0 | 50-ep avg: -64.1 | Eps: 0.923
Ep 120/500 | Return: -66.0 | 50-ep avg: -63.7 | Eps: 0.916
Ep 130/500 | Return: -62.0 | 50-ep avg: -63.8 | Eps: 0.909
Ep 140/500 | Return: -65.0 | 50-ep avg: -64.0 | Eps: 0.902


Training:  29%|██▉       | 146/500 [00:01<00:04, 84.26it/s]

Ep 150/500 | Return: -64.0 | 50-ep avg: -64.1 | Eps: 0.895
Ep 160/500 | Return: -65.0 | 50-ep avg: -64.0 | Eps: 0.888
Ep 170/500 | Return: -65.0 | 50-ep avg: -64.3 | Eps: 0.881


Training:  36%|███▌      | 180/500 [00:05<00:13, 24.01it/s]

Ep 180/500 | Return: -65.0 | 50-ep avg: -64.4 | Eps: 0.874
Ep 190/500 | Return: -64.0 | 50-ep avg: -64.6 | Eps: 0.868


Training:  40%|███▉      | 199/500 [00:07<00:16, 17.98it/s]

Ep 200/500 | Return: -66.0 | 50-ep avg: -64.7 | Eps: 0.861


Training:  42%|████▏     | 211/500 [00:09<00:18, 15.78it/s]

Ep 210/500 | Return: -66.0 | 50-ep avg: -65.1 | Eps: 0.854


Training:  44%|████▍     | 219/500 [00:10<00:19, 14.15it/s]

Ep 220/500 | Return: -64.0 | 50-ep avg: -65.1 | Eps: 0.847


Training:  46%|████▋     | 232/500 [00:11<00:22, 12.04it/s]

Ep 230/500 | Return: -59.0 | 50-ep avg: -64.9 | Eps: 0.840


Training:  48%|████▊     | 241/500 [00:13<00:40,  6.43it/s]

Ep 240/500 | Return: -62.0 | 50-ep avg: -64.8 | Eps: 0.833


Training:  50%|█████     | 250/500 [00:15<00:39,  6.28it/s]

Ep 250/500 | Return: -64.0 | 50-ep avg: -64.8 | Eps: 0.826


Training:  52%|█████▏    | 261/500 [00:17<00:32,  7.32it/s]

Ep 260/500 | Return: -57.0 | 50-ep avg: -64.3 | Eps: 0.819


Training:  54%|█████▍    | 270/500 [00:20<01:16,  3.00it/s]

Ep 270/500 | Return: -65.0 | 50-ep avg: -63.9 | Eps: 0.812


Training:  56%|█████▌    | 280/500 [00:23<01:19,  2.76it/s]

Ep 280/500 | Return: -63.0 | 50-ep avg: -63.3 | Eps: 0.804


Training:  58%|█████▊    | 290/500 [00:27<01:08,  3.09it/s]

Ep 290/500 | Return: -66.0 | 50-ep avg: -63.1 | Eps: 0.797


Training:  60%|██████    | 300/500 [00:30<01:10,  2.82it/s]

Ep 300/500 | Return: -61.0 | 50-ep avg: -62.6 | Eps: 0.790


Training:  62%|██████▏   | 310/500 [00:34<01:11,  2.66it/s]

Ep 310/500 | Return: -64.0 | 50-ep avg: -62.5 | Eps: 0.783


Training:  64%|██████▍   | 320/500 [00:38<01:05,  2.75it/s]

Ep 320/500 | Return: -66.0 | 50-ep avg: -62.5 | Eps: 0.775


Training:  66%|██████▌   | 330/500 [00:41<01:03,  2.67it/s]

Ep 330/500 | Return: -66.0 | 50-ep avg: -63.2 | Eps: 0.768


Training:  68%|██████▊   | 340/500 [00:45<00:58,  2.76it/s]

Ep 340/500 | Return: -64.0 | 50-ep avg: -63.2 | Eps: 0.761


Training:  70%|███████   | 351/500 [00:47<00:22,  6.63it/s]

Ep 350/500 | Return: -60.0 | 50-ep avg: -63.7 | Eps: 0.754


Training:  72%|███████▏  | 361/500 [00:48<00:18,  7.43it/s]

Ep 360/500 | Return: -61.0 | 50-ep avg: -63.4 | Eps: 0.747


Training:  74%|███████▍  | 370/500 [00:50<00:18,  7.01it/s]

Ep 370/500 | Return: -66.0 | 50-ep avg: -63.5 | Eps: 0.740


Training:  76%|███████▌  | 381/500 [00:52<00:24,  4.86it/s]

Ep 380/500 | Return: -63.0 | 50-ep avg: -63.2 | Eps: 0.732


Training:  78%|███████▊  | 390/500 [00:54<00:25,  4.30it/s]

Ep 390/500 | Return: -54.0 | 50-ep avg: -62.3 | Eps: 0.725


Training:  80%|████████  | 400/500 [00:57<00:34,  2.94it/s]

Ep 400/500 | Return: -63.0 | 50-ep avg: -61.8 | Eps: 0.717


Training:  82%|████████▏ | 410/500 [01:01<00:34,  2.58it/s]

Ep 410/500 | Return: -66.0 | 50-ep avg: -61.8 | Eps: 0.710


Training:  84%|████████▍ | 420/500 [01:05<00:31,  2.55it/s]

Ep 420/500 | Return: -65.0 | 50-ep avg: -61.6 | Eps: 0.702


Training:  86%|████████▌ | 430/500 [01:09<00:26,  2.63it/s]

Ep 430/500 | Return: -64.0 | 50-ep avg: -61.7 | Eps: 0.695


Training:  88%|████████▊ | 440/500 [01:13<00:23,  2.57it/s]

Ep 440/500 | Return: -59.0 | 50-ep avg: -62.3 | Eps: 0.688


Training:  90%|█████████ | 450/500 [01:17<00:19,  2.54it/s]

Ep 450/500 | Return: -64.0 | 50-ep avg: -62.4 | Eps: 0.680


Training:  92%|█████████▏| 460/500 [01:21<00:15,  2.53it/s]

Ep 460/500 | Return: -63.0 | 50-ep avg: -62.4 | Eps: 0.673


Training:  94%|█████████▍| 470/500 [01:25<00:12,  2.44it/s]

Ep 470/500 | Return: -65.0 | 50-ep avg: -62.6 | Eps: 0.666


Training:  96%|█████████▌| 480/500 [01:29<00:08,  2.38it/s]

Ep 480/500 | Return: -56.0 | 50-ep avg: -62.0 | Eps: 0.658


Training:  98%|█████████▊| 490/500 [01:33<00:04,  2.35it/s]

Ep 490/500 | Return: -62.0 | 50-ep avg: -61.7 | Eps: 0.651


Training: 100%|██████████| 500/500 [01:38<00:00,  5.10it/s]

Ep 500/500 | Return: -60.0 | 50-ep avg: -61.3 | Eps: 0.643
Training finished.





In [5]:
save_path = "agent/flappy_dqn.pth"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
agent.save(save_path)
print("Saved:", save_path)

Saved: agent/flappy_dqn.pth


In [6]:

env.close()
state = env.reset()

total_reward = 0.0
for _ in range(5000):
	# Greedy action for evaluation (no epsilon)
	with torch.no_grad():
		state_t = torch.from_numpy(state).float().unsqueeze(0).to(agent.cfg.device)
		q_vals = agent.q_net(state_t)
		action = int(torch.argmax(q_vals, dim=1).item())

	next_state, reward, done, info = env.step(action)
	env.render(fps=60)
	total_reward += reward
	state = next_state
	if done:
		break

print("Eval episode reward:", total_reward)
env.close()

Eval episode reward: -72.0
