
# （CNN / ViT / T2T-ViT / EfficientNetV2）four models comparison on MNIST dataset



In [None]:
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# environment
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
if device.type == 'cuda':
    try:
        print('GPU:', torch.cuda.get_device_name(0))
    except Exception:
        pass

# seed
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

Using device: cuda
GPU: Tesla T4


In [None]:
# MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

trainval = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset  = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# split train/val
train_size = 50000
val_size = len(trainval) - train_size
trainset, valset = random_split(trainval, [train_size, val_size], generator=torch.Generator().manual_seed(123))

batch_size = 128
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=(device.type=='cuda'))
val_loader   = DataLoader(valset,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=(device.type=='cuda'))
test_loader  = DataLoader(testset,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=(device.type=='cuda'))

len(train_loader), len(val_loader), len(test_loader)

100%|██████████| 9.91M/9.91M [00:00<00:00, 11.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 344kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.18MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.99MB/s]


(391, 79, 79)

In [None]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    crit = nn.CrossEntropyLoss()
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = crit(logits, y)
        loss_sum += loss.item()
        pred = logits.argmax(1)
        total += y.size(0)
        correct += (pred == y).sum().item()
    return loss_sum / len(loader), 100.0 * correct / total

def train_model(model, train_loader, val_loader, epochs=20, lr=1e-3, wd=1e-4, name="model"):
    crit = nn.CrossEntropyLoss()
    opt = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

    hist = {
        "train_loss": [],
        "train_acc": [],
        "val_acc": [],
        "test_acc": [],
        "times": [],  
    }
    start0 = time.time()

    for ep in range(epochs):
        model.train()
        total, correct, run_loss = 0, 0, 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            logits = model(x)
            loss = crit(logits, y)
            loss.backward()
            opt.step()
            run_loss += loss.item()
            pred = logits.argmax(1)
            total += y.size(0)
            correct += (pred == y).sum().item()

        train_loss = run_loss / len(train_loader)
        train_acc  = 100.0 * correct / total
        _, val_acc = evaluate(model, val_loader)
        _, test_acc = evaluate(model, test_loader)

        hist["train_loss"].append(train_loss)
        hist["train_acc"].append(train_acc)
        hist["val_acc"].append(val_acc)
        hist["test_acc"].append(test_acc)
        hist["times"].append(time.time() - start0)

        sch.step()
        print(f"[{name}] Epoch {ep+1:02d}: loss={train_loss:.4f}  train={train_acc:.2f}%  val={val_acc:.2f}%  test={test_acc:.2f}%")

    return hist

In [None]:
# --- model 1：Simple CNN ---
class SimpleCNN(nn.Module):
    def __init__(self, n_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2), nn.Dropout(0.25),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2), nn.Dropout(0.25),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*7*7, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, n_classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# --- model 2：ViT ---
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=28, patch_size=4, in_channels=1, embed_dim=64):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class MSA(nn.Module):
    def __init__(self, dim=64, heads=4, p=0.1):
        super().__init__()
        assert dim % heads == 0
        self.h = heads
        self.d = dim // heads
        self.qkv = nn.Linear(dim, dim*3)
        self.out = nn.Linear(dim, dim)
        self.drop = nn.Dropout(p)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.h, self.d).permute(2,0,3,1,4)
        q,k,v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * (self.d ** -0.5)
        attn = attn.softmax(-1)
        attn = self.drop(attn)
        out = (attn @ v).transpose(1,2).reshape(B,N,C)
        return self.out(out)

class MLP(nn.Module):
    def __init__(self, dim=64, ratio=4, p=0.1):
        super().__init__()
        hid = int(dim*ratio)
        self.fc1 = nn.Linear(dim, hid)
        self.fc2 = nn.Linear(hid, dim)
        self.drop = nn.Dropout(p)
    def forward(self, x):
        x = self.fc1(x); x = F.gelu(x); x = self.drop(x)
        x = self.fc2(x); x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim=64, heads=4, ratio=4, p=0.1):
        super().__init__()
        self.n1 = nn.LayerNorm(dim)
        self.attn = MSA(dim, heads, p)
        self.n2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, ratio, p)
    def forward(self, x):
        x = x + self.attn(self.n1(x))
        x = x + self.mlp(self.n2(x))
        return x

class ViT(nn.Module):
    def __init__(self, img_size=28, patch=4, dim=64, layers=6, heads=4, n_classes=10, p=0.1):
        super().__init__()
        self.pe = nn.Parameter(torch.zeros(1, (img_size//patch)**2 + 1, dim))
        self.cls = nn.Parameter(torch.zeros(1,1,dim))
        self.embed = PatchEmbedding(img_size, patch, 1, dim)
        self.blocks = nn.ModuleList([Block(dim, heads, 4, p) for _ in range(layers)])
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, n_classes)
        with torch.no_grad():
            self.pe.normal_(std=0.02); self.cls.normal_(std=0.02)
    def forward(self, x):
        B = x.size(0)
        x = self.embed(x)
        cls = self.cls.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1) + self.pe
        for blk in self.blocks: x = blk(x)
        x = self.norm(x[:,0])
        return self.head(x)

# --- model 3：T2T-ViT ---
class T2TModule(nn.Module):
    def __init__(self, in_ch=1, token_dim=64, embed_dim=64):
        super().__init__()
        self.stage1 = nn.Conv2d(in_ch, token_dim, 4, 4)
        self.stage2 = nn.Conv2d(token_dim, embed_dim, 2, 2)
        self.stage3 = nn.Conv2d(embed_dim, embed_dim, 3, 3)
    def forward(self, x):
        x = F.gelu(self.stage1(x))
        x = F.gelu(self.stage2(x))
        x = F.gelu(self.stage3(x))
        x = x.flatten(2).transpose(1,2)
        return x

class T2TViT(nn.Module):
    def __init__(self, dim=64, heads=4, layers=6, n_classes=10):
        super().__init__()
        self.t2t = T2TModule(1, 64, dim)
        self.pe = nn.Parameter(torch.zeros(1, 1 + 1*1, dim)) 
        self.cls = nn.Parameter(torch.zeros(1,1,dim))
        self.blocks = nn.ModuleList([Block(dim, heads) for _ in range(layers)])
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, n_classes)
        with torch.no_grad():
            self.pe.normal_(std=0.02); self.cls.normal_(std=0.02)
    def forward(self, x):
        B = x.size(0)
        x = self.t2t(x)           
        cls = self.cls.expand(B, -1, -1)
        pe = F.interpolate(self.pe.transpose(1,2), size=(x.size(1),), mode='linear', align_corners=False).transpose(1,2)
        x = torch.cat([cls, x], dim=1) + pe
        for blk in self.blocks: x = blk(x)
        x = self.norm(x[:,0])
        return self.head(x)

# --- model 4：EfficientNetV2 ---
def try_build_efficientnetv2(num_classes=10):
    try:
        import timm
        model = timm.create_model('efficientnetv2_s', pretrained=False, num_classes=num_classes, in_chans=1)
        return model
    except Exception as e:
        print("timm/efficientnetv2 不可用，改用 Tiny 近似实现。原因：", e)

        # Fused-MBConv like tiny
        def conv_bn_act(in_c, out_c, k=3, s=1):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, k, s, k//2, bias=False),
                nn.BatchNorm2d(out_c),
                nn.SiLU()
            )
        class FusedMBConv(nn.Module):
            def __init__(self, in_c, out_c, expand=4, k=3, s=1):
                super().__init__()
                mid = int(in_c*expand)
                self.block = nn.Sequential(
                    conv_bn_act(in_c, mid, k, s),
                    conv_bn_act(mid, out_c, 1, 1)
                )
                self.use_skip = (s==1 and in_c==out_c)
            def forward(self, x):
                out = self.block(x)
                return x + out if self.use_skip else out

        class EfficientNetV2Tiny(nn.Module):
            def __init__(self, num_classes=10):
                super().__init__()
                self.stem = conv_bn_act(1, 24, 3, 1)
                self.stage1 = FusedMBConv(24, 24, expand=2, k=3, s=1)
                self.stage2 = FusedMBConv(24, 48, expand=4, k=3, s=2)
                self.stage3 = FusedMBConv(48, 64, expand=4, k=3, s=2)
                self.stage4 = FusedMBConv(64, 96, expand=4, k=3, s=2)
                self.head = nn.Sequential(
                    nn.AdaptiveAvgPool2d(1),
                    nn.Flatten(),
                    nn.Linear(96, num_classes)
                )
            def forward(self, x):
                x = self.stem(x)
                x = self.stage1(x)
                x = self.stage2(x)
                x = self.stage3(x)
                x = self.stage4(x)
                x = self.head(x)
                return x

        return EfficientNetV2Tiny(num_classes)

# Build model by name
def build_model(name):
    name = name.lower()
    if name == "cnn":
        return SimpleCNN().to(device)
    if name == "vit":
        return ViT().to(device)
    if name == "t2t-vit":
        return T2TViT().to(device)
    if name == "efficientnetv2":
        return try_build_efficientnetv2().to(device)
    raise ValueError(name)

In [None]:
epochs = 20  
lr = 1e-3
wd = 1e-4

histories = {}
for name in ["CNN", "ViT", "T2T-ViT", "EfficientNetV2"]:
    print("\n==== 训练", name, "====")
    model = build_model(name)
    h = train_model(model, train_loader, val_loader, epochs=epochs, lr=lr, wd=wd, name=name)
    histories[name] = h


==== 训练 CNN ====
[CNN] Epoch 01: loss=0.2620  train=91.59%  val=98.10%  test=98.51%
[CNN] Epoch 02: loss=0.0801  train=97.59%  val=98.66%  test=98.87%
[CNN] Epoch 03: loss=0.0602  train=98.15%  val=98.73%  test=99.00%


In [None]:
# plotting
def plot_metric(metric_key, ylabel, title):
    plt.figure(figsize=(6,4))
    for name, h in histories.items():
        plt.plot(h[metric_key], label=name)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.show()

# 1) Train Loss
plot_metric("train_loss", "Loss", "Loss comparison")

# 2) Train Acc
plot_metric("train_acc", "Accuracy (%)", "train accuracy")

# 3) Test Acc
plot_metric("test_acc", "Accuracy (%)", "test accuracy")

# 4) Test Acc vs. Time
plt.figure(figsize=(6,4))
for name, h in histories.items():
    plt.plot(h["times"], h["test_acc"], marker='o', label=name)
plt.xlabel("Time (s)")
plt.ylabel("Test Accuracy (%)")
plt.title("accuracy-time comparison")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()