# import neccesery

In [1]:
import pygame
import random
from enum import Enum
from collections import namedtuple
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import math
import os

from Snake_Env import Env
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

pygame 2.6.0 (SDL 2.28.4, Python 3.10.11)
Hello from the pygame community. https://www.pygame.org/contribute.html


# Create Env

In [2]:
import gym
from gym import spaces
import numpy as np

class SnakeEnv(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, w=640, h=480):
        super(SnakeEnv, self).__init__()
        self.game = Env(w=w, h=h)

        # 定義 action space（3個動作：直走、右轉、左轉）
        self.action_space = spaces.Discrete(3)

        # 定義 observation space（長度11的state，每個是0或1）
        self.observation_space = spaces.Box(low=0, high=1, shape=(11,), dtype=np.int32)

    def reset(self):
        state = self.game.reset()
        return state

    def step(self, action_idx):
        """
        action_idx: 0=直走, 1=右轉, 2=左轉
        轉成 one-hot vector 給 SnakeGameAI
        """
        action = [0, 0, 0]
        action[action_idx] = 1
        state, _, reward, next_state, done = self.game.play_step(action)
        info = {"score": self.game.score}
        return next_state, reward, done, info

    def render(self, mode='human'):
        self.game.render()

    def close(self):
        pygame.quit()

# Replay Buffer

In [3]:
class ReplayBuffer:
    def __init__(self, max_size=10000, alpha=0.6):
        self.max_size = max_size
        self.alpha = alpha
        self.ptr = 0
        self.size = 0

        self.s = np.zeros((max_size, 11), dtype=np.float32)
        self.a = np.zeros((max_size,), dtype=np.int64)
        self.r = np.zeros((max_size,), dtype=np.float32)
        self.s_ = np.zeros((max_size, 11), dtype=np.float32)
        self.done = np.zeros((max_size,), dtype=np.float32)
        self.priorities = np.zeros((max_size,), dtype=np.float32)

    def append(self, s, a, r, s_, done):
        max_prio = self.priorities.max() if self.size > 0 else 1.0

        self.s[self.ptr] = s
        self.a[self.ptr] = a
        self.r[self.ptr] = r
        self.s_[self.ptr] = s_
        self.done[self.ptr] = done
        self.priorities[self.ptr] = max_prio

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size, beta=0.4):
        if self.size == self.max_size:
            prios = self.priorities
        else:
            prios = self.priorities[:self.ptr]

        probs = prios ** self.alpha
        probs /= probs.sum()

        indices = np.random.choice(len(probs), batch_size, p=probs)
        weights = (len(probs) * probs[indices]) ** (-beta)
        weights /= weights.max()

        return (
            torch.FloatTensor(self.s[indices]),
            torch.LongTensor(self.a[indices]),
            torch.FloatTensor(self.r[indices]),
            torch.FloatTensor(self.s_[indices]),
            torch.FloatTensor(self.done[indices]),
            torch.FloatTensor(weights),
            indices
        )

    def update_priorities(self, indices, prios):
        for idx, prio in zip(indices, prios):
            self.priorities[idx] = prio

# DQN model

In [4]:
class DQN(torch.nn.Module):
    def __init__(self, input_dim=11, output_dim=3):
        super(DQN, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, 128)
        self.fc2 = torch.nn.Linear(128, 64)

        self.value_fc = torch.nn.Linear(64, 32)
        self.value = torch.nn.Linear(32, 1)

        self.adv_fc = torch.nn.Linear(64, 32)
        self.adv = torch.nn.Linear(32, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        value = self.value(F.relu(self.value_fc(x)))
        adv = self.adv(F.relu(self.adv_fc(x)))

        return value + (adv - adv.mean(dim=-1, keepdim=True))


# Agent

In [5]:
folder = "RainbowDQN_model/"

if not os.path.exists(folder):
    os.makedirs(folder)

In [None]:
class DQNAgent:
    def __init__(self, gamma=0.9, eps_low=0.05, lr=1e-3):
        self.env = SnakeEnv()
        self.n_act = self.env.action_space.n
        self.model = DQN(output_dim=self.n_act).to(device)
        self.target_model = DQN(output_dim=self.n_act).to(device)
        self.target_model.load_state_dict(self.model.state_dict())

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.loss_fn = torch.nn.SmoothL1Loss()
        self.rb = ReplayBuffer()
        self.gamma = gamma
        self.eps_low = eps_low
        self.steps = 0
        self.beta_start = 0.4
        self.beta_frames = 10000
        self.log = {"TrainReward": [], "TestReward": [], "Loss": []}

    def predict(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            return self.model(state).argmax().item()

    def select_action(self, state, eps):
        if random.random() < eps:
            return random.randint(0, self.n_act - 1)
        return self.predict(state)
    
    def save_log(self, ep):
        np.save(f"{folder}log_{ep}.npy", self.log)

    def learn(self, batch_size=64):
        beta = min(1.0, self.beta_start + self.steps * (1.0 - self.beta_start) / self.beta_frames)
        s, a, r, s_, done, weights, indices = self.rb.sample(batch_size)
        s, a, r, s_, done, weights = map(lambda x: x.to(device), (s, a, r, s_, done, weights))


        # Double DQN
        with torch.no_grad():
            next_actions = self.model(s_).argmax(1, keepdim=True)  # 用 online model 選擇動作
            next_q = self.target_model(s_).gather(1, next_actions).squeeze()  # 用 target model 評估
            q_target = r + self.gamma * next_q * (1 - done)

        q_pred = self.model(s).gather(1, a.unsqueeze(1)).squeeze()
        td_error = q_pred - q_target
        prios = (td_error.abs() + 1e-5).detach().cpu().numpy()
        self.rb.update_priorities(indices, prios)

        loss = (td_error.pow(2) * weights.to(device)).mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.log["Loss"].append(loss.item())

    def train(self, N_EPISODES=1000):
        for ep in tqdm(range(N_EPISODES)):
            state = self.env.reset()
            done = False
            total_reward = 0
            eps = self.eps_low + (1.0 - self.eps_low) * math.exp(-ep / 100)

            while not done:
                action = self.select_action(state, eps)
                next_state, reward, done, _ = self.env.step(action)
                self.rb.append(state, action, reward, next_state, float(done))
                state = next_state
                total_reward += reward

                self.steps += 1
                if self.rb.size > 500:
                    self.learn()

                if self.steps % 200 == 0:
                    self.target_model.load_state_dict(self.model.state_dict())

            self.log["TrainReward"].append(total_reward)

            if ep % 20 == 0:
                test_reward = self.test()
                self.log["TestReward"].append(test_reward)
                print(f"\nEpisode {ep}: Train Reward = {total_reward}, Test Reward = {test_reward}")
                torch.save(self.model.state_dict(), f"{folder}snake_model_{ep}.pt")
                self.save_log(ep)

    def test(self, render=False):
        state = self.env.reset()
        done = False
        total_reward = 0
        while not done:
            if render:
                self.env.render()
            action = self.predict(state)
            state, reward, done, _ = self.env.step(action)
            total_reward += reward
        return total_reward


# train

In [None]:
agent = DQNAgent(gamma=0.9, eps_low=0.05, lr=1e-3)
agent.train(N_EPISODES=1000)

  0%|          | 1/1000 [00:13<3:39:21, 13.17s/it]


Episode 0: Train Reward = -10, Test Reward = -10


  0%|          | 5/1000 [00:24<1:21:55,  4.94s/it]


ValueError: not enough values to unpack (expected 7, got 5)

: 

In [None]:
agent.test(render=True)