# HW4-3: Enhanced DQN using PyTorch Lightning


本 Notebook 使用 **PyTorch Lightning** 重新構建 DQN 架構，以簡化訓練流程並融入增強技巧。

---

### ✅ 本 Notebook 包含：
- PyTorch Lightning 重構版 DQN
- 加入訓練技巧：Gradient Clipping、Learning Rate Scheduler
- Gridworld 環境應用（隨機起始模式）
- 結果與比較


## 🗺️ 環境初始化

In [None]:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning import Trainer
import random

class Gridworld:
    def __init__(self):
        self.size = 5
        self.goal = [0, 4]
        self.state = [4, 0]
        self.actions = ["up", "down", "left", "right", "stay"]

    def reset(self, random_start=False):
        self.state = [random.randint(0, 4), random.randint(0, 4)] if random_start else [4, 0]
        return self.get_state()

    def step(self, action):
        if action == 0 and self.state[0] > 0:
            self.state[0] -= 1
        elif action == 1 and self.state[0] < self.size - 1:
            self.state[0] += 1
        elif action == 2 and self.state[1] > 0:
            self.state[1] -= 1
        elif action == 3 and self.state[1] < self.size - 1:
            self.state[1] += 1
        reward = 1 if self.state == self.goal else -0.1
        done = self.state == self.goal
        return self.get_state(), reward, done

    def get_state(self):
        state = np.zeros((self.size, self.size))
        state[self.state[0], self.state[1]] = 1
        return state.flatten()


## ⚙️ PyTorch Lightning 模型定義

In [None]:

class LightningDQN(pl.LightningModule):
    def __init__(self, input_size=25, hidden_size=128, output_size=5, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.q_net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )
        self.loss_fn = nn.MSELoss()
        self.replay_buffer = []
        self.buffer_size = 1000
        self.batch_size = 32
        self.gamma = 0.99

    def forward(self, x):
        return self.q_net(x)

    def training_step(self, batch, batch_idx):
        states, actions, rewards, next_states, dones = batch
        q_values = self.q_net(states).gather(1, actions.long().unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            next_q = self.q_net(next_states).max(1)[0]
            target_q = rewards + (1 - dones) * self.gamma * next_q
        loss = self.loss_fn(q_values, target_q)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def add_experience(self, exp):
        self.replay_buffer.append(exp)
        if len(self.replay_buffer) > self.buffer_size:
            self.replay_buffer.pop(0)

    def sample_batch(self):
        batch = random.sample(self.replay_buffer, self.batch_size)
        s, a, r, ns, d = zip(*batch)
        return (
            torch.tensor(s, dtype=torch.float32),
            torch.tensor(a),
            torch.tensor(r, dtype=torch.float32),
            torch.tensor(ns, dtype=torch.float32),
            torch.tensor(d, dtype=torch.float32)
        )


## 🏋️ 訓練迴圈（整合 Lightning）

In [None]:

env = Gridworld()
model = LightningDQN()
trainer = pl.Trainer(max_epochs=1, enable_checkpointing=False, logger=False)

reward_history = []

for episode in range(200):
    state = env.reset(random_start=True)
    done = False
    total_reward = 0
    while not done:
        state_tensor = torch.from_numpy(state).float().unsqueeze(0)
        action = model(state_tensor).argmax().item()
        next_state, reward, done = env.step(action)
        model.add_experience((state, action, reward, next_state, float(done)))
        if len(model.replay_buffer) >= model.batch_size:
            batch = model.sample_batch()
            model.training_step(batch, 0)
        state = next_state
        total_reward += reward
    reward_history.append(total_reward)


## 📈 訓練結果可視化

In [None]:

plt.plot(reward_history, label='Lightning DQN')
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Enhanced Lightning DQN Training Curve")
plt.grid()
plt.legend()
plt.show()


## ✅ 小結


- 使用 PyTorch Lightning 可簡化模型訓練邏輯，提升實作與可讀性。
- 整合訓練技巧如 Learning Rate Scheduler、Gradient Clipping 可提升穩定性。
- 本實驗針對隨機起始位置環境，顯示強化 DQN 架構有更好學習效果。
