## 定义模型

In [1]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self, num_channels=128, num_residual_blocks=8) -> None:
        super().__init__()

        #初始特征提取
        self.head = nn.Sequential(
            nn.Conv2d(3, num_channels, kernel_size=9, padding=4),
            nn.ReLU(inplace=True),
        )

        #残差块组（每个残差块：两层卷积 + 跳跃连接）
        res_blocks = []
        for _ in range(num_residual_blocks):
            res_blocks.append(nn.Sequential(
                nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1),
            ))
        self.res_blocks = nn.ModuleList(res_blocks)
        self.res_relu = nn.ReLU(inplace=True)

        #特征压缩
        self.tail = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

        #上采样部分
        self.upsample = nn.Sequential(
            nn.Conv2d(64, 64*4, kernel_size=3, padding=1),
            nn.PixelShuffle(upscale_factor=2),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.head(x)
        for block in self.res_blocks:
            x = self.res_relu(x + block(x))  #残差连接
        x = self.tail(x)
        x = self.upsample(x)
        return x

## 模型训练

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from datasets import DIV2KDataset
from Test import evaluate

#参数配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
lr = 0.001
epochs = 100
save_dir = "results/task1"

os.makedirs(save_dir, exist_ok=True)

#准备数据
train_dataset = DIV2KDataset(root_dir='./DS/DIV2K/train', crop_size=128, scale_factor=2)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6)
print(f"已加载 {len(train_dataset)} 张训练图片")

#准备模型、Loss、优化器
model = Net()
model.to(device)

criterion = nn.L1Loss() 
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.7)  

#开始训练
print(f"\n开始训练！批次大小：{batch_size}, 共 {epochs} 轮\n")

for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    epoch_psnr = 0
    epoch_ssim = 0

    for i, (in_imgs, out_imgs) in enumerate(train_loader):
        in_imgs = in_imgs.to(device)
        out_imgs = out_imgs.to(device)
        
        preds = model(in_imgs)
        loss = criterion(preds, out_imgs)
        
        #计算评估指标
        e = evaluate(preds, out_imgs)
        psnr = e.psnr()
        ssim = e.ssim()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_psnr += psnr.item()
        epoch_ssim += ssim.item()
    
    scheduler.step()

    # 计算平均指标
    avg_loss = epoch_loss / len(train_loader)
    avg_psnr = epoch_psnr / len(train_loader)
    avg_ssim = epoch_ssim / len(train_loader)
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch [{epoch+1:3d}/{epochs}] Loss: {avg_loss:.5f} | PSNR: {avg_psnr:.2f} | SSIM: {avg_ssim:.4f} | LR: {current_lr:.6f}")

    #可视化
    if (epoch + 1) % 20 == 0:
        model.eval()
        with torch.no_grad():
            #保存对比图
            in_resized = nn.functional.interpolate(in_imgs[:4], scale_factor=2, mode='nearest')
            comparison = torch.cat([in_resized, preds[:4], out_imgs[:4]], dim=0)
            save_image(comparison, f"{save_dir}/epoch_{epoch+1}.png", nrow=4)
            print(f"        对比图已保存：{save_dir}/epoch_{epoch+1}.png")
            
        #保存模型权重
        torch.save(model.state_dict(), f"{save_dir}/model_epoch_{epoch+1}.pth")
        print(f"        模型权重已保存")
        model.train()
        
print("\n训练完成 :)")

已加载 800 张训练图片

开始训练！批次大小：64, 共 100 轮

Epoch [  1/100] Loss: 0.24027 | PSNR: 11.07 | SSIM: 0.3592 | LR: 0.001000
Epoch [  2/100] Loss: 0.14147 | PSNR: 14.74 | SSIM: 0.4520 | LR: 0.001000
Epoch [  3/100] Loss: 0.12317 | PSNR: 15.85 | SSIM: 0.5113 | LR: 0.001000
Epoch [  4/100] Loss: 0.09810 | PSNR: 17.36 | SSIM: 0.5624 | LR: 0.001000
Epoch [  5/100] Loss: 0.08725 | PSNR: 18.45 | SSIM: 0.5923 | LR: 0.001000
Epoch [  6/100] Loss: 0.07074 | PSNR: 19.82 | SSIM: 0.6066 | LR: 0.001000
Epoch [  7/100] Loss: 0.06283 | PSNR: 20.73 | SSIM: 0.6299 | LR: 0.001000
Epoch [  8/100] Loss: 0.06001 | PSNR: 20.98 | SSIM: 0.6446 | LR: 0.001000
Epoch [  9/100] Loss: 0.05588 | PSNR: 21.57 | SSIM: 0.6748 | LR: 0.001000
Epoch [ 10/100] Loss: 0.05302 | PSNR: 21.91 | SSIM: 0.6818 | LR: 0.001000
Epoch [ 11/100] Loss: 0.05166 | PSNR: 21.96 | SSIM: 0.6863 | LR: 0.001000
Epoch [ 12/100] Loss: 0.05126 | PSNR: 22.23 | SSIM: 0.7065 | LR: 0.001000
Epoch [ 13/100] Loss: 0.05412 | PSNR: 21.92 | SSIM: 0.7017 | LR: 0.001000


## 模型测试

In [3]:
import os
from PIL import Image
import torch
from torchvision import transforms

set14_dir = './DS/Set14'
test_dir = f'{save_dir}/vaild'
os.makedirs(test_dir, exist_ok=True)

#加载预训练模型
model = Net()
model.load_state_dict(torch.load(f'{save_dir}/model_epoch_{epochs}.pth'))
model.eval()
model.to(device)
print(f"已加载模型 model_epoch_{epochs}.pth\n开始在 Set14 数据集上评估\n")

#初始化
img_file = [f for f in os.listdir(set14_dir)]
total_psnr = 0
total_ssim = 0
count = 0

with torch.no_grad():
    for img in img_file:
        count += 1
        img_path = os.path.join(set14_dir, img)

        #处理图像
        out_img = Image.open(img_path).convert('RGB')
        w, h = out_img.size
        in_img = out_img.resize((w//2, h//2), Image.BICUBIC)
        to_tensor = transforms.ToTensor()
        in_tensor = to_tensor(in_img).unsqueeze(0).to(device)
        out_tensor = to_tensor(out_img).unsqueeze(0).to(device)

        #模型推理
        pred = model(in_tensor)

        e = evaluate(pred, out_tensor)
        psnr = e.psnr()
        ssim = e.ssim()

        total_psnr += psnr.item()
        total_ssim += ssim.item()

        print(f"[{count}/{len(img_file)}] {img:20s} PSNR: {psnr:.2f} dB | SSIM: {ssim:.4f}")

        #保存对比图 - 对齐张量尺寸
        in_resized = nn.functional.interpolate(in_tensor, scale_factor=2, mode='bilinear', align_corners=False)
        clamp = torch.clamp(pred, 0, 1)
        
        # 裁剪到相同尺寸
        min_h = min(in_resized.size(2), clamp.size(2), out_tensor.size(2))
        min_w = min(in_resized.size(3), clamp.size(3), out_tensor.size(3))
        in_resized = in_resized[:, :, :min_h, :min_w]
        clamp = clamp[:, :, :min_h, :min_w]
        out_tensor_cropped = out_tensor[:, :, :min_h, :min_w]
        
        comparison = torch.cat([in_resized, clamp, out_tensor_cropped], dim=0)
        save_image(comparison, f"{test_dir}/eval_{count}.png")

avg_psnr = total_psnr / len(img_file)
avg_ssim = total_ssim / len(img_file)

print(f"\nSet14 评估结果：")
print(f"    平均 PSNR: {avg_psnr:.2f} dB")
print(f"    平均 SSIM: {avg_ssim:.4f}")
print("\nOVER!")

已加载模型 model_epoch_100.pth
开始在 Set14 数据集上评估

[1/14] zebra.jpeg           PSNR: 26.65 dB | SSIM: 0.8595
[2/14] comic.jpeg           PSNR: 23.18 dB | SSIM: 0.8113
[3/14] coastguard.jpeg      PSNR: 29.18 dB | SSIM: 0.8409
[4/14] bridge.jpeg          PSNR: 27.66 dB | SSIM: 0.8509
[5/14] foreman.jpeg         PSNR: 31.69 dB | SSIM: 0.9451
[6/14] man.jpeg             PSNR: 29.30 dB | SSIM: 0.8755
[7/14] face.jpeg            PSNR: 33.01 dB | SSIM: 0.8584
[8/14] flowers.jpeg         PSNR: 28.21 dB | SSIM: 0.8601
[9/14] pepper.jpeg          PSNR: 29.16 dB | SSIM: 0.8353
[10/14] baboon.jpeg          PSNR: 23.27 dB | SSIM: 0.6846
[11/14] ppt3.jpeg            PSNR: 23.67 dB | SSIM: 0.9125
[12/14] monarch.jpeg         PSNR: 32.34 dB | SSIM: 0.9524
[13/14] barbara.jpeg         PSNR: 26.80 dB | SSIM: 0.8483
[14/14] lenna.jpeg           PSNR: 32.17 dB | SSIM: 0.8679

Set14 评估结果：
    平均 PSNR: 28.31 dB
    平均 SSIM: 0.8574

OVER!
