In [None]:
import os
import cv2
import torch
from natsort import natsorted
import numpy as np
from numpy import ndarray
from PIL import Image
import math
from skimage.metrics import structural_similarity
from typing import Any
from torch import Tensor
from torch import nn
from torch.nn import functional as F_torch
from torchvision import models
from torchvision import transforms
from torchvision.models.feature_extraction import create_feature_extractor


# 创建目录函数，如果目录不存在，则创建一个新目录
def make_directory(dir_path: str) -> None:
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


# 将tensor转换为图像的函数，如果range_norm为True，则将tensor的范围从[-1,1]转换为[0,1]，如果half为True，则将tensor的数据类型从float32转换为float16
def tensor_to_image(tensor: torch.Tensor, range_norm: bool, half: bool):
    if range_norm:
        tensor = tensor.add(1.0).div(2.0)
    if half:
        tensor = tensor.half()

    image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8")

    return image


# 计算两个图像之间的PSNR
def psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


# 将图像转换为tensor的函数，如果range_norm为True，则将像素值范围从[0,255]转换为[-1,1]，如果half为True，则将数据类型从float32转换为float16
def image_to_tensor(image: ndarray, range_norm: bool, half: bool):
    tensor = torch.from_numpy(np.ascontiguousarray(image)).permute(2, 0, 1).float()
    if range_norm:
        tensor = tensor.mul(2.0).sub(1.0)
    if half:
        tensor = tensor.half()
    return tensor


class SRResNet(nn.Module):
    def __init__(
            self,
            in_channels: int,  # 输入通道数
            out_channels: int,  # 输出通道数
            channels: int,  # 每个残差块的通道数
            num_rcb: int,  # 残差块数量
            upscale_factor: int  # 上采样倍数
    ) -> None:
        super(SRResNet, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, channels, (9, 9), (1, 1), (4, 4)),
            nn.PReLU(), # 激活函数
        )

        trunk = []
        for _ in range(num_rcb):
            trunk.append(_ResidualConvBlock(channels))  # 添加残差块
        self.trunk = nn.Sequential(*trunk)

        self.conv2 = nn.Sequential(
            nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(channels),  # 批量归一化
        )

        upsampling = []
        if upscale_factor == 2 or upscale_factor == 4 or upscale_factor == 8:
            for _ in range(int(math.log(upscale_factor, 2))):
                upsampling.append(_UpsampleBlock(channels, 2))
        elif upscale_factor == 3:
            upsampling.append(_UpsampleBlock(channels, 3))
        self.upsampling = nn.Sequential(*upsampling)

        self.conv3 = nn.Conv2d(channels, out_channels, (9, 9), (1, 1), (4, 4))

        self._initialize_weights()

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)

    def _forward_impl(self, x: Tensor) -> Tensor:
        out1 = self.conv1(x)
        out = self.trunk(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)  # 短路连接
        out = self.upsampling(out)
        out = self.conv3(out)

        out = torch.clamp_(out, 0.0, 1.0)   # 输出值截断在0到1之间

        return out

    def _initialize_weights(self) -> None:
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)  # 卷积层权重初始化
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)


class _UpsampleBlock(nn.Module):
    def __init__(self, channels: int, upscale_factor: int) -> None:
        super(_UpsampleBlock, self).__init__()
        self.upsample_block = nn.Sequential(
            nn.Conv2d(channels, channels * upscale_factor * upscale_factor, (3, 3), (1, 1), (1, 1)),  # 使用3x3的卷积核对输入进行卷积
            nn.PixelShuffle(2),  # 像素混洗操作，将通道数增加到 upscale_factor * upscale_factor 倍，同时将图像大小扩大 upscale_factor 倍
            nn.PReLU(),  # 非线性激活函数
        )

    def forward(self, x: Tensor) -> Tensor:
        out = self.upsample_block(x)  # 前向传播过程

        return out


class _ResidualConvBlock(nn.Module):
    def __init__(self, channels: int) -> None:
        super(_ResidualConvBlock, self).__init__()
        self.rcb = nn.Sequential(
            nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),  # 使用3x3的卷积核对输入进行卷积
            nn.BatchNorm2d(channels),  # 批归一化层，对特征进行标准化处理
            nn.PReLU(),  # 非线性激活函数
            nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),  # 使用3x3的卷积核对输入进行卷积
            nn.BatchNorm2d(channels),  # 批归一化层，对特征进行标准化处理
        )

    def forward(self, x: Tensor) -> Tensor:
        identity = x  # 将输入 x 作为残差分支的 identity

        out = self.rcb(x)  # 经过残差分支的前向传播

        out = torch.add(out, identity)  # 将残差分支和 identity 相加

        return out


def srresnet_x4(**kwargs: Any) -> SRResNet:
    model = SRResNet(upscale_factor=4, **kwargs)  # 创建一个 SRResNet 模型

    return model


def main() -> None:
    # 设置低分辨率图像目录、超分辨率图像输出目录和原始高分辨率图像目录
    device = torch.device("cuda", 0)
    lr_dir = f"./data/Set5"
    sr_dir = f"./srgan_results/"
    gt_dir = f"./data/Set5"
    # 加载预训练模型的权重
    g_model_weights_path = './srgan_model/SRGAN_x4-ImageNet-8c4a7569.pth.tar'
    # 创建超分辨率模型实例，并将模型转移到指定设备上
    g_model = SRResNet(in_channels=3, upscale_factor=4,
                       out_channels=3,
                       channels=64,
                       num_rcb=16)
    g_model = g_model.to(device=device)
    # 加载预训练模型的权重
    checkpoint = torch.load(g_model_weights_path,
                            map_location=lambda storage, loc: storage)
    g_model.load_state_dict(checkpoint["state_dict"])
    # 创建超分辨率图像输出目录
    make_directory(sr_dir)
    # 设置初始的 PSNR 和 SSIM 值为 0
    g_model.eval()
    psnr_metrics = 0.0
    psnr_metrics_all = 0.0
    ssim_metrics_all = 0.0
    # 获取低分辨率图像目录中的所有文件名，并按照自然数排序
    file_names = natsorted(os.listdir(lr_dir))
    # 获取总共需要处理的文件数
    total_files = len(file_names)
    # 遍历每个文件
    for index in range(total_files):
        # 获取当前文件的低分辨率图像、超分辨率图像和原始高分辨率图像的路径
        lr_image_path = os.path.join(lr_dir, file_names[index])
        sr_image_path = os.path.join(sr_dir, file_names[index])
        gt_image_path = os.path.join(gt_dir, file_names[index])
        # 打开原始高分辨率图像，并将图像转化为 numpy 数组并归一化到 [0, 1] 范围内
        gt_img = Image.open(gt_image_path)
        gt_image = np.array(gt_img).astype(np.float32) / 255.0
        # 将 numpy 数组转换为 PyTorch 张量，并在第 0 维增加一个维度，变成 [1, C, H, W] 的形式
        gt_tensor = image_to_tensor(gt_image, False, False).unsqueeze_(0)
        # 将原始高分辨率图像保存为 BGR 格式的 PNG 文件
        gt_rgb_image = cv2.cvtColor(gt_image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(sr_dir, f'GroundTruth_{file_names[index]}'),
                    gt_rgb_image * 255)
        # 加载低分辨率图像，获取其宽度和高度中较小的值，并使用双三次插值将其缩小4倍。
        lr_img = Image.open(lr_image_path)
        size = np.min(lr_img.size)
        downscale = transforms.Resize(int(size / 4), interpolation=Image.BICUBIC)
        lr_img = downscale(lr_img)
        lr_image = np.array(lr_img)
        lr_rgb_image = cv2.cvtColor(lr_image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(sr_dir, f'subsample_4_{file_names[index]}'),
                    lr_rgb_image)
        lr_tensor = image_to_tensor(lr_image, False, False).unsqueeze_(0)
        lr_tensor = lr_tensor.to(device=device, memory_format=torch.channels_last, non_blocking=True) / 255.0
        gt_tensor = gt_tensor.to(device=device, memory_format=torch.channels_last, non_blocking=True)

        with torch.no_grad():
            sr_tensor = g_model(lr_tensor)

        sr_image = tensor_to_image(sr_tensor, False, True)
        psnr_metrics = psnr(sr_tensor,
                            gt_tensor)
        ssim_metrics = structural_similarity(sr_image.astype(np.float32) / 255.0, gt_image, win_size=11,
                                             gaussian_weights=True,
                                             multichannel=True, data_range=1.0, K1=0.01, K2=0.03, sigma=1.5)
        psnr_metrics_all += psnr_metrics
        ssim_metrics_all += ssim_metrics
        print(file_names[index], f' psnr:{psnr_metrics}')
        print(file_names[index], f' ssim:{ssim_metrics}')
        sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR)
        text = 'psnr:' + str(round(float(psnr_metrics.cpu()), 3)) + ' ssim:' + str(ssim_metrics)
        cv2.putText(sr_image, text, (40, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 255),
                    1)
        cv2.imwrite(os.path.join(sr_dir, f'super_resolution_{file_names[index]}'), sr_image)

    avg_psnr = 100 if psnr_metrics_all / total_files > 100 else psnr_metrics_all / total_files
    avg_ssim = 1 if ssim_metrics_all / total_files > 1 else ssim_metrics_all / total_files

    print(f"PSNR: {avg_psnr:4.2f} [dB]\n"
          f"SSIM: {avg_ssim:4.4f} [u]")


if __name__ == "__main__":
    main()
