In [1]:
# 导入与设备
import os, time, random
from typing import Tuple, Dict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

print("PyTorch:", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


# 实用函数
def seed_everything(seed: int = 42, deterministic: bool = False) -> None:
    import numpy as np
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.use_deterministic_algorithms(True)
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True


def count_params(model: nn.Module) -> Tuple[int, int]:
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable


@torch.no_grad()
def evaluate_acc(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return correct / total


def run_one_epoch(
    loader: DataLoader,
    model: nn.Module,
    criterion: nn.Module,
    optimizer: optim.Optimizer | None = None,
    train: bool = False,
    device: torch.device = torch.device("cpu"),
) -> Tuple[float, float]:
    model.train(train)
    epoch_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        if train:
            optimizer.zero_grad(set_to_none=True)

        with torch.set_grad_enabled(train):
            outputs = model(images)
            loss = criterion(outputs, labels)
            if train:
                loss.backward()
                optimizer.step()

        batch_size = images.size(0)
        epoch_loss += loss.item() * batch_size
        correct += (outputs.argmax(1) == labels).sum().item()
        total += batch_size

    return epoch_loss / total, correct / total

PyTorch: 2.7.1+cu118
Device: cuda


In [None]:
# ============= 新增版本 2：BN-only（有BN，无残差） =============
class BNCNN(nn.Module):
    """
    基于 SimpleCNN 的骨干，在每个卷积后、激活前加入 BatchNorm
    不包含任何残差连接
    """
    def __init__(self, in_channels=1, num_classes=10, img_size=28):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2), 

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )

        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, img_size, img_size)
            flat_dim = self.features(dummy).view(1, -1).size(1)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flat_dim, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [3]:
# 数据加载（与你提供的实现一致）
# =========================
def make_loaders(
    name: str,
    root: str = "./data",
    batch: int = 128,
    val_ratio: float = 0.1,
    workers: int = 2,
    pin: bool = False,
) -> Dict[str, object]:
    name_l = name.lower()
    if name_l == "mnist":
        tfm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        train_full = datasets.MNIST(root, train=True, download=True, transform=tfm)
        test_ds = datasets.MNIST(root, train=False, download=True, transform=tfm)
        in_channels, img_size = 1, 28
    elif name_l == "cifar10":
        tfm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2470, 0.2435, 0.2616)),
        ])
        train_full = datasets.CIFAR10(root, train=True, download=True, transform=tfm)
        test_ds = datasets.CIFAR10(root, train=False, download=True, transform=tfm)
        in_channels, img_size = 3, 32
    else:
        raise ValueError("Unsupported dataset name.")

    val_size = int(len(train_full) * val_ratio)
    train_size = len(train_full) - val_size
    train_ds, val_ds = random_split(
        train_full, [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )

    def mk(ds, bs, shuffle):
        return DataLoader(
            ds, batch_size=bs, shuffle=shuffle,
            num_workers=workers, pin_memory=pin
        )

    return {
        "train": mk(train_ds, batch, True),
        "val":   mk(val_ds, batch * 2, False),
        "test":  mk(test_ds, batch * 2, False),
        "in_channels": in_channels,
        "img_size": img_size,
    }

In [4]:
# 训练配置与例行函数（与你提供的版本一致）
# =========================
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    epochs: int = 15,
    lr: float = 1e-3,
    weight_decay: float = 5e-4,
    ckpt_path: str = "best.pt",
    use_plateau: bool = True,
):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5,
                                  patience=2, min_lr=1e-5) if use_plateau else None

    best_val_loss = float("inf")
    early_patience = 6
    patience = early_patience
    history = {"train_loss": [], "train_acc": [],
               "val_loss": [], "val_acc": [], "lrs": [], "time": []}

    for ep in range(1, epochs + 1):
        ep_start = time.time()

        tr_loss, tr_acc = run_one_epoch(
            train_loader, model, criterion, optimizer,
            train=True, device=device
        )
        val_loss, val_acc = run_one_epoch(
            val_loader, model, criterion, optimizer=None,
            train=False, device=device
        )

        if scheduler is not None:
            scheduler.step(val_loss)

        history["train_loss"].append(tr_loss)
        history["train_acc"].append(tr_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        history["lrs"].append(optimizer.param_groups[0]["lr"])
        history["time"].append(time.time() - ep_start)

        print(f"[{os.path.basename(ckpt_path)}] Epoch {ep:02d}/{epochs} | "
              f"time {history['time'][-1]:.1f}s | "
              f"Train {tr_loss:.4f}/{tr_acc:.4f} | "
              f"Val {val_loss:.4f}/{val_acc:.4f} | "
              f"LR {optimizer.param_groups[0]['lr']:.1e}")

        if val_loss < best_val_loss - 1e-6:
            best_val_loss = val_loss
            patience = 0
            torch.save({"model": model.state_dict()}, ckpt_path)
        else:
            patience += 1
            if patience >= early_patience:
                print("Early stopping.")
                break

    # 加载最佳权重
    state = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(state["model"])
    return history

In [5]:
# =========================
# 实验主入口（四个模型，可逐个运行）
# =========================
if __name__ == "__main__":
    seed_everything(42)

    # 统一实验配置：MNIST 10 epochs，CIFAR10 25 epochs（控制变量）
    cfg = {
        "mnist": {"epochs": 10, "lr": 1e-3, "weight_decay": 5e-4},
        "cifar10": {"epochs": 10, "lr": 1e-3, "weight_decay": 5e-4},
    }

    # 你可以按需注释任何一段，避免全部训练耗时
    results = {}

In [6]:
# ---------- MNIST ----------
mnist = make_loaders("mnist", root="./data")
in_ch_m, img_m = mnist["in_channels"], mnist["img_size"]

In [7]:
# BNCNN
model = BNCNN(in_ch_m, 10, img_m).to(device)
print("MNIST-BNCNN params:", count_params(model))
history = train_model(model, mnist["train"], mnist["val"], device,
                    epochs=cfg["mnist"]["epochs"], lr=cfg["mnist"]["lr"],
                    weight_decay=cfg["mnist"]["weight_decay"],
                    ckpt_path="./mnist_bn_best.pt")
acc = evaluate_acc(model, mnist["test"], device)
total_time = sum(history["time"])
print(f"BNCNN on MNIST - Test Accuracy: {acc:.4f}, Total Training Time: {total_time:.2f} seconds")

# 如需保存到 results（可选）
results["mnist_bn"] = (acc, count_params(model)[0], total_time)

MNIST-BNCNN params: (50378, 50378)
[mnist_bn_best.pt] Epoch 01/10 | time 7.1s | Train 0.1485/0.9555 | Val 0.0707/0.9787 | LR 1.0e-03
[mnist_bn_best.pt] Epoch 02/10 | time 4.7s | Train 0.0500/0.9850 | Val 0.0735/0.9782 | LR 1.0e-03
[mnist_bn_best.pt] Epoch 03/10 | time 4.8s | Train 0.0387/0.9882 | Val 0.0518/0.9845 | LR 1.0e-03
[mnist_bn_best.pt] Epoch 04/10 | time 4.7s | Train 0.0311/0.9903 | Val 0.0506/0.9852 | LR 1.0e-03
[mnist_bn_best.pt] Epoch 05/10 | time 4.7s | Train 0.0264/0.9915 | Val 0.0491/0.9858 | LR 1.0e-03
[mnist_bn_best.pt] Epoch 06/10 | time 4.7s | Train 0.0221/0.9931 | Val 0.0538/0.9842 | LR 1.0e-03
[mnist_bn_best.pt] Epoch 07/10 | time 4.6s | Train 0.0196/0.9940 | Val 0.0526/0.9833 | LR 1.0e-03
[mnist_bn_best.pt] Epoch 08/10 | time 4.6s | Train 0.0187/0.9942 | Val 0.0481/0.9868 | LR 1.0e-03
[mnist_bn_best.pt] Epoch 09/10 | time 4.5s | Train 0.0180/0.9943 | Val 0.0482/0.9867 | LR 1.0e-03
[mnist_bn_best.pt] Epoch 10/10 | time 4.4s | Train 0.0158/0.9953 | Val 0.0473/0.986

In [8]:
# ---------- CIFAR10 ----------
cifar = make_loaders("cifar10", root="./data")
in_ch_c, img_c = cifar["in_channels"], cifar["img_size"]

In [9]:
 # BNCNN
model = BNCNN(in_ch_c, 10, img_c).to(device)
print("CIFAR10-BNCNN params:", count_params(model))
history = train_model(model, cifar["train"], cifar["val"], device,
                    epochs=cfg["cifar10"]["epochs"], lr=cfg["cifar10"]["lr"],
                    weight_decay=cfg["cifar10"]["weight_decay"],
                    ckpt_path="./cifar_bn_best.pt")
acc = evaluate_acc(model, cifar["test"], device)
total_time = sum(history["time"])
print(f"BNCNN on CIFAR10 - Test Accuracy: {acc:.4f}, Total Training Time: {total_time:.2f} seconds")

# 如需保存到 results（可选）
results["cifar_bn"] = (acc, count_params(model)[0], total_time)

CIFAR10-BNCNN params: (60554, 60554)
[cifar_bn_best.pt] Epoch 01/10 | time 4.7s | Train 1.3374/0.5301 | Val 1.0964/0.6232 | LR 1.0e-03
[cifar_bn_best.pt] Epoch 02/10 | time 4.7s | Train 0.9813/0.6563 | Val 0.9391/0.6864 | LR 1.0e-03
[cifar_bn_best.pt] Epoch 03/10 | time 4.5s | Train 0.8774/0.6949 | Val 0.9272/0.6780 | LR 1.0e-03
[cifar_bn_best.pt] Epoch 04/10 | time 4.7s | Train 0.8144/0.7186 | Val 0.9296/0.6710 | LR 1.0e-03
[cifar_bn_best.pt] Epoch 05/10 | time 4.7s | Train 0.7593/0.7375 | Val 0.8390/0.7098 | LR 1.0e-03
[cifar_bn_best.pt] Epoch 06/10 | time 4.4s | Train 0.7076/0.7567 | Val 0.8506/0.7088 | LR 1.0e-03
[cifar_bn_best.pt] Epoch 07/10 | time 4.2s | Train 0.6708/0.7690 | Val 0.8823/0.6916 | LR 1.0e-03
[cifar_bn_best.pt] Epoch 08/10 | time 4.3s | Train 0.6350/0.7801 | Val 0.9249/0.6904 | LR 5.0e-04
[cifar_bn_best.pt] Epoch 09/10 | time 4.4s | Train 0.5416/0.8172 | Val 0.7877/0.7316 | LR 5.0e-04
[cifar_bn_best.pt] Epoch 10/10 | time 4.4s | Train 0.5199/0.8260 | Val 0.8215/0.7