## ResNet34

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import os

In [2]:
from torch.utils.data import Dataset, DataLoader
class ImageNetMiniDataset(Dataset):
    def __init__(self, txt_file, root_dir, transform=None):
        """
        txt_file: .txt 檔案路徑（例如 train.txt）
        root_dir: 圖像資料夾根目錄（例如 .）
        transform: 圖像預處理變換
        """
        self.root_dir = root_dir
        self.transform = transform
        # 讀取 .txt 檔案，存儲圖像路徑和標籤
        self.data = []
        with open(txt_file, 'r') as f:
            for line in f:
                image_path, label = line.strip().split()
                self.data.append((image_path, int(label)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, label = self.data[idx]
        # 拼接完整圖像路徑
        image_path = os.path.join(self.root_dir, image_path)
        # 讀取圖像
        image = Image.open(image_path).convert('RGB')  # 轉為 RGB 格式
        # 應用變換（如果有）
        if self.transform:
            image = self.transform(image)
        return image, label

In [3]:
import torch
from torchvision import transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定義圖像預處理（ResNet 標準預處理）
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 調整大小為 224x224
    transforms.ToTensor(),  # 轉為 Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet 均值
                        std=[0.229, 0.224, 0.225])  # ImageNet 標準差
])

# 加載數據集
extract_path = "dataset"
train_dataset = ImageNetMiniDataset(txt_file="dataset/train.txt", root_dir=extract_path, transform=transform)
val_dataset = ImageNetMiniDataset(txt_file="dataset/val.txt", root_dir=extract_path, transform=transform)
test_dataset = ImageNetMiniDataset(txt_file="dataset/test.txt", root_dir=extract_path, transform=transform)

# 創建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)


In [23]:
import torch
import torch.nn as nn
import torchvision.models as models

# 定義 ResNet34 對照組模型（不修改任何結構）
class ResNet34Baseline(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNet34Baseline, self).__init__()
        # 直接載入 torchvision 中的 ResNet34 模型，不使用預訓練權重
        self.model = models.resnet34(pretrained=False)
        # 確認全連接層的輸出類別數（ImageNet 為 1000 類）
        # ResNet34 的全連接層預設已經適配 1000 類，因此這裡不需要修改
        # 如果 num_classes 不等於 1000，可以替換全連接層如下：
        if num_classes != 1000:
            self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

# 測試模型
def test_model():
    # 創建模型實例
    model = ResNet34Baseline(num_classes=1000)
    # 模擬輸入：批次大小 4，3 通道，224x224 圖像
    x = torch.randn(4, 3, 224, 224)
    # 前向傳播
    output = model(x)
    print(f"輸出形狀：{output.shape}")  # 應為 (4, 1000)
    # 計算參數量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型總參數量：{total_params}")

In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import os
from torchsummary import summary

def main():
    num_classes = 50

    # 初始化模型
    model = ResNet34Baseline(num_classes=num_classes).to(device)
    print(f"模型總參數量：{sum(p.numel() for p in model.parameters())}")

    # 定義損失函數和優化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 訓練迴圈
    num_epochs = 10  # 訓練 10 個 epoch（可調整）
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # 前向傳播
            outputs = model(images)
            loss = criterion(outputs, labels)

            # 反向傳播和優化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 計算損失和準確率
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}] 訓練損失：{epoch_loss:.4f} 訓練準確率：{epoch_acc:.2f}%")

        # 驗證階段
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_acc = 100 * val_correct / val_total
        print(f"驗證準確率：{val_acc:.2f}%")

    # 測試階段
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_acc = 100 * test_correct / test_total
    print(f"測試準確率：{test_acc:.2f}%")

    summary(model, (3, 84, 84))  # 假設輸入是 RGB 圖像

    # 如果需要計算 FLOPS，可以使用 ptflops
    try:
        from ptflops import get_model_complexity_info
        flops, params = get_model_complexity_info(model, (3, 84, 84), as_strings=True, print_per_layer_stat=False)
        print(f"Computational complexity: {flops}")
        print(f"Number of parameters: {params}")
    except ImportError:
        print("Please install ptflops to compute FLOPS: pip install ptflops")


if __name__ == "__main__":
    main()

模型總參數量：21310322
Epoch [1/10] 訓練損失：3.2980 訓練準確率：12.66%
驗證準確率：17.56%
Epoch [2/10] 訓練損失：2.6552 訓練準確率：25.41%
驗證準確率：28.44%
Epoch [3/10] 訓練損失：2.1607 訓練準確率：37.02%
驗證準確率：38.22%
Epoch [4/10] 訓練損失：1.7826 訓練準確率：46.44%
驗證準確率：45.56%
Epoch [5/10] 訓練損失：1.4888 訓練準確率：54.41%
驗證準確率：53.11%
Epoch [6/10] 訓練損失：1.2480 訓練準確率：60.93%
驗證準確率：54.00%
Epoch [7/10] 訓練損失：1.0300 訓練準確率：67.11%
驗證準確率：57.33%
Epoch [8/10] 訓練損失：0.8115 訓練準確率：73.26%
驗證準確率：59.78%
Epoch [9/10] 訓練損失：0.6140 訓練準確率：79.22%
驗證準確率：64.00%
Epoch [10/10] 訓練損失：0.4525 訓練準確率：84.46%
驗證準確率：61.56%
測試準確率：60.44%
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 42, 42]           9,408
       BatchNorm2d-2           [-1, 64, 42, 42]             128
              ReLU-3           [-1, 64, 42, 42]               0
         MaxPool2d-4           [-1, 64, 21, 21]               0
            Conv2d-5           [-1, 64, 21, 21]          36,864
       Batc

## MyNet: SimpleResNet

In [25]:
import torch
import torch.nn as nn

# 定義基本的殘差塊
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        # 第一個卷積層，可能改變空間尺寸（通過 stride）
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        # 第二個卷積層，保持尺寸不變
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 殘差連接：如果輸入輸出通道數或尺寸不匹配，則用 1x1 卷積調整
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        # 殘差連接：將輸入（經過調整）與輸出相加
        out += self.shortcut(identity)
        out = self.relu(out)
        return out

# 定義簡化的 ResNet 模型
class SimpleResNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(SimpleResNet, self).__init__()
        # 初始卷積層：7x7 卷積，64 個濾波器，步幅 2
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        # 最大池化層：3x3，步幅 2
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 殘差塊 1：輸入 64 通道，輸出 64 通道，步幅 1
        self.block1 = ResidualBlock(64, 64, stride=1)
        # 殘差塊 2：輸入 64 通道，輸出 128 通道，步幅 2（降採樣）
        self.block2 = ResidualBlock(64, 128, stride=2)

        # 全局平均池化
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # 全連接層：128 通道到分類數量
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        # 初始卷積和池化
        x = self.conv1(x)  # 輸入：224x224x3，輸出：112x112x64
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)  # 輸出：56x56x64

        # 殘差塊
        x = self.block1(x)  # 輸出：56x56x64
        x = self.block2(x)  # 輸出：28x28x128

        # 全局平均池化和全連接層
        x = self.avgpool(x)  # 輸出：1x1x128
        x = torch.flatten(x, 1)  # 展平為 (batch_size, 128)
        x = self.fc(x)  # 輸出：(batch_size, num_classes)
        return x


In [26]:
# 測試模型
def test_model():
    # 創建模型實例
    model = SimpleResNet(num_classes=50)
    # 模擬輸入：批次大小 4，3 通道，224x224 圖像
    x = torch.randn(4, 3, 224, 224)
    # 前向傳播
    output = model(x)
    print(f"輸出形狀：{output.shape}")  # 應為 (4, 1000)
    # 計算參數量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"模型總參數量：{total_params}")

if __name__ == "__main__":
    test_model()

輸出形狀：torch.Size([4, 50])
模型總參數量：320114


In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import os
from torchsummary import summary


def main():
    num_classes = 50

    # 初始化模型
    model = SimpleResNet(num_classes=num_classes).to(device)
    print(f"模型總參數量：{sum(p.numel() for p in model.parameters())}")

    # 定義損失函數和優化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 訓練迴圈
    num_epochs = 20  # 訓練 10 個 epoch（可調整）
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # 前向傳播
            outputs = model(images)
            loss = criterion(outputs, labels)

            # 反向傳播和優化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 計算損失和準確率
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}] 訓練損失：{epoch_loss:.4f} 訓練準確率：{epoch_acc:.2f}%")

        # 驗證階段
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_acc = 100 * val_correct / val_total
        print(f"驗證準確率：{val_acc:.2f}%")

    # 測試階段
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_acc = 100 * test_correct / test_total
    print(f"測試準確率：{test_acc:.2f}%")

    summary(model, (3, 84, 84))  # 假設輸入是 RGB 圖像

    # 如果需要計算 FLOPS，可以使用 ptflops
    try:
        from ptflops import get_model_complexity_info
        flops, params = get_model_complexity_info(model, (3, 84, 84), as_strings=True, print_per_layer_stat=False)
        print(f"Computational complexity: {flops}")
        print(f"Number of parameters: {params}")
    except ImportError:
        print("Please install ptflops to compute FLOPS: pip install ptflops")


if __name__ == "__main__":
    main()

模型總參數量：320114
Epoch [1/20] 訓練損失：3.3739 訓練準確率：12.71%
驗證準確率：14.44%
Epoch [2/20] 訓練損失：2.9674 訓練準確率：20.63%
驗證準確率：22.00%
Epoch [3/20] 訓練損失：2.6523 訓練準確率：27.28%
驗證準確率：25.33%
Epoch [4/20] 訓練損失：2.4323 訓練準確率：32.13%
驗證準確率：32.89%
Epoch [5/20] 訓練損失：2.2757 訓練準確率：35.89%
驗證準確率：30.00%
Epoch [6/20] 訓練損失：2.1553 訓練準確率：38.63%
驗證準確率：38.89%
Epoch [7/20] 訓練損失：2.0509 訓練準確率：41.14%
驗證準確率：40.22%
Epoch [8/20] 訓練損失：1.9589 訓練準確率：43.38%
驗證準確率：39.33%
Epoch [9/20] 訓練損失：1.8814 訓練準確率：45.69%
驗證準確率：41.11%
Epoch [10/20] 訓練損失：1.8032 訓練準確率：47.66%
驗證準確率：46.67%
Epoch [11/20] 訓練損失：1.7382 訓練準確率：49.35%
驗證準確率：47.11%
Epoch [12/20] 訓練損失：1.6827 訓練準確率：50.81%
驗證準確率：47.78%
Epoch [13/20] 訓練損失：1.6278 訓練準確率：52.39%
驗證準確率：48.44%
Epoch [14/20] 訓練損失：1.5771 訓練準確率：53.63%
驗證準確率：49.56%
Epoch [15/20] 訓練損失：1.5296 訓練準確率：54.91%
驗證準確率：51.78%
Epoch [16/20] 訓練損失：1.4917 訓練準確率：55.74%
驗證準確率：51.56%
Epoch [17/20] 訓練損失：1.4535 訓練準確率：56.96%
驗證準確率：51.33%
Epoch [18/20] 訓練損失：1.4178 訓練準確率：57.79%
驗證準確率：55.33%
Epoch [19/20] 訓練損失：1.3752 訓練準確率：59.16%
驗證準確率：56.67%
Epoch [

## 消融實驗

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from thop import profile  # 用於計算 FLOPs 和參數量
import os

# 定義 ResidualBlock（從你提供）
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(identity)
        out = self.relu(out)
        return out

# 無 Shortcut 的 ResidualBlock（消融實驗 2）
class NoShortcutBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(NoShortcutBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out

# 無 BatchNorm 的 ResidualBlock（消融實驗 5）
class NoBatchNormBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(NoBatchNormBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out += self.shortcut(identity)
        out = self.relu(out)
        return out

# MyNet 基礎模型
class MyNet(nn.Module):
    def __init__(self, num_classes=50, block=ResidualBlock, channels=64, kernel_size=7, use_maxpool=True, pool_type='avg'):
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(3, channels, kernel_size=kernel_size, stride=2, padding=kernel_size//2, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if use_maxpool else nn.Identity()
        self.block = block(channels, channels, stride=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1)) if pool_type == 'avg' else nn.AdaptiveMaxPool2d((1, 1))
        self.fc = nn.Linear(channels, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.block(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# MyNet 增加一個殞差塊（消融實驗 7）
class MyNetExtraBlock(nn.Module):
    def __init__(self, num_classes=50, channels=64):
        super(MyNetExtraBlock, self).__init__()
        self.conv1 = nn.Conv2d(3, channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.block1 = ResidualBlock(channels, channels, stride=1)
        self.block2 = ResidualBlock(channels, channels, stride=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(channels, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# 訓練函數
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
    accuracy = 100. * correct / total
    print(f'Train Epoch: {epoch}, Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%')

# 驗證函數
def validate(model, device, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_loss += criterion(output, target).item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    accuracy = 100. * correct / total
    print(f'Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {accuracy:.2f}%')
    return accuracy

# 測試函數
def test(model, device, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    accuracy = 100. * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

# 計算 FLOPs 和參數量
def compute_flops_params(model, input_size=(1, 3, 84, 84)):
    model.eval()
    input_tensor = torch.randn(input_size).to(next(model.parameters()).device)
    flops, params = profile(model, inputs=(input_tensor,), verbose=False)
    flops = flops / 1e9  # 轉為 GFLOPs
    params = params / 1e6  # 轉為百萬參數
    return flops, params

# 運行單個實驗
def run_experiment(model, device, train_loader, val_loader, test_loader, experiment_name, epochs=10):
    print(f'\n=== Running Experiment: {experiment_name} ===')
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    best_acc = 0.0
    best_model_path = f'best_model_{experiment_name}.pth'
    
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, criterion, epoch)
        val_acc = validate(model, device, val_loader, criterion)
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), best_model_path)
    
    model.load_state_dict(torch.load(best_model_path))
    test_acc = test(model, device, test_loader)
    flops, params = compute_flops_params(model)
    print(f'FLOPs: {flops:.2f} GFLOPs, Parameters: {params:.2f} M')
    return best_acc, test_acc, flops, params

# 主函數：運行所有消融實驗
def ablation_study(train_loader, val_loader, test_loader, num_classes=50, epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results = {}
    
    # 移除 Shortcut
    model = MyNet(num_classes=num_classes, block=NoShortcutBlock).to(device)
    val_acc, test_acc, flops, params = run_experiment(model, device, train_loader, val_loader, test_loader, 'No_Shortcut')
    results['No_Shortcut'] = (val_acc, test_acc, flops, params)
    
    # 移除最大池化
    model = MyNet(num_classes=num_classes, use_maxpool=False).to(device)
    val_acc, test_acc, flops, params = run_experiment(model, device, train_loader, val_loader, test_loader, 'No_MaxPool')
    results['No_MaxPool'] = (val_acc, test_acc, flops, params)
    
    # 移除 BatchNorm
    model = MyNet(num_classes=num_classes, block=NoBatchNormBlock).to(device)
    val_acc, test_acc, flops, params = run_experiment(model, device, train_loader, val_loader, test_loader, 'No_BatchNorm')
    results['No_BatchNorm'] = (val_acc, test_acc, flops, params)
    
    # 增加殞差塊
    model = MyNetExtraBlock(num_classes=num_classes).to(device)
    val_acc, test_acc, flops, params = run_experiment(model, device, train_loader, val_loader, test_loader, 'Extra_Block')
    results['Extra_Block'] = (val_acc, test_acc, flops, params)
    
    # 打印結果
    print("\n=== Ablation Study Results ===")
    for exp, (val_acc, test_acc, flops, params) in results.items():
        print(f"{exp}: Validation Accuracy = {val_acc:.2f}%, Test Accuracy = {test_acc:.2f}%, FLOPs = {flops:.2f} GFLOPs, Parameters = {params:.2f} M")

if __name__ == "__main__":
    # 假設 train_loader, val_loader, test_loader 已定義
    ablation_study(train_loader, val_loader, test_loader, num_classes=50, epochs=10)


=== Running Experiment: No_Shortcut ===
Train Epoch: 1, Loss: 3.4925, Accuracy: 10.66%
Validation Loss: 3.2399, Accuracy: 12.00%
Train Epoch: 2, Loss: 3.2325, Accuracy: 15.61%
Validation Loss: 3.1421, Accuracy: 13.33%
Train Epoch: 3, Loss: 3.0432, Accuracy: 19.36%
Validation Loss: 2.9441, Accuracy: 16.89%
Train Epoch: 4, Loss: 2.9040, Accuracy: 22.28%
Validation Loss: 2.7958, Accuracy: 22.44%
Train Epoch: 5, Loss: 2.7971, Accuracy: 24.67%
Validation Loss: 2.6693, Accuracy: 24.89%
Train Epoch: 6, Loss: 2.7083, Accuracy: 26.44%
Validation Loss: 2.7531, Accuracy: 20.67%
Train Epoch: 7, Loss: 2.6359, Accuracy: 28.11%
Validation Loss: 2.6102, Accuracy: 26.00%
Train Epoch: 8, Loss: 2.5719, Accuracy: 29.67%
Validation Loss: 2.5026, Accuracy: 26.22%
Train Epoch: 9, Loss: 2.5206, Accuracy: 30.75%
Validation Loss: 2.5405, Accuracy: 28.67%
Train Epoch: 10, Loss: 2.4753, Accuracy: 31.40%
Validation Loss: 2.4529, Accuracy: 29.78%


  model.load_state_dict(torch.load(best_model_path))


Test Accuracy: 33.56%
FLOPs: 0.05 GFLOPs, Parameters: 0.09 M

=== Running Experiment: No_MaxPool ===
Train Epoch: 1, Loss: 3.5490, Accuracy: 9.53%
Validation Loss: 3.3288, Accuracy: 13.11%
Train Epoch: 2, Loss: 3.3627, Accuracy: 12.88%
Validation Loss: 3.2386, Accuracy: 13.11%
Train Epoch: 3, Loss: 3.2431, Accuracy: 15.38%
Validation Loss: 3.1593, Accuracy: 14.22%
Train Epoch: 4, Loss: 3.1366, Accuracy: 17.66%
Validation Loss: 2.9981, Accuracy: 16.44%
Train Epoch: 5, Loss: 3.0502, Accuracy: 19.61%
Validation Loss: 2.9576, Accuracy: 18.00%
Train Epoch: 6, Loss: 2.9699, Accuracy: 21.50%
Validation Loss: 2.8174, Accuracy: 20.89%
Train Epoch: 7, Loss: 2.9093, Accuracy: 22.49%
Validation Loss: 2.8711, Accuracy: 19.78%
Train Epoch: 8, Loss: 2.8536, Accuracy: 23.87%
Validation Loss: 2.7357, Accuracy: 24.22%
Train Epoch: 9, Loss: 2.8006, Accuracy: 25.09%
Validation Loss: 2.7178, Accuracy: 23.56%
Train Epoch: 10, Loss: 2.7575, Accuracy: 25.92%
Validation Loss: 2.9475, Accuracy: 21.56%
Test Accu