In [1]:
import os
import cv2
import torch
from natsort import natsorted
import numpy as np
from numpy import ndarray
from PIL import Image
import math
from skimage.metrics import structural_similarity
from typing import Any
from torch import Tensor
from torch import nn
from torch.nn import functional as F_torch
from torchvision import models
from torchvision import transforms
from torchvision.models.feature_extraction import create_feature_extractor

In [2]:
def make_directory(dir_path: str) -> None:
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def tensor_to_image(tensor:torch.Tensor, range_norm: bool, half: bool) :
    if range_norm:
        tensor = tensor.add(1.0).div(2.0)
    if half:
        tensor = tensor.half()

    image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8")

    return image
def psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
def image_to_tensor(image: ndarray, range_norm: bool, half: bool):
    # Convert image data type to Tensor data type
    tensor = torch.from_numpy(np.ascontiguousarray(image)).permute(2, 0, 1).float()
    # Scale the image data from [0, 1] to [-1, 1]
    if range_norm:
        tensor = tensor.mul(2.0).sub(1.0)
    # Convert torch.float32 image data type to torch.half image data type
    if half:
        tensor = tensor.half()
    return tensor

In [3]:
class SRResNet(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            channels: int,
            num_rcb: int,
            upscale_factor: int
    ) -> None:
        super(SRResNet, self).__init__()
        #低频信息提取层：使用一个包含一个卷积层和PReLU激活函数的Sequential对象来提取输入图像的低频信息。
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, channels, (9, 9), (1, 1), (4, 4)),
            nn.PReLU(),
        )

        #高频信息提取块：使用了多个残差卷积块_ResidualConvBlock来提取输入图像的高频信息。
        trunk = []
        for _ in range(num_rcb):
            trunk.append(_ResidualConvBlock(channels))
        self.trunk = nn.Sequential(*trunk)

        #高频线性融合层：使用了一个包含一个卷积层、批归一化和PReLU激活函数的Sequential对象对高频信息进行线性融合。
        self.conv2 = nn.Sequential(
            nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(channels),
        )

        #放大块：使用了多个_UpsampleBlock来将低分辨率图像放大到目标分辨率。这个过程中，每个_UpsampleBlock都包含了一个卷积层、像素洗牌(PixelShuffle)和PReLU激活函数。
        upsampling = []
        if upscale_factor == 2 or upscale_factor == 4 or upscale_factor == 8:
            for _ in range(int(math.log(upscale_factor, 2))):
                upsampling.append(_UpsampleBlock(channels, 2))
        elif upscale_factor == 3:
            upsampling.append(_UpsampleBlock(channels, 3))
        self.upsampling = nn.Sequential(*upsampling)

       #重建块：使用一个卷积层来将放大后的图像进行重建，并使用clamp函数将像素值限制在0到1之间。
        self.conv3 = nn.Conv2d(channels, out_channels, (9, 9), (1, 1), (4, 4))

        # Initialize neural network weights
        self._initialize_weights()

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)

    # Support torch.script function
    def _forward_impl(self, x: Tensor) -> Tensor:
        out1 = self.conv1(x)
        out = self.trunk(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)

        out = torch.clamp_(out, 0.0, 1.0)

        return out

    def _initialize_weights(self) -> None:
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)
class _UpsampleBlock(nn.Module):
    def __init__(self, channels: int, upscale_factor: int) -> None:
        super(_UpsampleBlock, self).__init__()
        self.upsample_block = nn.Sequential(
            nn.Conv2d(channels, channels * upscale_factor * upscale_factor, (3, 3), (1, 1), (1, 1)),
            nn.PixelShuffle(2),
            nn.PReLU(),
        )

    def forward(self, x: Tensor) -> Tensor:
        out = self.upsample_block(x)

        return out
class _ResidualConvBlock(nn.Module):
    def __init__(self, channels: int) -> None:
        super(_ResidualConvBlock, self).__init__()
        self.rcb = nn.Sequential(
            nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(channels),
        )

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.rcb(x)

        out = torch.add(out, identity)

        return out
def srresnet_x4(**kwargs: Any) -> SRResNet:
    model = SRResNet(upscale_factor=4, **kwargs)

    return model



In [14]:
def main() -> None:
    device = torch.device("cuda", 0)#指定计算设备为第一个可用的CUDA启用的GPU。
    lr_dir = f"./data/Set5"#指定包含低分辨率测试图像的目录。
    sr_dir = f"./srgan_results/"#创建用于存储超分辨率实验结果的目录。
    gt_dir = f"./data/Set5"#指定包含相应高分辨率测试图像的目录（仅用于评估）。
    g_model_weights_path='./srgan_model/SRGAN_x4-ImageNet-8c4a7569.pth.tar'#指定预先训练的SRResNet模型权重文件路径。

    g_model = SRResNet(in_channels=3,upscale_factor=4,
                                            out_channels=3,
                                            channels=64,
                                            num_rcb=16)#创建一个SRResNet模型实例，并将其发送到计算设备。
    g_model = g_model.to(device=device)#将SRResNet模型实例移到指定计算设备上。

    checkpoint = torch.load(g_model_weights_path, map_location=lambda storage, loc: storage)#从指定文件加载预先训练的SRResNet模型权重。
    g_model.load_state_dict(checkpoint["state_dict"])

    make_directory(sr_dir)#如果目录不存在，则创建用于存储超分辨率实验结果的目录。
    g_model.eval()#将模型设置为评估模式
    #初始化变量，用于计算每个测试图像的PSNR和SSIM指标以及所有测试图像的指标。
    psnr_metrics = 0.0
    psnr_metrics_all = 0.0
    ssim_metrics_all = 0.0
    # ：获取所有低分辨率测试图像文件名的排序列表。
    file_names = natsorted(os.listdir(lr_dir))
    # 计算测试图像的总数。
    total_files = len(file_names)

    for index in range(total_files):#迭代处理所有测试图像。
        # 获取LR图像路径、SR图像路径、HR图像路径
        lr_image_path = os.path.join(lr_dir, file_names[index])
        sr_image_path = os.path.join(sr_dir, file_names[index])
        gt_image_path = os.path.join(gt_dir, file_names[index])
        #加载高分辨率图像，将其转换为numpy数组并归一化到0到1之间，然后将其转换为PyTorch张量，并添加批处理维度。
        gt_img = Image.open(gt_image_path)
        gt_image = np.array(gt_img).astype(np.float32) / 255.0
        gt_tensor=image_to_tensor(gt_image, False, False).unsqueeze_(0)

        gt_rgb_image = cv2.cvtColor(gt_image, cv2.COLOR_RGB2BGR)#将高分辨率图像从RGB格式转换为BGR格式，以便能够保存为图像文件。
        cv2.imwrite(os.path.join(sr_dir, f'GroundTruth_{ file_names[index]}'), gt_rgb_image*255)#将高分辨率图像保存为文件，文件名前缀为"GroundTruth_"。
        #加载低分辨率图像，获取其宽度和高度中较小的值，并使用双三次插值将其缩小4倍。
        lr_img = Image.open(lr_image_path)
        size = np.min(lr_img.size)
        downscale = transforms.Resize(int(size / 4), interpolation=Image.BICUBIC)
        #将低分辨率图像转换为numpy数组，并将其从RGB格式转换为BGR格式，以便能够保存为图像文件。
        lr_img = downscale(lr_img)
        lr_image = np.array(lr_img)
        lr_rgb_image = cv2.cvtColor(lr_image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(sr_dir, f'subsample_4_{file_names[index]}'), lr_rgb_image)#将低分辨率图像保存为文件，文件名前缀为"subsample_4_"。
        lr_tensor = image_to_tensor(lr_image, False, False).unsqueeze_(0)#将低分辨率图像转换为PyTorch张量，并添加批处理维度。
        #将输入的低分辨率图像LR表示为PyTorch张量，然后将内存格式设置为通道优先（channels-last）的方式，移动到指定的GPU设备上，并将所有像素值除以255.0，以将它们缩放到[0,1]范围内
        lr_tensor = lr_tensor.to(device=device, memory_format=torch.channels_last, non_blocking=True)/255.0
        gt_tensor = gt_tensor.to(device=device, memory_format=torch.channels_last, non_blocking=True)
        # Only reconstruct the Y channel image data.
        with torch.no_grad():#使用with语句块，以避免在反向传播时计算梯度。在这个语句块中的操作不会影响模型的梯度计算。
            sr_tensor = g_model(lr_tensor)#通过调用预训练的超分辨率模型g_model来对低分辨率图像LR进行重建，得到超分辨率图像SR的张量表示sr_tensor。

        #将SR图像的张量表示sr_tensor转换为OpenCV格式的图像表示sr_image，以便于后续的处理和保存。
        sr_image = tensor_to_image(sr_tensor, False, True)
        # Cal IQA metrics
        psnr_metrics = psnr(sr_tensor, gt_tensor)#：计算SR图像和GT图像之间的PSNR指标，其中psnr()是一个自定义函数，输入为张量表示的SR和GT图像，输出为它们之间的PSNR值。
        ssim_metrics = structural_similarity(sr_image.astype(np.float32) / 255.0, gt_image, win_size=11, gaussian_weights=True,
                                             multichannel=True, data_range=1.0, K1=0.01, K2=0.03, sigma=1.5)
        #计算SR图像和GT图像之间的SSIM指标，其中structural_similarity()是scikit-image库中实现的计算SSIM的函数，它的输入为SR和GT图像的OpenCV格式表示，输出为它们之间的SSIM值
        psnr_metrics_all += psnr_metrics
        ssim_metrics_all += ssim_metrics
        print(file_names[index], f' psnr:{psnr_metrics}')
        print(file_names[index], f' ssim:{ssim_metrics}')
        sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR)#将SR图像从RGB格式转换为BGR格式，以便于后续的保存。
        text='psnr:'+str(round(float(psnr_metrics.cpu()), 3))+' ssim:'+str(ssim_metrics)
        cv2.putText(sr_image, text, (40, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0,255), 1)#在SR图像上添加文字标签，标注PSNR和SSIM指标的数值。
        cv2.imwrite(os.path.join(sr_dir, f'super_resolution_{file_names[index]}'), sr_image)

    avg_psnr = 100 if psnr_metrics_all / total_files > 100 else psnr_metrics_all / total_files
    avg_ssim = 1 if ssim_metrics_all / total_files > 1 else ssim_metrics_all / total_files

    print(f"PSNR: {avg_psnr:4.2f} [dB]\n"
          f"SSIM: {avg_ssim:4.4f} [u]")


if __name__ == "__main__":
    main()


Processing `c:\Users\DELL\Desktop\计算机视觉实践\计算机视觉实践-练习3\data\Set5\baby.png`...


  downscale = transforms.Resize(int(size / 4), interpolation=Image.BICUBIC)
  ssim_metrics = structural_similarity(sr_image.astype(np.float32) / 255.0, gt_image, win_size=11, gaussian_weights=True,


baby.png  psnr:30.62220573425293
baby.png  ssim:0.8199921250343323
Processing `c:\Users\DELL\Desktop\计算机视觉实践\计算机视觉实践-练习3\data\Set5\bird.png`...
bird.png  psnr:29.83331871032715
bird.png  ssim:0.857977569103241
Processing `c:\Users\DELL\Desktop\计算机视觉实践\计算机视觉实践-练习3\data\Set5\butterfly.png`...
butterfly.png  psnr:25.211177825927734
butterfly.png  ssim:0.8504059314727783
Processing `c:\Users\DELL\Desktop\计算机视觉实践\计算机视觉实践-练习3\data\Set5\head.png`...
head.png  psnr:28.8527774810791
head.png  ssim:0.6636013984680176
Processing `c:\Users\DELL\Desktop\计算机视觉实践\计算机视觉实践-练习3\data\Set5\woman.png`...
woman.png  psnr:27.815959930419922
woman.png  ssim:0.8716802000999451
PSNR: 28.47 [dB]
SSIM: 0.8127 [u]
