# Setup Pacman-world

### 1. import all the required packages

In [None]:
import gymnasium as gym
import gymnasium_env
import matplotlib.pyplot as plt
import os
import numpy as np
import warnings
import shutil
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd

### 2. Define Wrapper
* `PacmanEnvWrapper` will contain important information for DQN to use

In [None]:

class PacmanEnvWrapper(gym.Wrapper):
    def __init__(self, env, k, img_size=(84,84)):
        gym.Wrapper.__init__(self, env)
        self.k = k
        self.env = gym.make(env_name, speedup=5.0)
        self.img_size = img_size
        obs_shape = env.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(k, img_size[0], img_size[1]), dtype=np.float32)

    def _preprocess(self, state, th=0.182):
        # TODO(Lab-1): Image processing.
        state = np.array(Image.fromarray(state).resize(self.img_size,Image.BILINEAR))
        state = state.astype(np.float32).mean(2) / 255.
        state[state > th] = 1.0
        state[state <= th] = 0.0

        return state

    def reset(self):
        state = self.env.reset()

        # 確認是否返回了tuple，並提取圖像
        if isinstance(state, tuple):
            state = state[0]

        state = self._preprocess(state)
        state = state[np.newaxis, ...].repeat(self.k, axis=0)  # 堆疊多幀
        return state


    def step(self, action):
        state_next = []
        info =[]
        reward = 0
        terminated = False
        
        for i in range(self.k):
            if not terminated:
                state_next_f, reward_f, terminated_f, info_f = self.env.step(action)
                state_next_f = self._preprocess(state_next_f)
                reward += reward_f
                terminated = terminated_f
                info.append(info_f)
            state_next.append(state_next_f[np.newaxis, ...])
        state_next = np.concatenate(state_next, 0)
        return state_next, reward, terminated, info


# Implement DQN

### 1 Define Dueling QNet

In [None]:

class QNet(nn.Module):
    # TODO(Lab-4): Q-Network architecture.
    def __init__(self, input_shape, n_actions):
        super(QNet, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )

        conv_out_size = self._get_conv_out(input_shape)

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x)
        out = self.fc(conv_out)
        return out

### 2 Define DQN

In [None]:
class DeepQNetwork():
    def __init__(
        self,
        n_actions,
        input_shape,
        qnet,
        device,
        learning_rate=2e-4,
        reward_decay=0.99,
        replace_target_iter=1000,
        memory_size=10000,
        batch_size=32,
    ):
        # initialize parameters
        self.n_actions = n_actions
        self.input_shape = input_shape
        self.lr = learning_rate
        self.gamma = reward_decay
        self.replace_target_iter = replace_target_iter
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.device = device
        self.learn_step_counter = 0
        self.init_memory()

        # Network
        self.qnet_eval = qnet(self.input_shape, self.n_actions).to(self.device)
        self.qnet_target = qnet(self.input_shape, self.n_actions).to(self.device)
        self.qnet_target.eval()
        self.optimizer = optim.RMSprop(self.qnet_eval.parameters(), lr=self.lr)

        # Keep Trach of episodes
        self.episode = 0

    def choose_action(self, state, epsilon=0):
        # 將狀態轉換為 FloatTensor 並增加 batch 維度
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        actions_value = self.qnet_eval.forward(state)
        if np.random.uniform() > epsilon:  # greedy
            action = torch.max(actions_value, 1)[1].data.cpu().numpy()[0]
        else:  # random
            action = np.random.randint(0, self.n_actions)
        return action

    def learn(self):
        # 替换目标网络参数
        if self.learn_step_counter % self.replace_target_iter == 0:
            self.qnet_target.load_state_dict(self.qnet_eval.state_dict())

        # 随机采样经验池中的一个批次
        if self.memory_counter > self.memory_size:
            sample_index = np.random.choice(self.memory_size, size=self.batch_size)
        else:
            sample_index = np.random.choice(self.memory_counter, size=self.batch_size)

        b_s = torch.FloatTensor(self.memory["s"][sample_index]).to(self.device)
        b_a = torch.LongTensor(self.memory["a"][sample_index]).to(self.device)
        b_r = torch.FloatTensor(self.memory["r"][sample_index]).to(self.device)
        b_s_ = torch.FloatTensor(self.memory["s_"][sample_index]).to(self.device)
        b_d = torch.FloatTensor(self.memory["done"][sample_index]).to(self.device)

        # DQN 和 DDQN 两种方式
        q_curr_eval = self.qnet_eval(b_s).gather(1, b_a)
        q_next_target = self.qnet_target(b_s_).detach()
        q_next_eval = self.qnet_eval(b_s_).detach()
        next_state_values = q_next_target.gather(1, q_next_eval.max(1)[1].unsqueeze(1))  # DDQN
        q_curr_recur = b_r + (1 - b_d) * self.gamma * next_state_values

        # 损失计算
        self.loss = F.smooth_l1_loss(q_curr_eval, q_curr_recur)

        # 反向传播和优化
        self.optimizer.zero_grad()
        self.loss.backward()
        self.optimizer.step()
        self.learn_step_counter += 1

        return self.loss.detach().cpu().numpy()



    def init_memory(self):
        # 初始化经验池
        self.memory = {
            "s": np.zeros((self.memory_size, *self.input_shape)),
            "a": np.zeros((self.memory_size, 1)),
            "r": np.zeros((self.memory_size, 1)),
            "s_": np.zeros((self.memory_size, *self.input_shape)),
            "done": np.zeros((self.memory_size, 1)),
        }

    def store_transition(self, s, a, r, s_, d):
        if not hasattr(self, 'memory_counter'):
            self.memory_counter = 0
        index = self.memory_counter % self.memory_size
        self.memory["s"][index] = s
        self.memory["a"][index] = np.array(a).reshape(-1, 1)
        self.memory["r"][index] = np.array(r).reshape(-1, 1)
        self.memory["s_"][index] = s_
        self.memory["done"][index] = np.array(d).reshape(-1, 1)
        self.memory_counter += 1

    def save_load_model(self, op, path="save", fname="qnet.pt"):
        if not os.path.exists(path):
            os.makedirs(path)
        file_path = os.path.join(path, fname)

        if op == "save":
            # 保存模型狀態、優化器狀態、學習步驟和經驗池計數
            checkpoint = {
                'qnet_eval_state_dict': self.qnet_eval.state_dict(),
                'qnet_target_state_dict': self.qnet_target.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'learn_step_counter': self.learn_step_counter,
                'memory_counter': self.memory_counter,
                'episode': self.episode,
            }
            torch.save(checkpoint, file_path)
            print(f"Model saved successfully at {file_path}")

        elif op == "load":
            try:
                # 加載模型狀態、優化器狀態、學習步驟和經驗池計數
                checkpoint = torch.load(file_path, map_location=self.device)

                # 檢查是否包含所有必需的鍵
                required_keys = ['qnet_eval_state_dict', 'qnet_target_state_dict', 'optimizer_state_dict']
                missing_keys = [key for key in required_keys if key not in checkpoint]

                if missing_keys:
                    raise KeyError(f"Missing keys in checkpoint: {missing_keys}")

                # 加載各部分的狀態
                self.qnet_eval.load_state_dict(checkpoint['qnet_eval_state_dict'])
                self.qnet_target.load_state_dict(checkpoint['qnet_target_state_dict'])
                self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

                # 選擇性地加載學習計數
                self.learn_step_counter = checkpoint.get('learn_step_counter', 0)
                self.memory_counter = checkpoint.get('memory_counter', 0)

                print("Model loaded successfully from", file_path)
                return {'learn_step_counter': self.learn_step_counter, 'episode': self.episode}

            except FileNotFoundError:
                print(f"No saved model found at {file_path}, starting fresh.")
            except KeyError as e:
                print(f"Error loading model: {e}")

            # 如果未成功加載模型或發生錯誤，返回初始狀態
        return {'learn_step_counter': 0, 'episode': 0} 

# Helper Functions

### 1. Define Epsilon Function

In [None]:
def epsilon_compute(frame_id, epsilon_max=1, epsilon_min=0.05, epsilon_decay=100000):
    return epsilon_min + (epsilon_max - epsilon_min) * np.exp(-frame_id / epsilon_decay)

### 2. Output to GIF
* remember to run it, it'll be used later.

In [None]:
project_root = os.getcwd()
def save_gif(img_buffer, fname, gif_path=os.path.join(project_root, "GIF")):
    if not os.path.exists(gif_path):
        os.makedirs(gif_path)
    img_buffer[0].save(os.path.join(gif_path, fname), save_all=True, append_images=img_buffer[1:], duration=3, loop=0)

# Trainer Functions

### 1. Define `Play()`

In [None]:

import gymnasium as gym
import gymnasium_env
def play(env, agent, stack_frames, img_size, randomized_ratio=0.005):
    # Reset environment.
    state = env.reset()
    img_buffer = [Image.fromarray(state[0]*255)]

    # Initialize information.
    step = 0
    total_reward = 0

    # One episode.
    while True:
        # Select action.
        action = agent.choose_action(state, randomized_ratio)

        # Get next stacked state.
        state_next, reward, done, info = env.step(action)
        if step % 2 == 0:
            img_buffer.append(Image.fromarray(state_next[0]*255))

        state = state_next.copy()
        step += 1
        total_reward += reward
        print('\rStep: {:3d} | Reward: {:.3f} / {:.3f} | Action: {:.3f} | Info: {}'.format(step, reward, total_reward, action, info[0]), end="")

        if done or step>1000:
            print()
            break

    return img_buffer


### CSV Helper

In [None]:
def write_to_csv(episode: int, total_steps: int, loss_arr: list, total_reward: float, episode_score: int, epsilon: float, csv_filepath: str):
 
    # Calculate average of loss function value for each episode
    avg_loss = np.sum(loss_arr) / len(loss_arr)
    print(type(avg_loss))
 
    # Store data into csv
    df = pd.DataFrame([{
        "Episodes": episode,
        "Total_Steps": total_steps,
        "Loss": avg_loss,
        "Total_Reward": total_reward,
        "Score": episode_score,
        "Epsilon": epsilon,
    }])
 
    write_header = not os.path.exists(csv_filepath)
    df.to_csv(csv_filepath, mode='a', header=write_header, index=False)
    print(f"Metrics of episode {episode} appended to {csv_filepath}!")
 
    return

### 2. Define `train()`

In [None]:
def train(env, agent, stack_frames, img_size, save_path="save", max_steps=1000000, session_name="default", max_episodes=10000):
    total_step = 0
    episode = 0
 
    # 初始化紀錄損失值與步數
    loss_values = []
 
 
    # 確保保存路徑存在
    os.makedirs(save_path, exist_ok=True)
    csv_filename = f"training_metrics_{session_name}.csv"
    csv_path = os.path.join(save_path, csv_filename)
    model_filename = f"qnet_{session_name}.pt"
 
    # 嘗試加載模型和訓練狀態
    try:
        print("Loading model and training status...")
        status = agent.save_load_model(op="load", path=save_path, fname=model_filename)
        total_step = status["learn_step_counter"]
        episode = status["episode"]
        print(f"Resuming training from total_step={total_step}, episode={episode}")
    except FileNotFoundError:
        print("No previous model found. Starting training from scratch.")
    except KeyError as e:
        print(f"Missing key in checkpoint: {e}")
 
 
    while total_step <= max_steps:
        # Reset environment.
        state = env.reset()
 
        # Initialize information.
        step = 0
        total_reward = 0
        loss = 0
 
        # Make sure agent episode does not exceed the limit
        if agent.episode > max_episodes:
            break
 
        # One episode.
        while True:
            loss_values = []
            # TODO(Lab-6): Select action.
            epsilon = epsilon_compute(total_step)
            action = agent.choose_action(state, epsilon)
 
            # Get next observation.
            obs, reward, terminated, info = env.step(action)
 
            # 如果 obs 是 tuple，提取圖像
            if isinstance(obs, tuple):
                obs = obs[0]
 
            # 判斷是否遊戲結束
            done = terminated
 
            # Store transition and learn.
            agent.store_transition(state, action, reward, obs, done)
 
            if total_step > 4 * agent.batch_size:
                loss = agent.learn()
 
            state = obs.copy()  # 更新狀態
            step += 1
            total_step += 1
            total_reward += reward
 
 
            # Print status 
            if total_step % 10 == 0 or done:
                print('\rEpisode: {:3d} | Step: {:3d} / {:3d} | Reward: {:.3f} / {:.3f} | Loss: {:.3f} | Epsilon: {:.3f}'.format(agent.episode, step, total_step, reward, total_reward, loss, epsilon), end="")
            loss_values.append(loss)
 
            # max step for each episode is 1000
            # Keep track of every crucial info each episode
            if done or step > 30:
                write_to_csv(
                    episode=agent.episode, 
                    total_steps=total_step, 
                    loss_arr=loss_values, 
                    total_reward=total_reward, 
                    episode_score=info[0]['total_score'], 
                    csv_filepath=csv_path,
                    epsilon=epsilon
                )
            
                agent.episode += 1
 
            # Evaluate model for every given episode
                if agent.episode % 20 == 0:
                    print("\nSave Model ...")
 
                    agent.save_load_model(
                        op="save",
                        path=save_path,
                        fname=model_filename
                    )
 
 
                    gif_name = f"train_ep" + str(agent.episode).zfill(5) + ".gif"
                    print(f"Generate GIF <{gif_name}>...")
                    img_buffer = play(env, agent, stack_frames, img_size, 0.50)
                    save_gif(img_buffer, gif_name)
                    print("Done !!")
 
					# Back up model
                    if agent.episode % 400 == 0:
                        print("Doing backup...")
                        backup_filename = f"{model_filename}.ep{agent.episode}.qnet.bak"
                        orig_path = os.path.join(save_path, model_filename)
                        backup_path = os.path.join(save_path, backup_filename)
                        shutil.copy(orig_path, backup_path)
                        print(f"Backup done, file path: {backup_path}")
 
                
                break
 
 
            if total_step > max_steps:
                break

# Start Training

### 10. Train the model

* 600step/min on Colab (with T4 GPU), 400step/min on RTX3070 laptop. Pretty slow

In [None]:
env_name = 'gymnasium_env/PacmanGymEnv'
env = gym.make(env_name, speedup=4.0)
env_pacman = PacmanEnvWrapper(env=env, k=4, img_size=(84, 84))
stack_frames = 4
img_size = (84,84)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

agent = DeepQNetwork(
        n_actions = env.action_space.n,
        input_shape = [stack_frames, *img_size],
        qnet = QNet,
        device = device,
        learning_rate = 2e-5,
        reward_decay = 0.95,
        replace_target_iter = 1000,
        memory_size = 10000,
        batch_size = 32,)

In [None]:
train(env_pacman, agent, stack_frames, img_size, save_path=os.path.join(project_root, "save"), max_steps=600000, session_name="DDQN_NoGrid_Advanced")

# Evaluate Model

In [None]:
agent.save_load_model(op="load", path=os.path.join(project_root, "save"), fname="qnet.pt")
env_pacman = PacmanEnvWrapper(env, k=4, img_size=(84,84))
img_buffer = play(env_pacman, agent, stack_frames, img_size)
save_gif(img_buffer, "eval.gif")