In [1]:
from math import exp, log10
from os import listdir
from os.path import join
import torch
from PIL import Image
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [2]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])

# 读取图像转为YCbCr模式，得到Y通道
def load_img(filepath):
    img = Image.open(filepath).convert('YCbCr')
    y, _, _ = img.split()
    return y

# 裁剪大小，宽高一致为300
# 如果想训练自己的数据集，请根据情况修改裁剪大小
CROP_SIZE = 300

# 封装数据集，适配后面的torch.utils.data.DataLoader中的dataset，定义成类似形式
# 类参数为图像文件夹路径和放大倍数
# __len__(self) 定义当被len()函数调用时的行为（返回容器中元素的个数）
#__getitem__(self) 定义获取容器中指定元素的行为，相当于self[key]，即允许类对象可以有索引操作。
#__iter__(self) 定义当迭代容器中的元素的行为
# 返回输入图像和标签，传入DataLoader的dataset参数
class DatasetFromFolder(Dataset):
    def __init__(self, image_dir, zoom_factor):
        super(DatasetFromFolder, self).__init__()
        self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)] # 图像路径列表
        crop_size = CROP_SIZE - (CROP_SIZE % zoom_factor) # 处理放大倍数，防止用户瞎设置，本例只能设置为2，3，4，大小不变
        # 数据集变换
        # 还有一些其他的变换操作，如归一化等，遇到一个积累一个
        self.input_transform = transforms.Compose([transforms.CenterCrop(crop_size), # 从图片中心裁剪成300*300
                                                   transforms.Resize(
                                                       crop_size // zoom_factor),    # Resize, 输入应该是缩放倍数后的图像，因为先缩小后放大
                                                   transforms.Resize(
                                                       crop_size, interpolation=Image.BICUBIC), # 双三次插值
                                                   transforms.ToTensor()]) # 图像转成tensor
        # label标签，超分不是分类问题，定义成一样的就行
        self.target_transform = transforms.Compose(
            [transforms.CenterCrop(crop_size), transforms.ToTensor()])

    def __getitem__(self, index):
        input = load_img(self.image_filenames[index]) # 输入是图像的Y通道，即亮度通道
        target = input.copy()
        input = self.input_transform(input)
        target = self.target_transform(target)
        return input, target

    def __len__(self):
        return len(self.image_filenames) # 图像个数

In [3]:
"""
计算ssim函数
"""


# 计算一维的高斯分布向量
def gaussian(window_size, sigma):
    gauss = torch.Tensor(
        [exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()


# 创建高斯核，通过两个一维高斯分布向量进行矩阵乘法得到
# 可以设定channel参数拓展为3通道
def create_window(window_size, channel=1):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(
        _1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(
        channel, 1, window_size, window_size).contiguous()
    return window


def psnr(loss):
    return 10 * log10(1 / loss.item())

# 计算SSIM
# 直接使用SSIM的公式，但是在计算均值时，不是直接求像素平均值，而是采用归一化的高斯核卷积来代替。
# 在计算方差和协方差时用到了公式Var(X)=E[X^2]-E[X]^2, cov(X,Y)=E[XY]-E[X]E[Y].
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
    if val_range is None:
        if torch.max(img1) > 128:
            max_val = 255
        else:
            max_val = 1

        if torch.min(img1) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val
    else:
        L = val_range

    padd = 0
    (_, channel, height, width) = img1.size()
    if window is None:
        real_size = min(window_size, height, width)
        window = create_window(real_size, channel=channel).to(img1.device)

    # 图像卷积后的均值
    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)

    # 均值平方
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    # 方差
    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd,
                         groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd,
                         groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=padd,
                       groups=channel) - mu1_mu2

    # SSIM默认常数
    C1 = (0.01 * L) ** 2
    C2 = (0.03 * L) ** 2

    v1 = 2.0 * sigma12 + C2
    v2 = sigma1_sq + sigma2_sq + C2
    cs = torch.mean(v1 / v2)  # contrast sensitivity

    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

    # 平均SSIM
    if size_average:
        ret = ssim_map.mean()
    else:
        ret = ssim_map.mean(1).mean(1).mean(1)

    if full:
        return ret, cs
    return ret


# 封装成类
class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.val_range = val_range

        # Assume 1 channel for SSIM
        self.channel = 1
        self.window = create_window(window_size)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.dtype == img1.dtype:
            window = self.window
        else:
            window = create_window(self.window_size, channel).to(
                img1.device).type(img1.dtype)
            self.window = window
            self.channel = channel

        return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)

In [4]:
# 搭建SRCNN
class SRCNN(nn.Module):
    def __init__(self, upscale_factor):
        super(SRCNN, self).__init__()

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=9, padding=9//2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1)
        self.conv3 = nn.Conv2d(32, 1, kernel_size=5, padding=5//2)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.pixel_shuffle(x)
        return x


In [5]:
# 放大倍数
zoom_factor = 3
nb_epochs = 500
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
torch.cuda.manual_seed(0)
BATCH_SIZE = 4
NUM_WORKERS = 0
trainset = DatasetFromFolder(r"E:\BSDS300-images\BSDS300\images\train", zoom_factor)
valset = DatasetFromFolder(r"E:\BSDS300-images\BSDS300\images\train", zoom_factor)
trainloader = DataLoader(dataset=trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
valloader = DataLoader(dataset=valset, batch_size=BATCH_SIZE,
                       shuffle=False, num_workers=NUM_WORKERS)

model = SRCNN(1).to(device)  # 模型，net.to(device)保证用GPU训练，本地训练可以不加，服务器上一定要加
criterion = nn.MSELoss()  # 损失函数为MSE
optimizer = optim.Adam(
    [
        {"params": model.conv1.parameters(), "lr": 0.0001},
        {"params": model.conv2.parameters(), "lr": 0.0001},
        {"params": model.conv3.parameters(), "lr": 0.00001},
    ]
)  # Adam优化器，设定三个层学习率为论文中的值，最后一层更小
# 一般的优化器传参就是optim.Adam(lr=0.0001), 因为最后一层学习率不同，所以用字典形式传参。



In [6]:
best_psnr = 0.0
for epoch in range(nb_epochs):
    # 训练
    epoch_loss = 0
    for iteration, batch in enumerate(trainloader):  # batchsize为4，则trainloader的长度是50，遍历它
        input, target = batch[0].to(device), batch[1].to(device)  # 每个trainloader是由图像张量和标签构成
        optimizer.zero_grad()  # 训练经典三步第一步：梯度归零
        out = model(input)  # 输入模型得到输出
        loss = criterion(out, target)  # 计算损失
        loss.backward()  # 训练经典三步第二步：反向传播
        optimizer.step()  # 训练经典三步第三步：梯度下降，更新一步参数
        epoch_loss += loss.item()  # 累加损失
    print(f"Epoch {epoch}. Training loss: {epoch_loss / len(trainloader)}")  # 训练完计算损失

Epoch 0. Training loss: 0.22559629157185554
Epoch 1. Training loss: 0.06908052768558264
Epoch 2. Training loss: 0.009506322834640741
Epoch 3. Training loss: 0.008707433631643653
Epoch 4. Training loss: 0.008004619562998415
Epoch 5. Training loss: 0.007428309861570597
Epoch 6. Training loss: 0.00686546387616545
Epoch 7. Training loss: 0.006182491299696267
Epoch 8. Training loss: 0.005528760552406311
Epoch 9. Training loss: 0.005101262568496167
Epoch 10. Training loss: 0.004843900313135236
Epoch 11. Training loss: 0.004683062289841473
Epoch 12. Training loss: 0.004543563791085035
Epoch 13. Training loss: 0.004416957129724324
Epoch 14. Training loss: 0.004303831227589399
Epoch 15. Training loss: 0.004198305427562446
Epoch 16. Training loss: 0.004102836740203202
Epoch 17. Training loss: 0.004010234549641609
Epoch 18. Training loss: 0.003925529937259853
Epoch 19. Training loss: 0.0038497459492646156
Epoch 20. Training loss: 0.0037847025645896793
Epoch 21. Training loss: 0.003727935783099383

In [10]:
# 验证
sum_psnr = 0.0
sum_ssim = 0.0
with torch.no_grad():
    for batch in valloader:
        input, target = batch[0].to(device), batch[1].to(device)
        out = model(input)
        loss = criterion(out, target)
        pr = psnr(loss)
        sm = ssim(input, out)
        sum_psnr += pr
        sum_ssim += sm
print(f"Average PSNR: {sum_psnr / len(valloader)} dB.")
print(f"Average SSIM: {sum_ssim / len(valloader)} ")
avg_psnr = sum_psnr / len(valloader)
if avg_psnr >= best_psnr:
    best_psnr = avg_psnr  # 用psnr衡量模型，保存最好的
    torch.save(model, r"best_model_SRCNN_3.pth")

Average PSNR: 25.622924793060207 dB.
Average SSIM: 0.9302493929862976 
