In [1]:
# 安裝必要套件
!pip install timm einops matplotlib pillow scikit-image

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->timm)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->timm)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->timm)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch->tim

In [None]:
# 導入必要的庫
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import math
import random
from einops import rearrange
import warnings
warnings.filterwarnings('ignore')

# 設置設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用設備: {device}')
# 設置隨機種子
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [2]:
class MDTA(nn.Module):
    """Multi-Dconv Head Transposed Attention"""
    def __init__(self, channels, num_heads):
        super(MDTA, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False)
        self.qkv_dwconv = nn.Conv2d(channels * 3, channels * 3, kernel_size=3,
                                   stride=1, padding=1, groups=channels * 3, bias=False)
        self.project_out = nn.Conv2d(channels, channels, kernel_size=1, bias=False)

    def forward(self, x):
        b, c, h, w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
        out = self.project_out(out)
        return out

class GDFN(nn.Module):
    """Gated-Dconv Feed-Forward Network"""
    def __init__(self, channels, expansion_factor):
        super(GDFN, self).__init__()

        hidden_channels = int(channels * expansion_factor)
        self.project_in = nn.Conv2d(channels, hidden_channels * 2, kernel_size=1, bias=False)
        self.dwconv = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, kernel_size=3,
                               stride=1, padding=1, groups=hidden_channels * 2, bias=False)
        self.project_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = torch.nn.functional.gelu(x1) * x2
        x = self.project_out(x)
        return x

class TransformerBlock(nn.Module):
    """Restormer Transformer Block"""
    def __init__(self, channels, num_heads, expansion_factor):
        super(TransformerBlock, self).__init__()

        self.norm1 = nn.LayerNorm(channels)
        self.attn = MDTA(channels, num_heads)
        self.norm2 = nn.LayerNorm(channels)
        self.ffn = GDFN(channels, expansion_factor)

    def forward(self, x):
        b, c, h, w = x.shape

        # Attention
        res = x
        x = x.flatten(2).transpose(1, 2)  # (B, H*W, C)
        x = self.norm1(x)
        x = x.transpose(1, 2).view(b, c, h, w)  # (B, C, H, W)
        x = self.attn(x) + res

        # FFN
        res = x
        x = x.flatten(2).transpose(1, 2)  # (B, H*W, C)
        x = self.norm2(x)
        x = x.transpose(1, 2).view(b, c, h, w)  # (B, C, H, W)
        x = self.ffn(x) + res

        return x

class Restormer(nn.Module):
    """完整的 Restormer 模型"""
    def __init__(self,
                 inp_channels=3,
                 out_channels=3,
                 dim=32,
                 num_blocks=[4,6,6,8],
                 num_refinement_blocks=4,
                 heads=[1,2,4,8],
                 ffn_expansion_factor=2.66,
                 bias=False,
                 LayerNorm_type='WithBias'):
        super(Restormer, self).__init__()

        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)

        self.encoder_level1 = nn.Sequential(*[TransformerBlock(channels=dim, num_heads=heads[0],
                                                               expansion_factor=ffn_expansion_factor)
                                             for i in range(num_blocks[0])])

        self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(channels=int(dim*2**1), num_heads=heads[1],
                                                               expansion_factor=ffn_expansion_factor)
                                             for i in range(num_blocks[1])])

        self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(channels=int(dim*2**2), num_heads=heads[2],
                                                               expansion_factor=ffn_expansion_factor)
                                             for i in range(num_blocks[2])])

        self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
        self.latent = nn.Sequential(*[TransformerBlock(channels=int(dim*2**3), num_heads=heads[3],
                                                       expansion_factor=ffn_expansion_factor)
                                     for i in range(num_blocks[3])])

        self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(*[TransformerBlock(channels=int(dim*2**2), num_heads=heads[2],
                                                               expansion_factor=ffn_expansion_factor)
                                             for i in range(num_blocks[2])])


        self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(*[TransformerBlock(channels=int(dim*2**1), num_heads=heads[1],
                                                               expansion_factor=ffn_expansion_factor)
                                             for i in range(num_blocks[1])])

        self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.decoder_level1 = nn.Sequential(*[TransformerBlock(channels=int(dim*2**1), num_heads=heads[0],
                                                               expansion_factor=ffn_expansion_factor)
                                             for i in range(num_blocks[0])])

        self.refinement = nn.Sequential(*[TransformerBlock(channels=int(dim*2**1), num_heads=heads[0],
                                                           expansion_factor=ffn_expansion_factor)
                                         for i in range(num_refinement_blocks)])

        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)


    def forward(self, inp_img):

        inp_enc_level1 = self.patch_embed(inp_img)
        out_enc_level1 = self.encoder_level1(inp_enc_level1)

        inp_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(inp_enc_level2)

        inp_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(inp_enc_level3)

        inp_enc_level4 = self.down3_4(out_enc_level3)
        latent = self.latent(inp_enc_level4)

        inp_dec_level3 = self.up4_3(latent)
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3)

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2)

        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)

        out_dec_level1 = self.refinement(out_dec_level1)

        out = self.output(out_dec_level1) + inp_img

        return out

class OverlapPatchEmbed(nn.Module):
    """重疊 Patch 嵌入層"""
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super(OverlapPatchEmbed, self).__init__()
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, x):
        x = self.proj(x)
        return x

class Downsample(nn.Module):
    """下採樣層"""
    def __init__(self, n_feat):
        super(Downsample, self).__init__()
        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)

class Upsample(nn.Module):
    """上採樣層"""
    def __init__(self, n_feat):
        super(Upsample, self).__init__()
        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)


In [None]:
class ImageRestorationDataset(Dataset):
    """圖像修復數據集"""
    def __init__(self, degraded_dir, clean_dir, transform=None, is_train=True):
        self.degraded_dir = degraded_dir
        self.clean_dir = clean_dir
        self.transform = transform
        self.is_train = is_train

        # 獲取所有降質圖像文件名
        self.degraded_images = []
        if is_train:
            # 訓練模式：有乾淨圖像對應
            for img_name in os.listdir(degraded_dir):
                if img_name.endswith(('.png', '.jpg', '.jpeg')):
                    self.degraded_images.append(img_name)
        else:
            # 測試模式：只有降質圖像
            for img_name in os.listdir(degraded_dir):
                if img_name.endswith(('.png', '.jpg', '.jpeg')):
                    self.degraded_images.append(img_name)

        self.degraded_images.sort()

    def __len__(self):
        return len(self.degraded_images)

    def __getitem__(self, idx):
        # 載入降質圖像
        degraded_path = os.path.join(self.degraded_dir, self.degraded_images[idx])
        degraded_img = Image.open(degraded_path).convert('RGB')

        if self.is_train:
            # 訓練模式：載入對應的乾淨圖像
            img_name = self.degraded_images[idx]
            # 找到對應的乾淨圖像
            if img_name.startswith('rain-'):
                clean_name = img_name.replace('rain-', 'rain_clean-')
            elif img_name.startswith('snow-'):
                clean_name = img_name.replace('snow-', 'snow_clean-')
            else:
                clean_name = img_name  # 備用方案

            clean_path = os.path.join(self.clean_dir, clean_name)
            clean_img = Image.open(clean_path).convert('RGB')

            if self.transform:
                degraded_img = self.transform(degraded_img)
                clean_img = self.transform(clean_img)

            return {
                'degraded': degraded_img,
                'clean': clean_img,
                'filename': self.degraded_images[idx]
            }
        else:
            # 測試模式：只返回降質圖像
            if self.transform:
                degraded_img = self.transform(degraded_img)

            return {
                'degraded': degraded_img,
                'filename': self.degraded_images[idx]
            }

# 定義數據轉換
def get_transforms(is_train=True):
    if is_train:
        return transforms.Compose([
            transforms.RandomCrop(256),  # 隨機裁剪到 256x256
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(),
        ])
    else:
        return transforms.Compose([
            transforms.ToTensor(),
        ])
#掛載雲端
from dataclasses import dataclass
from google.colab import drive
drive.mount('/content/drive')


# 設置數據路徑
train_degraded_dir = '/content/drive/MyDrive/DL_HW/hw4_realse_dataset/train/degraded'
train_clean_dir = '/content/drive/MyDrive/DL_HW/hw4_realse_dataset/train/clean'
test_degraded_dir = '/content/drive/MyDrive/DL_HW/hw4_realse_dataset/test/degraded'

# 建立數據載入器
train_transform = get_transforms(is_train=True)
test_transform = get_transforms(is_train=False)

train_dataset = ImageRestorationDataset(train_degraded_dir, train_clean_dir,
                                       transform=train_transform, is_train=True)
test_dataset = ImageRestorationDataset(test_degraded_dir, None,
                                      transform=test_transform, is_train=False)

print(f'訓練數據集大小: {len(train_dataset)}')
print(f'測試數據集大小: {len(test_dataset)}')

# 建立 DataLoader
batch_size = 4
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=6)

In [None]:
def calculate_psnr(img1, img2, max_val=1.0):

    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
    return psnr.item()

class CharbonnierLoss(nn.Module):

    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        diff = x - y
        loss = torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps)))
        return loss

# 初始化模型
model = Restormer(
    inp_channels=3,
    out_channels=3,
    dim=48,
    num_blocks=[4,6,6,8],
    num_refinement_blocks=4,
    heads=[1,2,4,8],
    ffn_expansion_factor=2.66,
    bias=False
).to(device)

# 計算模型參數數量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'總參數數量: {total_params:,}')
print(f'可訓練參數數量: {trainable_params:,}')


# 設置優化器和損失函數
criterion = CharbonnierLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)



In [None]:
# %%
# 確保所有必要的類都已定義
import math
import os
import torch

# WarmupCosineAnnealingLR 類定義
class WarmupCosineAnnealingLR:
    def __init__(self, optimizer, warmup_epochs, total_epochs, warmup_start_lr=1e-6, base_lr=3e-4, eta_min=1e-6):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.warmup_start_lr = warmup_start_lr
        self.base_lr = base_lr
        self.eta_min = eta_min
        self.current_epoch = 0

        # 設置初始學習率（只在有優化器時執行）
        if self.optimizer is not None:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.warmup_start_lr

    def step(self):
        self.current_epoch += 1

        if self.current_epoch <= self.warmup_epochs:
            # Warmup 階段：線性增長
            lr = self.warmup_start_lr + (self.base_lr - self.warmup_start_lr) * (self.current_epoch / self.warmup_epochs)
        else:
            # Cosine annealing 階段
            cosine_epochs = self.total_epochs - self.warmup_epochs
            current_cosine_epoch = self.current_epoch - self.warmup_epochs
            lr = self.eta_min + (self.base_lr - self.eta_min) * (
                1 + math.cos(math.pi * current_cosine_epoch / cosine_epochs)
            ) / 2

        # 應用學習率（只在有優化器時執行）
        if self.optimizer is not None:
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr

        return lr

    def get_lr(self):
        if self.optimizer is not None:
            return self.optimizer.param_groups[0]['lr']
        else:
            # 如果沒有優化器，計算當前應該的學習率
            if self.current_epoch <= self.warmup_epochs:
                return self.warmup_start_lr + (self.base_lr - self.warmup_start_lr) * (self.current_epoch / self.warmup_epochs)
            else:
                cosine_epochs = self.total_epochs - self.warmup_epochs
                current_cosine_epoch = self.current_epoch - self.warmup_epochs
                return self.eta_min + (self.base_lr - self.eta_min) * (
                    1 + math.cos(math.pi * current_cosine_epoch / cosine_epochs)
                ) / 2

# EarlyStopping
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0.001, restore_best_weights=True, mode='max'):
        """
        Args:
            patience (int): 等待改善的 epoch 數量
            min_delta (float): 最小改善幅度，低於此值不算改善
            restore_best_weights (bool): 是否在停止時恢復最佳權重
            mode (str): 'max' for maximizing (PSNR), 'min' for minimizing (loss)
        """
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.mode = mode

        # 內部狀態
        self.wait = 0
        self.stopped_epoch = 0
        self.best_score = None
        self.best_weights = None
        self.early_stop = False

        # 設置比較函數
        if mode == 'max':
            self.monitor_op = lambda current, best: current > best + min_delta
        else:  # mode == 'min'
            self.monitor_op = lambda current, best: current < best - min_delta

        self.reset()

    def reset(self):

        self.wait = 0
        self.early_stop = False
        if self.mode == 'max':
            self.best_score = float('-inf')
        else:
            self.best_score = float('inf')

    def __call__(self, score, model=None):

        if self.monitor_op(score, self.best_score):
            # 有改善
            self.best_score = score
            self.wait = 0

            # 保存最佳權重
            if self.restore_best_weights and model is not None:
                self.best_weights = {
                    name: param.clone().detach()
                    for name, param in model.named_parameters()
                }
        else:
            # 沒有改善
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = self.wait
                self.early_stop = True

                # 恢復最佳權重
                if self.restore_best_weights and model is not None and self.best_weights is not None:
                    print(f"恢復最佳權重（PSNR: {self.best_score:.2f}dB）")
                    for name, param in model.named_parameters():
                        param.data.copy_(self.best_weights[name])

        return self.early_stop

    def get_status(self):

        return {
            'wait': self.wait,
            'patience': self.patience,
            'best_score': self.best_score,
            'early_stop': self.early_stop
        }



warmup_epochs = 0      # warmup 5 個 epochs
total_epochs = 10      # 總共 50 個 epochs
warmup_start_lr = 1e-6 # warmup 起始學習率
base_lr = 1e-5         # 基礎學習率
eta_min = 1e-6         # 最小學習率

# 重新設置優化器的初始學習率
for param_group in optimizer.param_groups:
    param_group['lr'] = warmup_start_lr

# 使用新的調度器
scheduler = WarmupCosineAnnealingLR(
    optimizer=optimizer,
    warmup_epochs=warmup_epochs,
    total_epochs=total_epochs,
    warmup_start_lr=warmup_start_lr,
    base_lr=base_lr,
    eta_min=eta_min
)

print(f"Warmup 調度器設置完成：")
print(f"  Warmup epochs: {warmup_epochs}")
print(f"  起始學習率: {warmup_start_lr:.2e}")
print(f"  基礎學習率: {base_lr:.2e}")
print(f"  最小學習率: {eta_min:.2e}")
print(f"  當前學習率: {scheduler.get_lr():.2e}")

# 設置 Early Stopping
patience = 10           # 等待改善的 epoch 數量
min_delta = 0.01       # 最小改善幅度 (PSNR)
early_stopping = EarlyStopping(
    patience=patience,
    min_delta=min_delta,
    restore_best_weights=True,
    mode='max'  # 監控 PSNR（越大越好）
)

print(f"\nEarly Stopping 設置完成：")
print(f"  Patience: {patience} epochs")
print(f"  最小改善幅度: {min_delta} dB")
print(f"  監控指標: 驗證 PSNR (max mode)")
print(f"  自動恢復最佳權重: 是")

In [None]:
def visualize_lr_schedule():
    """視覺化學習率調度"""
    epochs = list(range(1, 10))  # 顯示前20個epoch
    lrs = []

    for epoch in epochs:
        if epoch <= warmup_epochs:
            lr = warmup_start_lr + (base_lr - warmup_start_lr) * (epoch / warmup_epochs)
        else:
            cosine_epochs = total_epochs - warmup_epochs
            current_cosine_epoch = epoch - warmup_epochs
            lr = eta_min + (base_lr - eta_min) * (
                1 + math.cos(math.pi * current_cosine_epoch / cosine_epochs)
            ) / 2
        lrs.append(lr)

    plt.figure(figsize=(10, 6))
    plt.plot(epochs, lrs, 'b-', linewidth=2, marker='o', markersize=4)
    plt.axvline(x=warmup_epochs, color='red', linestyle='--', alpha=0.7,
                label=f'Warmup 結束 (Epoch {warmup_epochs})')
    plt.title('Warmup + Cosine Annealing 學習率調度')
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

    print("前10個epoch的學習率變化：")
    for i, (epoch, lr) in enumerate(zip(epochs[:10], lrs[:10])):
        status = "Warmup" if epoch <= warmup_epochs else "Cosine"
        print(f"  Epoch {epoch}: {lr:.2e} ({status})")

visualize_lr_schedule()

In [6]:
# %%

def train_epoch_stable(model, train_loader, criterion, optimizer, device, epoch, scaler=None):
    model.train()
    total_loss = 0.0
    total_psnr = 0.0
    num_batches = len(train_loader)

    if scaler is None:
        from torch.cuda.amp import GradScaler
        scaler = GradScaler()

    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}')

    for batch_idx, batch in enumerate(progress_bar):
        degraded = batch['degraded'].to(device, non_blocking=True)
        clean = batch['clean'].to(device, non_blocking=True)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            restored = model(degraded)
            loss = criterion(restored, clean)

            # 檢查loss是否異常
            if torch.isnan(loss) or torch.isinf(loss) or loss.item() > 10.0:
                print(f"警告：檢測到異常loss值 {loss.item():.4f}，跳過此批次")
                continue

        # 縮放損失並反向傳播
        scaler.scale(loss).backward()

        # 梯度裁剪 - 防止梯度爆炸
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()

        with torch.no_grad():
            psnr = calculate_psnr(restored, clean)

            # 檢查PSNR是否異常
            if psnr < 0 or psnr > 100:
                print(f"警告：異常PSNR值 {psnr:.2f}dB")
                continue

        total_loss += loss.item()
        total_psnr += psnr

        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'PSNR': f'{psnr:.2f}dB',
            'Scale': f'{scaler.get_scale():.0f}'
        })

        if batch_idx % 100 == 0:
            torch.cuda.empty_cache()

    avg_loss = total_loss / num_batches
    avg_psnr = total_psnr / num_batches

    return avg_loss, avg_psnr

def validate(model, val_loader, criterion, device):
    """驗證函數"""
    model.eval()
    total_loss = 0.0
    total_psnr = 0.0
    num_batches = len(val_loader)

    with torch.no_grad():
        for batch in tqdm(val_loader, desc='Validation'):
            degraded = batch['degraded'].to(device)
            clean = batch['clean'].to(device)

            restored = model(degraded)
            loss = criterion(restored, clean)
            psnr = calculate_psnr(restored, clean)

            total_loss += loss.item()
            total_psnr += psnr

    avg_loss = total_loss / num_batches
    avg_psnr = total_psnr / num_batches

    return avg_loss, avg_psnr


In [None]:
from torch.utils.data import random_split

# 將訓練集分割為訓練和驗證集 (80%-20%)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=10)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=10)

print(f'訓練集大小: {len(train_subset)}')
print(f'驗證集大小: {len(val_subset)}')



# 檢查是否有現有模型可以接續訓練
# 接續訓練設置
RESUME_TRAINING = True  # 設為 True 來啟用接續訓練
MODEL_PATH = 'best_restormer_model.pth'

# 初始化訓練參數
num_epochs = 30
start_epoch = 1
best_psnr = 0.0
train_losses = []
train_psnrs = []
val_losses = []
val_psnrs = []

# 嘗試載入現有模型
if RESUME_TRAINING and os.path.exists(MODEL_PATH):
    try:
        print(f"嘗試載入現有模型進行接續訓練...")
        checkpoint = torch.load(MODEL_PATH, map_location=device)

        # 載入模型權重
        model.load_state_dict(checkpoint['model_state_dict'])

        # 載入優化器狀態（如果存在）
        if 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print("優化器狀態已載入")

        # 載入訓練進度
        start_epoch = checkpoint.get('epoch', 1) + 1
        best_psnr = checkpoint.get('best_psnr', 0.0)

        # 載入歷史記錄（如果存在）
        if 'train_losses' in checkpoint:
            train_losses = checkpoint['train_losses']
            train_psnrs = checkpoint['train_psnrs']
            val_losses = checkpoint['val_losses']
            val_psnrs = checkpoint['val_psnrs']
            print("歷史記錄已載入")

        # 檢查是否是 early stopping 的模型
        if checkpoint.get('early_stopped', False):
            print(f"  注意：這是一個 Early Stopping 的模型")
            print(f"   原停止 epoch: {checkpoint.get('stopped_epoch', 'N/A')}")

            # 詢問是否要重置 early stopping
            reset_early_stopping = input("是否重置 Early Stopping 狀態? (y/n): ").lower().strip()
            if reset_early_stopping == 'y':
                early_stopping.reset()
                print(" Early Stopping 狀態已重置")
            else:
                print("  保持原有 Early Stopping 狀態")

        print(f" 接續訓練設置完成:")
        print(f"  - 起始 epoch: {start_epoch}")
        print(f"  - 目標 epoch: {num_epochs}")
        print(f"  - 當前最佳 PSNR: {best_psnr:.2f}dB")
        print(f"  - 已有訓練記錄: {len(train_losses)} epochs")

        # 調整學習率調度器到正確的 epoch
        for _ in range(start_epoch - 1):
            scheduler.step()
        print(f"  - 當前學習率: {scheduler.get_lr():.6f}")

    except Exception as e:
        print(f" 載入模型失敗: {e}")
        print(f" 將從頭開始訓練")
        start_epoch = 1
        best_psnr = 0.0
        train_losses = []
        train_psnrs = []
        val_losses = []
        val_psnrs = []

else:
    if RESUME_TRAINING:
        print(f"ℹ  未找到現有模型，將從頭開始訓練")
    else:
        print(f" 從頭開始訓練")

# 選擇訓練函數（推薦使用穩定版本）
# train_epoch = train_epoch_gradient_clipped  # 更強的梯度裁剪版本
train_epoch = train_epoch_stable              # 穩定版本，防止梯度爆炸

print(f"\n開始訓練...")
print(f"使用訓練函數: {train_epoch.__name__}")
if start_epoch > 1:
    print(f"接續從 Epoch {start_epoch} 開始")

# 訓練循環
for epoch in range(start_epoch, num_epochs + 1):
    print(f"\n=== Epoch {epoch}/{num_epochs} ===")

    # 訓練
    train_loss, train_psnr = train_epoch_stable(model, train_loader, criterion, optimizer, device, epoch)

    # 驗證
    val_loss, val_psnr = validate(model, val_loader, criterion, device)

    # 更新學習率
    scheduler.step()

    # 記錄結果
    train_losses.append(train_loss)
    train_psnrs.append(train_psnr)
    val_losses.append(val_loss)
    val_psnrs.append(val_psnr)

    print(f'訓練 - Loss: {train_loss:.4f}, PSNR: {train_psnr:.2f}dB')
    print(f'驗證 - Loss: {val_loss:.4f}, PSNR: {val_psnr:.2f}dB')

    # 顯示學習率和 warmup 狀態
    current_lr = scheduler.get_lr()
    if epoch <= warmup_epochs:
        warmup_progress = epoch / warmup_epochs * 100
        print(f'學習率: {current_lr:.6f} (Warmup 階段: {warmup_progress:.1f}%)')
    else:
        print(f'學習率: {current_lr:.6f} (Cosine Annealing 階段)')

    # 保存最佳模型（增強版 - 包含訓練歷史）
    if val_psnr > best_psnr:
        best_psnr = val_psnr
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.__dict__,  # 保存調度器狀態
            'best_psnr': best_psnr,
            'train_losses': train_losses,
            'train_psnrs': train_psnrs,
            'val_losses': val_losses,
            'val_psnrs': val_psnrs,
            'early_stopping_state': early_stopping.get_status(),
            'resume_info': {
                'total_epochs_trained': len(train_losses),
                'original_start_epoch': start_epoch,
                'training_resumed': start_epoch > 1
            }
        }, 'best_restormer_model.pth')
        print(f"新的最佳模型已保存！PSNR: {best_psnr:.2f}dB")

    # **Early Stopping 檢查**
    if early_stopping(val_psnr, model):
        # 取得 early stopping 狀態
        status = early_stopping.get_status()
        print(f"\n Early Stopping 觸發！")
        print(f"  停止 epoch: {epoch}")
        print(f"  最佳 PSNR: {status['best_score']:.2f}dB")
        print(f"  已等待: {status['wait']}/{status['patience']} epochs")
        print(f"  最佳權重已自動恢復")

        # 保存 early stopping 後的最終模型
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.__dict__,
            'best_psnr': status['best_score'],
            'early_stopped': True,
            'stopped_epoch': epoch,
            'train_losses': train_losses,
            'train_psnrs': train_psnrs,
            'val_losses': val_losses,
            'val_psnrs': val_psnrs,
            'early_stopping_state': status,
            'resume_info': {
                'total_epochs_trained': len(train_losses),
                'original_start_epoch': start_epoch,
                'training_resumed': start_epoch > 1
            }
        }, 'best_restormer_model.pth')

        break

    # 顯示 Early Stopping 狀態
    status = early_stopping.get_status()
    if status['wait'] > 0:
        print(f" Early Stopping: {status['wait']}/{patience} epochs (最佳 PSNR: {status['best_score']:.2f}dB)")
    else:
        print(f" 驗證指標改善！(目前最佳 PSNR: {status['best_score']:.2f}dB)")

# 訓練結束總結
total_epochs_trained = len(train_losses)
if early_stopping.early_stop:
    print(f"\n 訓練因 Early Stopping 提前結束")
    print(f"實際訓練 epochs: {epoch}")
    print(f"最佳驗證 PSNR: {early_stopping.best_score:.2f}dB")
else:
    print(f"\n 訓練正常完成")
    print(f"總訓練 epochs: {num_epochs}")
    print(f"最佳驗證 PSNR: {best_psnr:.2f}dB")

if start_epoch > 1:
    print(f" 接續訓練統計:")
    print(f"  - 原有訓練: {start_epoch - 1} epochs")
    print(f"  - 本次訓練: {total_epochs_trained - (start_epoch - 1)} epochs")
    print(f"  - 總計訓練: {total_epochs_trained} epochs")

In [None]:
# 繪製訓練曲線
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='train loss')
plt.plot(val_losses, label='val loss')
# 如果有 early stopping，標記停止點
if early_stopping.early_stop:
    plt.axvline(x=len(train_losses), color='red', linestyle='--', alpha=0.7,
                label=f'Early Stop (Epoch {len(train_losses)})')

plt.title('train&val loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(train_psnrs, label='train PSNR')
plt.plot(val_psnrs, label='val PSNR')
# 如果有 early stopping，標記停止點
if early_stopping.early_stop:
    plt.axvline(x=len(train_losses), color='red', linestyle='--', alpha=0.7,
                label=f'Early Stop (Epoch {len(train_losses)})')
plt.title('train&val PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR (dB)')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
# 計算實際的學習率變化曲線
learning_rates = []
actual_epochs = len(train_losses)  # 實際訓練的 epoch 數

# 直接計算學習率，不使用 WarmupCosineAnnealingLR 類
for epoch in range(1, actual_epochs + 1):
    if epoch <= warmup_epochs:
        # Warmup 階段
        lr = warmup_start_lr + (base_lr - warmup_start_lr) * (epoch / warmup_epochs)
    else:
        # Cosine annealing 階段
        cosine_epochs = total_epochs - warmup_epochs
        current_cosine_epoch = epoch - warmup_epochs
        lr = eta_min + (base_lr - eta_min) * (
            1 + math.cos(math.pi * current_cosine_epoch / cosine_epochs)
        ) / 2
    learning_rates.append(lr)

plt.plot(learning_rates, 'b-', linewidth=2)
plt.axvline(x=warmup_epochs, color='red', linestyle='--', alpha=0.7,
            label=f'Warmup 結束 (Epoch {warmup_epochs})')
# 如果有 early stopping，標記停止點
if early_stopping.early_stop:
    plt.axvline(x=actual_epochs, color='orange', linestyle='--', alpha=0.7,
                label=f'Early Stop (Epoch {actual_epochs})')
plt.title('學習率變化 (Warmup + Cosine Annealing)')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# 顯示訓練統計信息
print(f"\n 訓練統計信息：")
print(f"  實際訓練 epochs: {actual_epochs}")
print(f"  最佳驗證 PSNR: {max(val_psnrs):.2f}dB (Epoch {np.argmax(val_psnrs) + 1})")
print(f"  最終驗證 PSNR: {val_psnrs[-1]:.2f}dB")
if early_stopping.early_stop:
    print(f"  Early Stopping: 是 (patience={patience})")
    print(f"  最佳權重已恢復: 是")
else:
    print(f"  Early Stopping: 否")

In [None]:

import os
import torch

MODEL_PATH = 'best_restormer_model.pth'

if os.path.exists(MODEL_PATH):
    print(f" 找到已存在的模型: {MODEL_PATH}")

    try:
        # 方法1: 嘗試使用 weights_only=False (適用於可信來源)
        print(" 嘗試載入完整檢查點...")
        checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)

        # 載入模型權重
        model.load_state_dict(checkpoint['model_state_dict'])

        # 顯示模型信息
        print(f" 模型載入成功！")
        print(f"  - 最佳 PSNR: {checkpoint.get('best_psnr', 'N/A'):.2f}dB")
        print(f"  - 訓練 Epoch: {checkpoint.get('epoch', 'N/A')}")

        if checkpoint.get('early_stopped', False):
            print(f"  - Early Stopping: 是 (停止於 Epoch {checkpoint.get('stopped_epoch', 'N/A')})")
        else:
            print(f"  - Early Stopping: 否")

        # 檢查是否包含訓練歷史
        if 'train_losses' in checkpoint:
            print(f"  - 訓練歷史: {len(checkpoint['train_losses'])} epochs")

    except Exception as e:
        print(f" 載入失敗: {e}")

        # 方法2: 嘗試只載入模型權重
        try:
            print(" 嘗試只載入模型權重...")
            # 使用 weights_only=True 只載入權重
            state_dict = torch.load(MODEL_PATH, map_location=device, weights_only=True)

            # 如果直接是 state_dict
            if 'model_state_dict' in state_dict:
                model.load_state_dict(state_dict['model_state_dict'])
            else:
                model.load_state_dict(state_dict)

            print(" 模型權重載入成功！")
            print("  注意：只載入了模型權重，無法獲取訓練歷史信息")

        except Exception as e2:
            print(f" 權重載入也失敗: {e2}")
            print(" 將使用新初始化的模型")

else:
    print(f"ℹ  未找到模型檔案: {MODEL_PATH}")
    print(" 將使用新初始化的模型")

print("\n模型準備完成，可以開始推論或繼續訓練")

def test_and_save_results(model, test_loader, device, save_path='pred.npz'):
    """測試模型並保存結果為 npz 格式"""
    model.eval()
    results = {}

    print("開始測試...")
    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Testing'):
            degraded = batch['degraded'].to(device)
            filename = batch['filename'][0]  # batch_size=1 for test

            # 模型推理
            restored = model(degraded)

            # 轉換為 numpy 陣列
            restored_np = restored.cpu().squeeze(0).numpy()  # (3, H, W)

            # 確保值在 [0, 1] 範圍內
            restored_np = np.clip(restored_np, 0, 1)

            # 轉換為 uint8 格式 (0-255)
            restored_np = (restored_np * 255).astype(np.uint8)

            # 使用原始檔名作為鍵值
            results[filename] = restored_np

    # 保存結果
    np.savez_compressed(save_path, **results)
    print(f"測試結果已保存到 {save_path}")
    print(f"共處理 {len(results)} 張圖像")

    return results

# 執行測試並生成結果
if MODEL_LOADED:
    print(f" 使用已載入的預訓練模型進行測試")
else:
    print(f" 使用當前模型進行測試")

test_results = test_and_save_results(model, test_loader, device, 'pred.npz')

# 顯示一些測試結果
def show_test_samples(results, test_dataset, num_samples=4):
    """顯示測試樣本"""
    sample_keys = list(results.keys())[:num_samples]

    if len(sample_keys) == 0:
        print(" 沒有測試結果可以顯示")
        return

    # 處理單張圖像的情況
    if len(sample_keys) == 1:
        num_samples = 1
        fig, axes = plt.subplots(2, 1, figsize=(8, 8))
        axes = axes.reshape(2, 1)
    else:
        num_samples = min(num_samples, len(sample_keys))
        fig, axes = plt.subplots(2, num_samples, figsize=(4*num_samples, 8))
        if num_samples == 1:
            axes = axes.reshape(2, 1)

    for i, key in enumerate(sample_keys[:num_samples]):
        # 載入原始降質圖像
        degraded = None
        for batch in test_loader:
            if batch['filename'][0] == key:
                degraded = batch['degraded'].squeeze(0).numpy()
                break

        if degraded is None:
            print(f"  無法找到原始圖像: {key}")
            continue

        # 取得修復後的圖像
        restored = results[key] / 255.0  # 轉回 [0,1] 範圍

        # 顯示降質圖像
        axes[0, i].imshow(np.transpose(degraded, (1, 2, 0)))
        axes[0, i].set_title(f'降質圖像: {key}')
        axes[0, i].axis('off')

        # 顯示修復圖像
        axes[1, i].imshow(np.transpose(restored, (1, 2, 0)))
        axes[1, i].set_title(f'修復圖像: {key}')
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()

show_test_samples(test_results, test_dataset)

# 顯示最終狀態
print(f"\n 測試完成總結:")
print(f"  - 使用模型: {'預訓練模型' if MODEL_LOADED else '當前模型'}")
print(f"  - 測試圖像數量: {len(test_results)}")
print(f"  - 輸出檔案: pred.npz")