In [1]:
import random
import pygame

pygame 2.2.0 (SDL 2.32.50, Python 3.8.20)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [33]:
# 定义常量
SCREEN_WIDTH = 800
SCREEN_HEIGHT = 600
POP_SIZE = 50
BLOCK_SIZE = 20

# 定义颜色
BLACK = (0, 0, 0)
WHITE = (255, 255, 255)
GREEN = (0, 255, 0)
RED = (255, 0, 0)

### 实现 Snake 类

In [4]:
class Snake:
    def __init__(self):
        """初始化蛇"""
        self.length = 3
        # 初始化蛇身，从中心开始向左延伸
        self.positions = [
            (SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2),
            (SCREEN_WIDTH / 2 - BLOCK_SIZE, SCREEN_HEIGHT / 2),
            (SCREEN_WIDTH / 2 - 2 * BLOCK_SIZE, SCREEN_HEIGHT / 2)
        ]
        self.direction = random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)])
        self.color = GREEN
        self.is_alive = True

    def get_head_position(self):
        """获得蛇头的坐标"""
        return self.positions[0]

    def turn(self, point):
        """
        改变蛇移动方向
        防止180度反向移动
        """
        # 如果新方向与当前方向相反，则不改变
        if (point[0] * -1, point[1] * -1) == self.direction:
            return
        self.direction = point

    def move(self):
        """移动蛇身"""
        cur = self.get_head_position()
        x, y = self.direction
        new = (cur[0] + (x * BLOCK_SIZE), cur[1] + (y * BLOCK_SIZE))
        
        # 检查是否撞墙
        if (new[0] < 0 or new[0] >= SCREEN_WIDTH or 
            new[1] < 0 or new[1] >= SCREEN_HEIGHT):
            self.is_alive = False
            return

        """检查是否撞到自己"""
        head = self.get_head_position()
        if head in self.positions[1:]:
            self.is_alive = False
            return
        
        self.positions.insert(0, new)
        if len(self.positions) > self.length:
            self.positions.pop()

    def reset(self):
        """重新开始"""
        self.__init__()

    def draw(self, surface):
        """在画布上绘制蛇身"""
        for p in self.positions:
            r = pygame.Rect((p[0], p[1]), (BLOCK_SIZE, BLOCK_SIZE))
            pygame.draw.rect(surface, self.color, r)
            pygame.draw.rect(surface, BLACK, r, 1)  # 绘制边框

    def grow(self):
        """蛇身增长"""
        self.length += 1

    def check_is_alive(self):
        return self.is_alive

### 实现 Food 类

In [5]:
class Food:
    def __init__(self, snake_positions=None):
        """
        初始化食物
        snake_positions: 可选参数，蛇身位置列表，用于避免食物生成在蛇身上
        """
        self.color = RED
        self._generate_position(snake_positions)

    def _generate_position(self, snake_positions):
        """生成食物位置，确保不与蛇身重叠"""
        while True:
            x = random.randrange(0, SCREEN_WIDTH, BLOCK_SIZE)
            y = random.randrange(0, SCREEN_HEIGHT, BLOCK_SIZE)
            self.position = (x, y)
            # 如果没有传入蛇身位置，或位置不在蛇身上，则跳出循环
            if snake_positions is None or self.position not in snake_positions:
                break

    def get_position(self):
        """获得食物坐标"""
        return self.position

    def draw(self, surface):
        """在画布上绘制食物"""
        r = pygame.Rect((self.position[0], self.position[1]), (BLOCK_SIZE, BLOCK_SIZE))
        pygame.draw.rect(surface, self.color, r)
        pygame.draw.rect(surface, BLACK, r, 1)  # 绘制边框

    def respawn(self, snake_positions=None):
        """重新生成食物位置"""
        self._generate_position(snake_positions)

### 构建训练蛇的神经网络

#### Q 值是什么？
Q 值和 return 是强化学习中的两个重要概念，但它们并不相同。
- Q 值表示在某个状态下采取某个动作后，未来可能获得的总奖励的期望值（模型预测得到）。
- Return 是从当前时刻开始，未来所有奖励的总和（历史经验数据得到）。
- Q 值用于指导策略（选择动作），而 return 用于评估策略的性能。


#### 目标策略网络与当前策略模型的区别与联系，为什么要用两个模型？
在Deep Q Learning（DQN）中，使用两个神经网络模型是为了解决训练不稳定的问题。  
- 当前策略网络 (self.model)：用于预测当前状态的 Q 值，并通过优化器更新权重。
- 目标策略网络 (self.target_model)：用于计算下一个状态的 Q 值。
- 目标策略网络的权重是从当前策略网络定期同步的，因此目标 Q 值在一段时间内是稳定的，这有助于模型更好地收敛。

In [31]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque

class SnakeAI:
    def __init__(self, buffer_size=1000, batch_size=32):
        """
        初始化 SnakeAI 类。
        参数:
            buffer_size: 经验回放缓冲区的大小。
            batch_size: 每次训练时从缓冲区中采样的批次大小。
        """
        # 设置参数
        self.gamma = 0.99  # 折扣因子，用于计算未来奖励的重要性
        self.input_size = 12  # 输入状态的维度(Game类get_state()中定义)
        self.output_size = 4  # 输出动作的维度（上下左右四个方向）
        self.hidden_size = 100  # 神经网络隐藏层的大小
        self.batch_size = batch_size  # 训练批次大小
        self.update_freq = 100  # 目标网络更新频率
        self.train_steps = 0  # 训练步数计数器

        # 创建神经网络模型
        self.model = self.build_model()  # 当前策略网络，用于预测动作
        self.target_model = self.build_model()  # 目标策略网络，用于计算训练目标
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)  # 优化器，用于更新模型参数
        self.criterion = nn.MSELoss()  # 损失函数，用于计算预测值与目标值的差距

        # 经验回放缓冲区，用于存储游戏经验（状态、动作、奖励等）
        self.buffer = deque(maxlen=buffer_size)

        # 同步目标网络的权重，确保初始时目标网络与当前网络一致
        self.update_target_model()

    def build_model(self):
        """
        构建神经网络模型。
        返回:
            一个包含输入层、隐藏层和输出层的神经网络模型。
        """
        model = nn.Sequential(
            nn.Linear(self.input_size, self.hidden_size),  # 输入层到隐藏层
            nn.ReLU(),  # 激活函数
            nn.Linear(self.hidden_size, self.hidden_size),  # 隐藏层到隐藏层
            nn.ReLU(),  # 激活函数
            nn.Linear(self.hidden_size, self.hidden_size),  # 隐藏层到隐藏层
            nn.ReLU(),  # 激活函数
            nn.Linear(self.hidden_size, self.output_size),  # 隐藏层到输出层
        )
        return model

    def update_target_model(self):
        """
        更新目标策略网络的权重。
        将当前策略网络的权重复制到目标策略网络中。
        """
        self.target_model.load_state_dict(self.model.state_dict())

    def get_action(self, state, epsilon=0.01): # 贪吃蛇每次只需要找到局部最优解即可似乎没必要用epsilon-greedy
        """
        根据当前状态选择动作。
        使用 epsilon-greedy 策略，平衡探索与利用
        参数:
            state: 当前游戏状态。
            epsilon: 探索概率，用于控制随机探索与利用的平衡。
        返回:
            选择的动作。
        """
        if random.random() < epsilon:
            # 随机选择一个动作（探索）
            return random.randint(0, self.output_size - 1)
        else:
            # 使用模型预测动作（利用）
            state = torch.FloatTensor(state).unsqueeze(0)  # 将状态转换为张量
            with torch.no_grad():
                q_values = self.model(state)  # 获取 Q 值
            return torch.argmax(q_values).item()  # 选择 Q 值最大的动作

    def train_model(self):
        """
        使用经验回放进行模型训练。
        从缓冲区中随机采样一个批次的数据，计算损失并更新模型。
        """
        if len(self.buffer) < self.batch_size:
            return  # 如果缓冲区中的数据不足，则跳过训练

        # 从缓冲区中随机采样一个批次的数据
        batch = random.sample(self.buffer, self.batch_size)

        # 解析批次数据
        states = torch.FloatTensor([sample[0] for sample in batch])  # 当前状态
        actions = torch.LongTensor([sample[1] for sample in batch])  # 执行的动作
        rewards = torch.FloatTensor([sample[2] for sample in batch])  # 获得的奖励
        next_states = torch.FloatTensor([sample[3] for sample in batch])  # 下一个状态
        dones = torch.FloatTensor([sample[4] for sample in batch])  # 是否结束

        # 计算当前 Q 值
        current_q_values = self.model(states).gather(1, actions.unsqueeze(1))

        # 计算目标 Q 值
        with torch.no_grad():
            next_q_values = self.target_model(next_states).max(1)[0]
        target_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # 计算损失并更新模型
        loss = self.criterion(current_q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()  # 清空梯度
        loss.backward()  # 反向传播
        self.optimizer.step()  # 更新模型参数

        # 更新目标网络
        self.train_steps += 1
        if self.train_steps % self.update_freq == 0:
            self.update_target_model()

    def add_experience(self, state, action, reward, next_state, done):
        """
        将经验添加到经验回放缓冲区中。
        参数:
            state: 当前状态。
            action: 执行的动作。
            reward: 获得的奖励。
            next_state: 下一个状态。
            done: 是否结束。
        """
        self.buffer.append((state, action, reward, next_state, done))

### 主游戏逻辑

In [35]:
class Game:
    def __init__(self, buffer_size=10000, batch_size=64):
        """初始化游戏"""
        pygame.init()
        self.screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
        pygame.display.set_caption("Snake AI Training")
        self.clock = pygame.time.Clock()
        self.snake = Snake()
        self.food = Food(self.snake.positions)  # 确保食物不生成在蛇身上
        self.ai_player = SnakeAI(buffer_size=10000, batch_size=64)
        self.score = 0
        self.best_score = 0
        self.scores = []
        self.steps = 0 # 未吃到食物的累计步数
        # 初始化字体
        self.font = pygame.font.SysFont("Arial", 24)  # 使用 Arial 字体，大小 24

    def get_direction(self, action):
        """将动作索引转换为方向"""
        directions = [(0, -1), (0, 1), (-1, 0), (1, 0)]  # 上、下、左、右
        return directions[action]

    def get_state(self):
        """获取当前游戏状态"""
        head = self.snake.get_head_position()
        food = self.food.position

        left = (head[0] - BLOCK_SIZE, head[1])
        right = (head[0] + BLOCK_SIZE, head[1])
        up = (head[0], head[1] - BLOCK_SIZE)
        down = (head[0], head[1] + BLOCK_SIZE)

        # 检查边界
        danger_left = left[0] < 0 or left in self.snake.positions[1:]
        danger_right = right[0] >= SCREEN_WIDTH or right in self.snake.positions[1:]
        danger_up = up[1] < 0 or up in self.snake.positions[1:]
        danger_down = down[1] >= SCREEN_HEIGHT or down in self.snake.positions[1:]

        state = [
            danger_left, danger_right, danger_up, danger_down,  # 四个方向的危险
            food[0] < head[0], food[0] > head[0],              # 食物相对位置（左右）
            food[1] < head[1], food[1] > head[1],              # 食物相对位置（上下）
            self.snake.direction == (0, -1),                   # 当前方向（上）
            self.snake.direction == (0, 1),                    # 当前方向（下）
            self.snake.direction == (-1, 0),                   # 当前方向（左）
            self.snake.direction == (1, 0)                     # 当前方向（右）
        ]
        return np.array(state, dtype=np.float32)

    def update(self):
        """更新游戏状态和AI训练"""
        state = self.get_state()
        action = self.ai_player.get_action(state)
        old_direction = self.snake.direction
        new_direction = self.get_direction(action)
        
        # 更新蛇的方向
        self.snake.turn(new_direction)

        # 与食物的距离
        old_distance = np.sqrt(np.sum((np.array(self.snake.get_head_position()) - np.array(self.food.position)) ** 2))
        self.snake.move()
        
        # 检查游戏结束条件
        done = False
        reward = 0
        
        # 吃到食物
        if self.snake.get_head_position() == self.food.position:
            self.steps = 0 # 重置步数计数器
            self.score += 1
            self.snake.grow()
            self.food.respawn(self.snake.positions)  # 重新生成食物
            reward += 10
        # 撞墙或撞自己
        elif not self.snake.check_is_alive():
            self.scores.append(self.snake.length)
            done = True
            reward -= 20
        # 计算距离变化的奖励
        else:
            new_distance = np.sqrt(np.sum((np.array(self.snake.get_head_position()) - np.array(self.food.position)) ** 2))
            reward += 0.2 if new_distance < old_distance else -0.1
            if self.steps > 10:  # 长时间未吃到食物
                reward -= 0.1

        next_state = self.get_state()
        self.ai_player.add_experience(state, action, reward, next_state, done)
        self.ai_player.train_model() 
        
        return done

    def run(self):
        """主游戏循环"""
        for episode in range(POP_SIZE):
            self.snake.reset()
            self.food = Food(self.snake.positions)
            self.score = 0
            self.steps = 0
            done = False

            while not done:
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        pygame.quit()
                        return

                self.steps += 1
                done = self.update()
                
                # 渲染画面
                self.screen.fill(BLACK)
                self.snake.draw(self.screen)
                self.food.draw(self.screen)
                pygame.display.flip()
                # 渲染分数
                score_text = self.font.render(f"Score: {self.score}", True, WHITE)  # 白色文字
                self.screen.blit(score_text, (10, 10))  # 显示在左上角 (10, 10)
                pygame.display.flip()
                
                # 控制帧率（训练时可以更快，观看时调慢）
                # self.clock.tick(100 if episode < POP_SIZE - 1 else 10)
                self.clock.tick(80)

            # 更新最佳分数并保存模型
            if self.score > self.best_score:
                self.best_score = self.score
                torch.save(self.ai_player.model.state_dict(), '../models/5_RL_Snake/best_weights.pth')
            print(f"Episode {episode + 1}/{POP_SIZE}, Score: {self.score}, Best: {self.best_score}")

        pygame.quit()

if __name__ == "__main__":
    game = Game(buffer_size=10000, batch_size=64)
    game.run()

Episode 1/50, Score: 0, Best: 0
Episode 2/50, Score: 0, Best: 0
Episode 3/50, Score: 0, Best: 0
Episode 4/50, Score: 0, Best: 0
Episode 5/50, Score: 0, Best: 0
Episode 6/50, Score: 0, Best: 0
Episode 7/50, Score: 0, Best: 0
Episode 8/50, Score: 0, Best: 0
Episode 9/50, Score: 0, Best: 0
Episode 10/50, Score: 1, Best: 1
Episode 11/50, Score: 0, Best: 1
Episode 12/50, Score: 0, Best: 1
Episode 13/50, Score: 0, Best: 1
Episode 14/50, Score: 0, Best: 1
Episode 15/50, Score: 0, Best: 1
Episode 16/50, Score: 0, Best: 1
Episode 17/50, Score: 0, Best: 1
Episode 18/50, Score: 0, Best: 1
Episode 19/50, Score: 0, Best: 1
Episode 20/50, Score: 1, Best: 1
Episode 21/50, Score: 0, Best: 1
Episode 22/50, Score: 0, Best: 1
Episode 23/50, Score: 2, Best: 2
Episode 24/50, Score: 1, Best: 2
Episode 25/50, Score: 2, Best: 2
Episode 26/50, Score: 0, Best: 2
Episode 27/50, Score: 1, Best: 2
Episode 28/50, Score: 1, Best: 2
Episode 29/50, Score: 1, Best: 2
Episode 30/50, Score: 0, Best: 2
Episode 31/50, Scor