## 多任务模型: 模糊图片 → SR修复 → 分类

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import os

from NetSet import SRNet, ClassifyNet
from datasets import CIFAR10Dataset, Fuzz

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = "results/multi"
os.makedirs(save_dir, exist_ok=True)

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std  = (0.2023, 0.1994, 0.2010)
normalize = transforms.Normalize(cifar10_mean, cifar10_std)

class MultiTaskNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.sr = SRNet()
        self.classifier = ClassifyNet()
        self.normalize = normalize

    def forward(self, x):
        sr_out = self.sr(x)
        sr_norm = self.normalize(sr_out)
        cls_out = self.classifier(sr_norm)
        return sr_out, cls_out

## 加载预训练权重

In [2]:
model = MultiTaskNet().to(device)

model.sr.load_state_dict(
    torch.load("results/task1/model_epoch_100.pth", map_location=device))

model.classifier.load_state_dict(
    torch.load("results/task2/model_epoch_100.pth", map_location=device))

print("预训练权重加载完成")

预训练权重加载完成


## 联合微调训练

In [3]:
batch_size = 64
epochs = 50 
lr = 0.0005
momentum = 0.9
lambda_sr = 1.0
lambda_cls = 1.0

train_dataset = CIFAR10Dataset(root_dir='./DS/CIFAR10', train=True)
test_dataset  = CIFAR10Dataset(root_dir='./DS/CIFAR10', train=False)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader  = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)

fuzz = Fuzz(scale_factor=2)

inv_mean = torch.tensor(cifar10_mean).view(3,1,1).to(device)
inv_std  = torch.tensor(cifar10_std).view(3,1,1).to(device)

def denormalize(x):
    return x * inv_std + inv_mean

sr_criterion  = nn.L1Loss()
cls_criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for epoch in range(epochs):
    model.train()
    epoch_sr_loss = 0
    epoch_cls_loss = 0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        clean = denormalize(inputs)
        blurry = torch.stack([fuzz(img) for img in clean]).to(device)

        optimizer.zero_grad()
        sr_out, cls_out = model(blurry)

        loss_sr  = sr_criterion(sr_out, clean)
        loss_cls = cls_criterion(cls_out, labels)
        loss = lambda_sr * loss_sr + lambda_cls * loss_cls
        loss.backward()
        optimizer.step()

        epoch_sr_loss  += loss_sr.item()
        epoch_cls_loss += loss_cls.item()
        _, predicted = torch.max(cls_out.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_sr  = epoch_sr_loss / len(train_loader)
    avg_cls = epoch_cls_loss / len(train_loader)
    train_acc = 100 * correct / total

    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            clean = denormalize(inputs)
            blurry = torch.stack([fuzz(img) for img in clean]).to(device)
            _, cls_out = model(blurry)
            _, predicted = torch.max(cls_out.data, 1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()
    test_acc = 100 * test_correct / test_total

    print(f"Epoch [{epoch+1:3d}/{epochs}] SR: {avg_sr:.5f} | CLS: {avg_cls:.5f} | Train: {train_acc:.2f}% | Test: {test_acc:.2f}%")

torch.save(model.state_dict(), f"{save_dir}/multi_model.pth")
print(f"保存至 {save_dir}/")

Epoch [  1/50] SR: 0.04358 | CLS: 0.73513 | Train: 75.56% | Test: 77.04%
Epoch [  2/50] SR: 0.04225 | CLS: 0.64552 | Train: 78.29% | Test: 77.38%
Epoch [  3/50] SR: 0.04131 | CLS: 0.61887 | Train: 79.05% | Test: 78.73%
Epoch [  4/50] SR: 0.04128 | CLS: 0.60075 | Train: 79.69% | Test: 78.71%
Epoch [  5/50] SR: 0.04080 | CLS: 0.58526 | Train: 80.16% | Test: 79.39%
Epoch [  6/50] SR: 0.04053 | CLS: 0.56218 | Train: 81.08% | Test: 79.03%
Epoch [  7/50] SR: 0.04036 | CLS: 0.55823 | Train: 81.07% | Test: 79.35%
Epoch [  8/50] SR: 0.04043 | CLS: 0.54553 | Train: 81.60% | Test: 79.25%
Epoch [  9/50] SR: 0.04017 | CLS: 0.53246 | Train: 82.12% | Test: 80.08%
Epoch [ 10/50] SR: 0.04041 | CLS: 0.52181 | Train: 82.61% | Test: 79.89%
Epoch [ 11/50] SR: 0.04031 | CLS: 0.51273 | Train: 82.68% | Test: 80.36%
Epoch [ 12/50] SR: 0.03995 | CLS: 0.49916 | Train: 83.10% | Test: 79.45%
Epoch [ 13/50] SR: 0.04014 | CLS: 0.49414 | Train: 83.23% | Test: 79.86%
Epoch [ 14/50] SR: 0.04004 | CLS: 0.48610 | Train: 

## 对比: 模糊直接分类 vs 多任务(SR+分类)

In [4]:
baseline = ClassifyNet().to(device)
baseline.load_state_dict(torch.load("results/task2/model_epoch_100.pth", map_location=device))
baseline.eval()
model.eval()

acc_baseline = 0
acc_multi = 0
acc_clean = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        clean = denormalize(inputs)
        blurry = torch.stack([fuzz(img) for img in clean]).to(device)

        upscaled = nn.functional.interpolate(blurry, scale_factor=2, mode='bicubic', align_corners=False)
        upscaled_norm = normalize(upscaled)
        pred_baseline = baseline(upscaled_norm).argmax(dim=1)

        _, cls_out = model(blurry)
        pred_multi = cls_out.argmax(dim=1)

        pred_clean = baseline(inputs).argmax(dim=1)

        total += labels.size(0)
        acc_baseline += (pred_baseline == labels).sum().item()
        acc_multi    += (pred_multi == labels).sum().item()
        acc_clean    += (pred_clean == labels).sum().item()

acc_baseline = 100 * acc_baseline / total
acc_multi    = 100 * acc_multi / total
acc_clean    = 100 * acc_clean / total

print(f"清晰图直接分类:       {acc_clean:.2f}%")
print(f"低分辨率图直接分类:  {acc_baseline:.2f}%")
print(f"低分辨率图 → 超分 → 分类:       {acc_multi:.2f}%")
print(f"提升: {acc_multi - acc_baseline:+.2f}%")

清晰图直接分类:       84.80%
低分辨率图直接分类:  47.26%
低分辨率图 → 超分 → 分类:       80.95%
提升: +33.69%
