In [25]:
import pygame
import random
import numpy as np
from collections import defaultdict, deque
import pickle
import matplotlib.pyplot as plt

# 初始化参数
SCREEN_WIDTH = 800
SCREEN_HEIGHT = 600
PLAYER_SPEED = 10  
BULLET_SPEED = 20  
ENEMY_SPEED = 3
LEARNING_RATE = 0.1
DISCOUNT_FACTOR = 0.95
EPSILON = 0.5  # 初始探索率
EPSILON_DECAY = 0.995  # 探索率衰减
MIN_EPSILON = 0.01  # 最小探索率
INITIAL_LIVES = 3
MAX_ENEMIES = 3
ENEMY_HEALTH = 300  # 每次被击中减少 50，需要 6 次击中消灭
FRAME_RATE = 30

# 颜色定义
WHITE = (255, 255, 255)
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)
YELLOW = (255, 255, 0)
BLACK = (0, 0, 0)
GRAY = (200, 200, 200)

# 初始化 Pygame
pygame.init()
screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
pygame.display.set_caption("智能飞机大战-Q-learning版")
clock = pygame.time.Clock()
font = pygame.font.Font(None, 36)

class GameState:
    def __init__(self):
        self.reset()

    def reset(self):
        self.player = Player()
        self.enemies = pygame.sprite.Group()
        self.bullets = pygame.sprite.Group()
        self.enemy_bullets = pygame.sprite.Group()
        self.score = 0
        self.lives = INITIAL_LIVES
        self.start_time = pygame.time.get_ticks()
        self.running = False
        self.show_start_button = True
        self.last_enemy_spawn = pygame.time.get_ticks()
        self.score_effects = []  # 存储得分特效

    def add_score_effect(self, pos):
        """在指定位置添加得分特效"""
        self.score_effects.append({
            "pos": pos,
            "timer": 30  # 显示30帧
        })

class Player(pygame.sprite.Sprite):
    def __init__(self):
        super().__init__()
        self.image = pygame.Surface((50, 30))
        self.image.fill(GREEN)
        self.rect = self.image.get_rect(center=(SCREEN_WIDTH // 2, SCREEN_HEIGHT - 100))  # 初始位置更靠下
        self.bullet_timer = 0

    def update(self, action):
        max_y = SCREEN_HEIGHT - 80  # 限制在底部80像素区域
        min_y = SCREEN_HEIGHT // 2  # 从屏幕中间开始限制

        if action == 0:  # 上
            self.rect.y = max(min_y, self.rect.y - PLAYER_SPEED)
        elif action == 1:  # 下
            self.rect.y = min(max_y, self.rect.y + PLAYER_SPEED)
        elif action == 2:  # 左
            self.rect.x = max(0, self.rect.x - PLAYER_SPEED)
        elif action == 3:  # 右
            self.rect.x = min(SCREEN_WIDTH - 50, self.rect.x + PLAYER_SPEED)
        elif action == 4:  # 发射子弹
            if self.bullet_timer >= 10:  # 加快子弹发射频率
                self.bullet_timer = 0
                return True
        self.bullet_timer += 1
        return False

class Enemy(pygame.sprite.Sprite):
    def __init__(self):
        super().__init__()
        self.image = pygame.Surface((40, 40))
        self.image.fill(RED)
        self.rect = self.image.get_rect(center=(random.randint(30, SCREEN_WIDTH - 30), 10))
        self.speed = ENEMY_SPEED
        self.health = ENEMY_HEALTH  # 敌方血量

    def update(self):
        self.rect.y += self.speed
        if self.rect.top > SCREEN_HEIGHT:
            self.kill()

class Bullet(pygame.sprite.Sprite):
    def __init__(self, x, y, is_enemy=False):
        super().__init__()
        self.image = pygame.Surface((8, 20))  # 放大子弹尺寸
        self.image.fill(YELLOW if is_enemy else BLUE)
        self.rect = self.image.get_rect(center=(x, y - 20))  # 调整子弹生成位置
        self.speed = -BULLET_SPEED if not is_enemy else BULLET_SPEED  # 加快玩家子弹速度

    def update(self):
        self.rect.y += self.speed
        if self.rect.bottom < 0 or self.rect.top > SCREEN_HEIGHT:
            self.kill()

class DQNAgent:
    def __init__(self):
        self.action_dim = 5  # 动作维度：上、下、左、右、发射子弹
        self.q_table = defaultdict(lambda: np.zeros(self.action_dim))  # 使用字典存储 Q 值
        self.memory = deque(maxlen=2000)
        self.epsilon = EPSILON

    def get_state(self, game_state):
        # 状态包括：玩家位置、敌人数量、子弹数量、最近敌人的位置
        player_x, player_y = game_state.player.rect.center
        enemy_count = len(game_state.enemies)
        bullet_count = len(game_state.bullets)
        nearest_enemy_x = SCREEN_WIDTH // 2 if not game_state.enemies else min(
            [enemy.rect.centerx for enemy in game_state.enemies])
        nearest_enemy_y = SCREEN_HEIGHT // 2 if not game_state.enemies else min(
            [enemy.rect.centery for enemy in game_state.enemies])
        state = (
            player_x // 100,  # 离散化玩家 X 位置
            player_y // 100,  # 离散化玩家 Y 位置
            enemy_count,
            bullet_count,
            nearest_enemy_x // 100,
            nearest_enemy_y // 100
        )
        return state

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_dim)
        return np.argmax(self.q_table[state])

    def learn(self):
        if len(self.memory) < 100:  # 确保有足够的经验数据
            return
        for state, action, reward, next_state, done in self.memory:
            if done:
                target = reward
            else:
                target = reward + DISCOUNT_FACTOR * np.max(self.q_table[next_state])
            self.q_table[state][action] = (1 - LEARNING_RATE) * self.q_table[state][action] + LEARNING_RATE * target

    def decay_epsilon(self):
        """衰减探索率"""
        self.epsilon = max(MIN_EPSILON, self.epsilon * EPSILON_DECAY)

def create_button(text, x, y, w, h, ic, ac):
    """
    创建按钮并检测点击事件
    :param text: 按钮文本
    :param x: 按钮左上角 x 坐标
    :param y: 按钮左上角 y 坐标
    :param w: 按钮宽度
    :param h: 按钮高度
    :param ic: 按钮默认颜色
    :param ac: 按钮激活颜色
    :return: 是否被点击
    """
    mouse = pygame.mouse.get_pos()
    click = pygame.mouse.get_pressed()
    button = pygame.Rect(x, y, w, h)

    if button.collidepoint(mouse):
        pygame.draw.rect(screen, ac, button)
        if click[0] == 1:
            return True
    else:
        pygame.draw.rect(screen, ic, button)

    text_surf = font.render(text, True, WHITE)
    text_rect = text_surf.get_rect(center=button.center)
    screen.blit(text_surf, text_rect)
    return False

def game_loop(train_mode=True, num_episodes=1000):
    game_state = GameState()
    agent = DQNAgent()
    scores = []
    hit_rates = []
    survival_times = []

    for episode in range(num_episodes):
        game_state.reset()
        state = agent.get_state(game_state)
        total_reward = 0
        done = False
        bullets_fired = 0
        bullets_hit = 0
        start_time = pygame.time.get_ticks()

        while not done:
            # 选择动作
            action = agent.act(state)

            # 执行动作
            if game_state.player.update(action):
                game_state.bullets.add(Bullet(game_state.player.rect.centerx, game_state.player.rect.top))
                bullets_fired += 1

            # 敌人生成逻辑
            if len(game_state.enemies) < MAX_ENEMIES and pygame.time.get_ticks() - game_state.last_enemy_spawn > 3000:
                game_state.enemies.add(Enemy())
                game_state.last_enemy_spawn = pygame.time.get_ticks()

            game_state.enemies.update()
            game_state.bullets.update()
            game_state.enemy_bullets.update()

            # 碰撞检测
            reward = 0
            for bullet in game_state.bullets:
                enemies_hit = pygame.sprite.spritecollide(bullet, game_state.enemies, False)
                for enemy in enemies_hit:
                    enemy.health -= 50
                    bullet.kill()
                    if enemy.health <= 0:
                        enemy.kill()
                        game_state.score += 100
                        reward = 20  # 击中敌机奖励
                        bullets_hit += 1

            for enemy in game_state.enemies:
                if enemy.rect.colliderect(game_state.player.rect):
                    game_state.lives -= 1
                    enemy.kill()
                    reward = -10  # 碰撞惩罚
                    if game_state.lives <= 0:
                        done = True

            # 获取下一个状态
            next_state = agent.get_state(game_state)

            # 存储经验并学习
            agent.remember(state, action, reward, next_state, done)
            if train_mode:
                agent.learn()

            total_reward += reward
            state = next_state

            # 渲染游戏画面（仅在非训练模式或需要可视化时）
            if not train_mode or episode % 100 == 0:
                screen.fill(BLACK)
                screen.blit(game_state.player.image, game_state.player.rect)
                game_state.enemies.draw(screen)
                game_state.bullets.draw(screen)
                pygame.display.flip()
                clock.tick(FRAME_RATE)

        # 计算命中率和存活时间
        hit_rate = bullets_hit / bullets_fired if bullets_fired > 0 else 0
        survival_time = (pygame.time.get_ticks() - start_time) / 1000  # 秒

        scores.append(game_state.score)
        hit_rates.append(hit_rate)
        survival_times.append(survival_time)

        print(f"Episode {episode + 1}, Total Reward: {total_reward}, Score: {game_state.score}, "
              f"Hit Rate: {hit_rate:.2f}, Survival Time: {survival_time:.2f}s")

        # 衰减探索率
        if train_mode:
            agent.decay_epsilon()

        # 每 100 条 Episode 保存一次 Q 表
        if train_mode and (episode + 1) % 100 == 0:
            with open("q_table.pkl", "wb") as f:
                pickle.dump(agent.q_table, f)
            print(f"Q 表已保存为 q_table.pkl（Episode {episode + 1}）")

    # 保存训练结果
    if train_mode:
        with open("q_table.pkl", "wb") as f:
            pickle.dump(agent.q_table, f)

    # 可视化训练结果
    plot_training_results(scores, hit_rates, survival_times)

def plot_training_results(scores, hit_rates, survival_times):
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.plot(scores)
    plt.title("Score per Episode")
    plt.xlabel("Episode")
    plt.ylabel("Score")

    plt.subplot(1, 3, 2)
    plt.plot(hit_rates)
    plt.title("Hit Rate per Episode")
    plt.xlabel("Episode")
    plt.ylabel("Hit Rate")

    plt.subplot(1, 3, 3)
    plt.plot(survival_times)
    plt.title("Survival Time per Episode")
    plt.xlabel("Episode")
    plt.ylabel("Survival Time")

    plt.tight_layout()
    plt.show()

def load_q_table(filename="q_table.pkl"):
    with open(filename, "rb") as f:
        return pickle.load(f)

if __name__ == "__main__":
    try:
        # 训练模式
        game_loop(train_mode=True, num_episodes=1000)

        # 测试模式
        agent = DQNAgent()
        agent.q_table = load_q_table()
        game_loop(train_mode=False)
    except KeyboardInterrupt:
        print("游戏已手动中断")
    finally:
        pygame.quit()

Episode 1, Total Reward: -30, Score: 0, Hit Rate: 0.00, Survival Time: 27.26s
Episode 2, Total Reward: -30, Score: 0, Hit Rate: 0.00, Survival Time: 30.96s
游戏已手动中断


In [20]:
pip install matplotlib

Defaulting to user installation because normal site-packages is not writeable
Collecting matplotlib
  Downloading matplotlib-3.9.4-cp39-cp39-macosx_10_12_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 2.1 MB/s eta 0:00:01
[?25hCollecting pyparsing>=2.3.1
  Downloading pyparsing-3.2.1-py3-none-any.whl (107 kB)
[K     |████████████████████████████████| 107 kB 6.6 MB/s eta 0:00:01
[?25hCollecting importlib-resources>=3.2.0
  Downloading importlib_resources-6.5.2-py3-none-any.whl (37 kB)
Collecting pillow>=8
  Downloading pillow-11.1.0-cp39-cp39-macosx_10_10_x86_64.whl (3.2 MB)
[K     |████████████████████████████████| 3.2 MB 5.1 MB/s eta 0:00:01
[?25hCollecting fonttools>=4.22.0
  Downloading fonttools-4.56.0-cp39-cp39-macosx_10_9_x86_64.whl (2.3 MB)
[K     |████████████████████████████████| 2.3 MB 969 kB/s eta 0:00:01
Collecting contourpy>=1.0.1
  Downloading contourpy-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl (265 kB)
[K     |████████████████████████████████| 

In [4]:
pip install numpy

Defaulting to user installation because normal site-packages is not writeable
Collecting numpy
  Downloading numpy-2.0.2-cp39-cp39-macosx_14_0_x86_64.whl (6.9 MB)
[K     |████████████████████████████████| 6.9 MB 555 kB/s eta 0:00:01
[?25hInstalling collected packages: numpy
Successfully installed numpy-2.0.2
You should consider upgrading via the '/Applications/Xcode.app/Contents/Developer/usr/bin/python3 -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install pygame

Defaulting to user installation because normal site-packages is not writeable
Collecting pygame
  Using cached pygame-2.6.1-cp39-cp39-macosx_10_9_x86_64.whl (13.0 MB)
Installing collected packages: pygame
Successfully installed pygame-2.6.1
You should consider upgrading via the '/Applications/Xcode.app/Contents/Developer/usr/bin/python3 -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.
