# import neccesery

In [1]:
import pygame
import random
from enum import Enum
from collections import namedtuple
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import math

from Snake_Env import Env
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

pygame 2.6.0 (SDL 2.28.4, Python 3.10.11)
Hello from the pygame community. https://www.pygame.org/contribute.html


# Create Env

In [2]:
import gym
from gym import spaces
import numpy as np

class SnakeEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, w=640, h=480):
        super(SnakeEnv, self).__init__()
        self.game = Env(w=w, h=h)

        # 定義 action space（3個動作：直走、右轉、左轉）
        self.action_space = spaces.Discrete(3)

        # 定義 observation space（長度11的state，每個是0或1）
        self.observation_space = spaces.Box(low=0, high=1, shape=(11,), dtype=np.int32)

    def reset(self):
        state = self.game.reset()
        return state

    def step(self, action_idx):
        """
        action_idx: 0=直走, 1=右轉, 2=左轉
        轉成 one-hot vector 給 SnakeGameAI
        """
        action = [0, 0, 0]
        action[action_idx] = 1
        state, _, reward, next_state, done = self.game.play_step(action)
        info = {"score": self.game.score}
        return next_state, reward, done, info

    def render(self, mode='human'):
        self.game.render()

    def close(self):
        pygame.quit()

# Replay Buffer

In [3]:
class ReplayBuffer:
    def __init__(self, max_size=10000):
        self.s = np.zeros((max_size, 11), dtype=np.float32)
        self.a = np.zeros((max_size,), dtype=np.int64)
        self.r = np.zeros((max_size,), dtype=np.float32)
        self.s_ = np.zeros((max_size, 11), dtype=np.float32)
        self.done = np.zeros((max_size,), dtype=np.float32)
        self.ptr = 0
        self.size = 0
        self.max_size = max_size

    def append(self, s, a, r, s_, done):
        self.s[self.ptr] = s
        self.a[self.ptr] = a
        self.r[self.ptr] = r
        self.s_[self.ptr] = s_
        self.done[self.ptr] = done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        idx = np.random.choice(self.size, batch_size, replace=False)
        return (
            torch.FloatTensor(self.s[idx]),
            torch.LongTensor(self.a[idx]),
            torch.FloatTensor(self.r[idx]),
            torch.FloatTensor(self.s_[idx]),
            torch.FloatTensor(self.done[idx])
        )


# DQN model

In [None]:
class DQN(torch.nn.Module):
    def __init__(self, input_dim=11, output_dim=3):
        super(DQN, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, 128)
        self.fc2 = torch.nn.Linear(128, 64)
        self.out = torch.nn.Linear(64, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

# Agent

In [None]:
import os
folder = "DQN_model/"

if not os.path.exists(folder):
    os.makedirs(folder)

In [None]:
class DQNAgent:
    def __init__(self, gamma=0.9, eps_low=0.05, lr=1e-3):
        self.env = SnakeEnv()
        self.n_act = self.env.action_space.n
        self.model = DQN(output_dim=self.n_act).to(device)
        self.target_model = DQN(output_dim=self.n_act).to(device)
        self.target_model.load_state_dict(self.model.state_dict())

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.loss_fn = torch.nn.SmoothL1Loss()
        self.rb = ReplayBuffer()
        self.gamma = gamma
        self.eps_low = eps_low
        self.steps = 0
        self.log = {"TrainReward": [], "TestReward": [], "Loss": []}

    def predict(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            return self.model(state).argmax().item()

    def select_action(self, state, eps):
        if random.random() < eps:
            return random.randint(0, self.n_act - 1)
        return self.predict(state)

    def learn(self, batch_size=64):
        s, a, r, s_, done = self.rb.sample(batch_size)
        s, a, r, s_, done = s.to(device), a.to(device), r.to(device), s_.to(device), done.to(device)

        q_pred = self.model(s).gather(1, a.unsqueeze(1)).squeeze()
        with torch.no_grad():
            q_next = self.target_model(s_).max(1)[0]  # ← 標準 DQN 的 target 計算方式
            q_target = r + self.gamma * q_next * (1 - done)

        loss = self.loss_fn(q_pred, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.log["Loss"].append(loss.item())

    def train(self, N_EPISODES=500):
        for ep in tqdm(range(N_EPISODES)):
            state = self.env.reset()
            done = False
            total_reward = 0
            eps = self.eps_low + (1.0 - self.eps_low) * math.exp(-ep / 100)

            while not done:
                action = self.select_action(state, eps)
                next_state, reward, done, info = self.env.step(action)
                self.rb.append(state, action, reward, next_state, float(done))
                state = next_state
                total_reward += reward

                self.steps += 1
                if self.rb.size > 500:
                    self.learn()

                if self.steps % 200 == 0:
                    self.target_model.load_state_dict(self.model.state_dict())

            self.log["TrainReward"].append(total_reward)

            if ep % 20 == 0:
                test_reward = self.test()
                self.log["TestReward"].append(test_reward)
                print(f"\nEpisode {ep}: Train Reward = {total_reward}, Test Reward = {test_reward}")
                torch.save(self.model.state_dict(), f"{folder}snake_model_{ep}.pt")

    def test(self, render=False):
        state = self.env.reset()
        done = False
        total_reward = 0
        while not done:
            if render:
                self.env.render()
            action = self.predict(state)
            state, reward, done, _ = self.env.step(action)
            total_reward += reward
        return total_reward


# train

In [None]:
agent = DQNAgent(gamma=0.9, eps_low=0.05, lr=1e-3)
agent.train(N_EPISODES=1000)

  0%|          | 1/1000 [00:05<1:38:48,  5.93s/it]


Episode 0: Train Reward = -10, Test Reward = -10


  2%|▏         | 21/1000 [01:19<2:25:11,  8.90s/it]


Episode 20: Train Reward = -10, Test Reward = 110


  4%|▍         | 41/1000 [02:28<1:48:44,  6.80s/it]


Episode 40: Train Reward = -10, Test Reward = 160


  6%|▌         | 61/1000 [03:28<1:32:00,  5.88s/it]


Episode 60: Train Reward = 0, Test Reward = 0


  7%|▋         | 71/1000 [03:53<43:22,  2.80s/it]  

In [None]:
agent.test(render=True)