# imgproc

In [1]:
"""Realize the function of processing the dataset before training."""
import math
import random
from typing import Any

import cv2
import numpy as np
import torch
from torchvision.transforms import functional as F

__all__ = [
    "image2tensor", "tensor2image",
    "rgb2ycbcr", "bgr2ycbcr", "ycbcr2bgr", "ycbcr2rgb",
    "center_crop", "random_crop", "random_rotate", "random_horizontally_flip", "random_vertically_flip",
]


def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
    """Convert ``PIL.Image`` to Tensor.

    Args:
        image (np.ndarray): The image data read by ``PIL.Image``
        range_norm (bool): Scale [0, 1] data to between [-1, 1]
        half (bool): Whether to convert torch.float32 similarly to torch.half type.

    Returns:
        Normalized image data

    Examples:
        >>> image = cv2.imread("image.bmp", cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        >>> tensor_image = image2tensor(image, range_norm=False, half=False)
    """

    tensor = F.to_tensor(image)

    if range_norm:
        tensor = tensor.mul_(2.0).sub_(1.0)
    if half:
        tensor = tensor.half()

    return tensor


def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
    """Converts ``torch.Tensor`` to ``PIL.Image``.

    Args:
        tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
        range_norm (bool): Scale [-1, 1] data to between [0, 1]
        half (bool): Whether to convert torch.float32 similarly to torch.half type.

    Returns:
        Convert image data to support PIL library

    Examples:
        >>> tensor = torch.randn([1, 3, 128, 128])
        >>> image = tensor2image(tensor, range_norm=False, half=False)
    """

    if range_norm:
        tensor = tensor.add_(1.0).div_(2.0)
    if half:
        tensor = tensor.half()

    image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")

    return image

def cubic(x: Any):
    """Implementation of `cubic` function in Matlab under Python language.

    Args:
        x: Element vector.

    Returns:
        Bicubic interpolation.
    """

    absx = torch.abs(x)
    absx2 = absx ** 2
    absx3 = absx ** 3
    return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (
        ((absx > 1) * (absx <= 2)).type_as(absx))


def calculate_weights_indices(in_length: int, out_length: int, scale: float, kernel_width: int, antialiasing: bool):
    """Implementation of `calculate_weights_indices` function in Matlab under Python language.

    Args:
        in_length (int): Input length.
        out_length (int): Output length.
        scale (float): Scale factor.
        kernel_width (int): Kernel width.
        antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
            Caution: Bicubic down-sampling in PIL uses antialiasing by default.

    """

    if (scale < 1) and antialiasing:
        # Use a modified kernel (larger kernel width) to simultaneously
        # interpolate and antialiasing
        kernel_width = kernel_width / scale

    # Output-space coordinates
    x = torch.linspace(1, out_length, out_length)


    u = x / scale + 0.5 * (1 - 1 / scale)

  
    left = torch.floor(u - kernel_width / 2)

    
    p = math.ceil(kernel_width) + 2

    # The indices of the input pixels involved in computing the k-th output
    # pixel are in row k of the indices matrix.
    indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
        out_length, p)

    # The weights used to compute the k-th output pixel are in row k of the
    # weights matrix.
    distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices

    # apply cubic kernel
    if (scale < 1) and antialiasing:
        weights = scale * cubic(distance_to_center * scale)
    else:
        weights = cubic(distance_to_center)

    # Normalize the weights matrix so that each row sums to 1.
    weights_sum = torch.sum(weights, 1).view(out_length, 1)
    weights = weights / weights_sum.expand(out_length, p)

    # If a column in weights is all zero, get rid of it. only consider the
    # first and last column.
    weights_zero_tmp = torch.sum((weights == 0), 0)
    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
        indices = indices.narrow(1, 1, p - 2)
        weights = weights.narrow(1, 1, p - 2)
    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
        indices = indices.narrow(1, 0, p - 2)
        weights = weights.narrow(1, 0, p - 2)
    weights = weights.contiguous()
    indices = indices.contiguous()
    sym_len_s = -indices.min() + 1
    sym_len_e = indices.max() - in_length
    indices = indices + sym_len_s - 1
    return weights, indices, int(sym_len_s), int(sym_len_e)



def imresize(image: Any, scale_factor: float, antialiasing: bool = True) -> Any:
    """Implementation of `imresize` function in Matlab under Python language.

    Args:
        image: The input image.
        scale_factor (float): Scale factor. The same scale applies for both height and width.
        antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
            Caution: Bicubic down-sampling in `PIL` uses antialiasing by default. Default: ``True``.

    Returns:
        np.ndarray: Output image with shape (c, h, w), [0, 1] range, w/o round.
    """
    squeeze_flag = False
    if type(image).__module__ == np.__name__:  # numpy type
        numpy_type = True
        if image.ndim == 2:
            image = image[:, :, None]
            squeeze_flag = True
        image = torch.from_numpy(image.transpose(2, 0, 1)).float()
    else:
        numpy_type = False
        if image.ndim == 2:
            image = image.unsqueeze(0)
            squeeze_flag = True

    in_c, in_h, in_w = image.size()
    out_h, out_w = math.ceil(in_h * scale_factor), math.ceil(in_w * scale_factor)
    kernel_width = 4

    # get weights and indices
    weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale_factor, kernel_width, antialiasing)
    weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale_factor, kernel_width, antialiasing)
    # process H dimension
    # symmetric copying
    img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
    img_aug.narrow(1, sym_len_hs, in_h).copy_(image)

    sym_patch = image[:, :sym_len_hs, :]
    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(1, inv_idx)
    img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)

    sym_patch = image[:, -sym_len_he:, :]
    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(1, inv_idx)
    img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)

    out_1 = torch.FloatTensor(in_c, out_h, in_w)
    kernel_width = weights_h.size(1)
    for i in range(out_h):
        idx = int(indices_h[i][0])
        for j in range(in_c):
            out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])

    # process W dimension
    # symmetric copying
    out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
    out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)

    sym_patch = out_1[:, :, :sym_len_ws]
    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(2, inv_idx)
    out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)

    sym_patch = out_1[:, :, -sym_len_we:]
    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(2, inv_idx)
    out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)

    out_2 = torch.FloatTensor(in_c, out_h, out_w)
    kernel_width = weights_w.size(1)
    for i in range(out_w):
        idx = int(indices_w[i][0])
        for j in range(in_c):
            out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])

    if squeeze_flag:
        out_2 = out_2.squeeze(0)
    if numpy_type:
        out_2 = out_2.numpy()
        if not squeeze_flag:
            out_2 = out_2.transpose(1, 2, 0)

    return out_2

def rgb2ycbcr(image: np.ndarray, use_y_channel: bool = False) -> np.ndarray:
    """Implementation of rgb2ycbcr function in Matlab under Python language.

    Args:
        image (np.ndarray): Image input in RGB format.
        use_y_channel (bool): Extract Y channel separately. Default: ``False``.

    Returns:
        ndarray: YCbCr image array data.
    """

    if use_y_channel:
        image = np.dot(image, [65.481, 128.553, 24.966]) + 16.0
    else:
        image = np.matmul(image, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]

    image /= 255.
    image = image.astype(np.float32)

    return image


def bgr2ycbcr(image: np.ndarray, use_y_channel: bool = False) -> np.ndarray:
    """Implementation of bgr2ycbcr function in Matlab under Python language.

    Args:
        image (np.ndarray): Image input in BGR format.
        use_y_channel (bool): Extract Y channel separately. Default: ``False``.

    Returns:
        ndarray: YCbCr image array data.
    """

    if use_y_channel:
        image = np.dot(image, [24.966, 128.553, 65.481]) + 16.0
    else:
        image = np.matmul(image, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]

    image /= 255.
    image = image.astype(np.float32)

    return image

def ycbcr2rgb(image: np.ndarray) -> np.ndarray:
    """Implementation of ycbcr2rgb function in Matlab under Python language.

    Args:
        image (np.ndarray): Image input in YCbCr format.

    Returns:
        ndarray: RGB image array data.
    """

    image_dtype = image.dtype
    image *= 255.

    image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
                              [0, -0.00153632, 0.00791071],
                              [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]

    image /= 255.
    image = image.astype(image_dtype)

    return image

def ycbcr2bgr(image: np.ndarray) -> np.ndarray:
    """Implementation of ycbcr2bgr function in Matlab under Python language.

    Args:
        image (np.ndarray): Image input in YCbCr format.

    Returns:
        ndarray: BGR image array data.
    """

    image_dtype = image.dtype
    image *= 255.

    image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
                              [0.00791071, -0.00153632, 0],
                              [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921]

    image /= 255.
    image = image.astype(image_dtype)

    return image




In [2]:
import random

import numpy as np
import torch
from torch.backends import cudnn

# Random seed to maintain reproducible results
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
# Use GPU for training by default
device = torch.device("cuda", 0)
# Turning on when the image size does not change during training can speed up training
cudnn.benchmark = True
# Image magnification factor
upscale_factor = 4
# Current configuration parameter method
mode = "train"
# Experiment name, easy to save weights and log files
exp_name = "fsrcnn_x4"

if mode == "train":
    # Dataset
    train_image_dir = f"/kaggle/input/df2kdata/DF2K_train_HR"
    valid_image_dir = f"/kaggle/input/df2kdata/DF2K_valid_HR"
    test_lr_image_dir = f"/kaggle/input/super-resolution-benchmarks/Set5/Set5/LRbicx{upscale_factor}"
    test_hr_image_dir = f"/kaggle/input/super-resolution-benchmarks/Set5/Set5/GTmod12"

    image_size = 36
    batch_size = 64
    num_workers = 4

    # Incremental training and migration training
    start_epoch = 0
    resume = ""

    # Total number of epochs
    epochs = 100

    # SGD optimizer parameter
    model_lr = 1e-3
    model_momentum = 0.9
    model_weight_decay = 1e-4
    model_nesterov = False

    print_frequency = 200

if mode == "valid":
    # Test data address
    lr_dir = f"data/Set5/LRbicx{upscale_factor}"
    sr_dir = f"results/test/{exp_name}"
    hr_dir = f"data/Set5/GTmod12"

    model_path = f"results/{exp_name}/best.pth.tar"

# dataset

In [3]:
"""Realize the function of dataset preparation."""
import os
import queue
import threading

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader


__all__ = [
    "TrainValidImageDataset", "TestImageDataset",
    "PrefetchGenerator", "PrefetchDataLoader", "CPUPrefetcher", "CUDAPrefetcher",
]


class TrainValidImageDataset(Dataset):
    """Customize the data set loading function and prepare low/high resolution image data in advance.

    Args:
        image_dir (str): Train/Valid dataset address.
        image_size (int): High resolution image size.
        upscale_factor (int): Image up scale factor.
        mode (str): Data set loading method, the training data set is for data enhancement, and the verification data set is not for data enhancement.
    """

    def __init__(self, image_dir: str, image_size: int, upscale_factor: int, mode: str) -> None:
        super(TrainValidImageDataset, self).__init__()
        # Get all image file names in folder
        self.image_file_names = [os.path.join(image_dir, image_file_name) for image_file_name in os.listdir(image_dir)]
        # Specify the high-resolution image size, with equal length and width
        self.image_size = image_size
        # How many times the high-resolution image is the low-resolution image
        self.upscale_factor = upscale_factor
        # Load training dataset or test dataset
        self.mode = mode

    def __getitem__(self, batch_index: int) -> [torch.Tensor, torch.Tensor]:
        # Read a batch of image data
        hr_image = cv2.imread(self.image_file_names[batch_index], cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        # Use high-resolution image to make low-resolution image
        lr_image = imresize(hr_image, 1 / self.upscale_factor)

        if self.mode == "Train":
            # Data augment
            lr_image, hr_image = random_crop(lr_image, hr_image, self.image_size, self.upscale_factor)
            lr_image, hr_image = random_rotate(lr_image, hr_image, angles=[0, 90, 180, 270])
        elif self.mode == "Valid":
            lr_image, hr_image = center_crop(lr_image, hr_image, self.image_size, self.upscale_factor)
        else:
            raise ValueError("Unsupported data processing model, please use `Train` or `Valid`.")

        # Only extract the image data of the Y channel
        lr_y_image = bgr2ycbcr(lr_image, use_y_channel=True)
        hr_y_image = bgr2ycbcr(hr_image, use_y_channel=True)

        # Convert image data into Tensor stream format (PyTorch).
        # Note: The range of input and output is between [0, 1]
        lr_y_tensor = image2tensor(lr_y_image, range_norm=False, half=False)
        hr_y_tensor = image2tensor(hr_y_image, range_norm=False, half=False)

        return {"lr": lr_y_tensor, "hr": hr_y_tensor}

    def __len__(self) -> int:
        return len(self.image_file_names)


class TestImageDataset(Dataset):
    """Define Test dataset loading methods.

    Args:
        test_lr_image_dir (str): Test dataset address for low resolution image dir.
        test_hr_image_dir (str): Test dataset address for high resolution image dir.
        upscale_factor (int): Image up scale factor.
    """

    def __init__(self, test_lr_image_dir: str, test_hr_image_dir: str, upscale_factor: int) -> None:
        super(TestImageDataset, self).__init__()
        # Get all image file names in folder
        self.lr_image_file_names = [os.path.join(test_lr_image_dir, x) for x in os.listdir(test_lr_image_dir)]
        self.hr_image_file_names = [os.path.join(test_hr_image_dir, x) for x in os.listdir(test_hr_image_dir)]
        # How many times the high-resolution image is the low-resolution image
        self.upscale_factor = upscale_factor

    def __getitem__(self, batch_index: int) -> [torch.Tensor, torch.Tensor]:
        # Read a batch of image data
        lr_image = cv2.imread(self.lr_image_file_names[batch_index], cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
        hr_image = cv2.imread(self.hr_image_file_names[batch_index], cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.

        # Only extract the image data of the Y channel
        lr_y_image = bgr2ycbcr(lr_image, use_y_channel=True)
        hr_y_image = bgr2ycbcr(hr_image, use_y_channel=True)

        # Convert image data into Tensor stream format (PyTorch).
        # Note: The range of input and output is between [0, 1]
        lr_y_tensor = image2tensor(lr_y_image, range_norm=False, half=False)
        hr_y_tensor = image2tensor(hr_y_image, range_norm=False, half=False)

        return {"lr": lr_y_tensor, "hr": hr_y_tensor}

    def __len__(self) -> int:
        return len(self.lr_image_file_names)


class PrefetchGenerator(threading.Thread):
    """A fast data prefetch generator.

    Args:
        generator: Data generator.
        num_data_prefetch_queue (int): How many early data load queues.
    """

    def __init__(self, generator, num_data_prefetch_queue: int) -> None:
        threading.Thread.__init__(self)
        self.queue = queue.Queue(num_data_prefetch_queue)
        self.generator = generator
        self.daemon = True
        self.start()

    def run(self) -> None:
        for item in self.generator:
            self.queue.put(item)
        self.queue.put(None)

    def __next__(self):
        next_item = self.queue.get()
        if next_item is None:
            raise StopIteration
        return next_item

    def __iter__(self):
        return self


class PrefetchDataLoader(DataLoader):
    """A fast data prefetch dataloader.

    Args:
        num_data_prefetch_queue (int): How many early data load queues.
        kwargs (dict): Other extended parameters.
    """

    def __init__(self, num_data_prefetch_queue: int, **kwargs) -> None:
        self.num_data_prefetch_queue = num_data_prefetch_queue
        super(PrefetchDataLoader, self).__init__(**kwargs)

    def __iter__(self):
        return PrefetchGenerator(super().__iter__(), self.num_data_prefetch_queue)




In [4]:
def center_crop(lr_image: np.ndarray, hr_image: np.ndarray, hr_image_size: int, upscale_factor: int) -> [np.ndarray, np.ndarray]:
    """Crop small image patches from one image center area.

    Args:
        lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`.
        hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`.
        hr_image_size (int): The size of the captured high-resolution image area.
        upscale_factor (int): Image up scale factor.

    Returns:
        np.ndarray: Small patch images.
    """

    hr_image_height, hr_image_width = hr_image.shape[:2]

    # Just need to find the top and left coordinates of the image
    hr_top = (hr_image_height - hr_image_size) // 2
    hr_left = (hr_image_width - hr_image_size) // 2

    # Define the LR image position
    lr_top = hr_top // upscale_factor
    lr_left = hr_left // upscale_factor
    lr_image_size = hr_image_size // upscale_factor

    # Crop image patch
    patch_lr_image = lr_image[lr_top:lr_top + lr_image_size, lr_left:lr_left + lr_image_size, ...]
    patch_hr_image = hr_image[hr_top:hr_top + hr_image_size, hr_left:hr_left + hr_image_size, ...]

    return patch_lr_image, patch_hr_image


def random_crop(lr_image: np.ndarray, hr_image: np.ndarray, hr_image_size: int, upscale_factor: int) -> [np.ndarray, np.ndarray]:
    """Crop small image patches from one image.

    Args:
        lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`.
        hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`.
        hr_image_size (int): The size of the captured high-resolution image area.
        upscale_factor (int): Image up scale factor.

    Returns:
        np.ndarray: Small patch images.
    """

    hr_image_height, hr_image_width = hr_image.shape[:2]

    # Just need to find the top and left coordinates of the image
    hr_top = random.randint(0, hr_image_height - hr_image_size)
    hr_left = random.randint(0, hr_image_width - hr_image_size)

    # Define the LR image position
    lr_top = hr_top // upscale_factor
    lr_left = hr_left // upscale_factor
    lr_image_size = hr_image_size // upscale_factor

    # Crop image patch
    patch_lr_image = lr_image[lr_top:lr_top + lr_image_size, lr_left:lr_left + lr_image_size, ...]
    patch_hr_image = hr_image[hr_top:hr_top + hr_image_size, hr_left:hr_left + hr_image_size, ...]

    return patch_lr_image, patch_hr_image


def random_rotate(lr_image: np.ndarray, hr_image: np.ndarray, angles: list, lr_center=None, hr_center=None, scale_factor: float = 1.0) -> [np.ndarray, np.ndarray]:
    """Rotate an image randomly by a specified angle.

    Args:
        lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`.
        hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`.
        angles (list): Specify the rotation angle.
        lr_center (tuple[int]): Low-resolution image rotation center. If the center is None, initialize it as the center of the image. ``Default: None``.
        hr_center (tuple[int]): Low-resolution image rotation center. If the center is None, initialize it as the center of the image. ``Default: None``.
        scale_factor (float): scaling factor. Default: 1.0.

    Returns:
        np.ndarray: Rotated images.
    """

    lr_image_height, lr_image_width = lr_image.shape[:2]
    hr_image_height, hr_image_width = hr_image.shape[:2]

    if lr_center is None:
        lr_center = (lr_image_width // 2, lr_image_height // 2)
    if hr_center is None:
        hr_center = (hr_image_width // 2, hr_image_height // 2)

    # Random select specific angle
    angle = random.choice(angles)

    lr_matrix = cv2.getRotationMatrix2D(lr_center, angle, scale_factor)
    hr_matrix = cv2.getRotationMatrix2D(hr_center, angle, scale_factor)

    rotated_lr_image = cv2.warpAffine(lr_image, lr_matrix, (lr_image_width, lr_image_height))
    rotated_hr_image = cv2.warpAffine(hr_image, hr_matrix, (hr_image_width, hr_image_height))

    return rotated_lr_image, rotated_hr_image


def random_horizontally_flip(lr_image: np.ndarray, hr_image: np.ndarray, p=0.5) -> [np.ndarray, np.ndarray]:
    """Flip an image horizontally randomly.

    Args:
        lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`.
        hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`.
        p (optional, float): rollover probability. (Default: 0.5)

    Returns:
        np.ndarray: Horizontally flip images.
    """

    if random.random() < p:
        lr_image = cv2.flip(lr_image, 1)
        hr_image = cv2.flip(hr_image, 1)

    return lr_image, hr_image


def random_vertically_flip(lr_image: np.ndarray, hr_image: np.ndarray, p=0.5) -> [np.ndarray, np.ndarray]:
    """Flip an image vertically randomly.

    Args:
        lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`.
        hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`.
        p (optional, float): rollover probability. (Default: 0.5)

    Returns:
        np.ndarray: Vertically flip images.
    """

    if random.random() < p:
        lr_image = cv2.flip(lr_image, 0)
        hr_image = cv2.flip(hr_image, 0)

    return lr_image, hr_image

In [5]:
class CPUPrefetcher:
    """Use the CPU side to accelerate data reading.

    Args:
        dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
    """

    def __init__(self, dataloader) -> None:
        self.original_dataloader = dataloader
        self.data = iter(dataloader)

    def next(self):
        try:
            return next(self.data)
        except StopIteration:
            return None

    def reset(self):
        self.data = iter(self.original_dataloader)

    def __len__(self) -> int:
        return len(self.original_dataloader)


class CUDAPrefetcher:
    """Use the CUDA side to accelerate data reading.

    Args:
        dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
        device (torch.device): Specify running device.
    """

    def __init__(self, dataloader, device: torch.device):
        self.batch_data = None
        self.original_dataloader = dataloader
        self.device = device

        self.data = iter(dataloader)
        self.stream = torch.cuda.Stream()
        self.preload()

    def preload(self):
        try:
            self.batch_data = next(self.data)
        except StopIteration:
            self.batch_data = None
            return None

        with torch.cuda.stream(self.stream):
            for k, v in self.batch_data.items():
                if torch.is_tensor(v):
                    self.batch_data[k] = self.batch_data[k].to(self.device, non_blocking=True)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        batch_data = self.batch_data
        self.preload()
        return batch_data

    def reset(self):
        self.data = iter(self.original_dataloader)
        self.preload()

    def __len__(self) -> int:
        return len(self.original_dataloader)

# model

In [6]:
"""Realize the model definition function."""
from math import sqrt

import torch
from torch import nn


class FSRCNN(nn.Module):
    """

    Args:
        upscale_factor (int): Image magnification factor.
    """

    def __init__(self, upscale_factor: int) -> None:
        super(FSRCNN, self).__init__()
        # Feature extraction layer.
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(1, 56, (5, 5), (1, 1), (2, 2)),
            nn.PReLU(56)
        )

        # Shrinking layer.
        self.shrink = nn.Sequential(
            nn.Conv2d(56, 12, (1, 1), (1, 1), (0, 0)),
            nn.PReLU(12)
        )

        # Mapping layer.
        self.map = nn.Sequential(
            nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)),
            nn.PReLU(12),
            nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)),
            nn.PReLU(12),
            nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)),
            nn.PReLU(12),
            nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)),
            nn.PReLU(12)
        )

        # Expanding layer.
        self.expand = nn.Sequential(
            nn.Conv2d(12, 56, (1, 1), (1, 1), (0, 0)),
            nn.PReLU(56)
        )

        # Deconvolution layer.
        self.deconv = nn.ConvTranspose2d(56, 1, (9, 9), (upscale_factor, upscale_factor), (4, 4), (upscale_factor - 1, upscale_factor - 1))

        # Initialize model weights.
        self._initialize_weights()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._forward_impl(x)

    # Support torch.script function.
    def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
        out = self.feature_extraction(x)
        out = self.shrink(out)
        out = self.map(out)
        out = self.expand(out)
        out = self.deconv(out)

        return out

    # The filter weight of each layer is a Gaussian distribution with zero mean and standard deviation initialized by random extraction 0.001 (deviation is 0).
    def _initialize_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=sqrt(2 / (m.out_channels * m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)

        nn.init.normal_(self.deconv.weight.data, mean=0.0, std=0.001)
        nn.init.zeros_(self.deconv.bias.data)

# evaluation metrics

In [7]:
import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
from torchmetrics.functional import structural_similarity_index_measure as ssim

class EvaluationMetrics:
    def __init__(self, device=None):
        """Initialize the Image Quality Metrics calculator."""
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def load_image(self, image_path):
        """Load an image from a file path."""
        image = Image.open(image_path)
        return image

    def image_to_tensor(self, image):
        """Convert an image to a PyTorch tensor."""
        transform = transformer.Compose([transforms.ToTensor()])
        image_tensor = transform(image).unsqueeze(0).to(self.device)  # Add batch dimension
        return image_tensor

    @staticmethod
    def normalize_to_01(tensor):
        """Clamp the input tensor to the range [0, 1]."""
        return torch.clamp(tensor, 0, 1)

    
    def define_loss(self):
        """Define loss functions."""
        self.psnr_criterion = nn.MSELoss().to(self.device)  # Gán làm thuộc tính instance
        self.pixel_criterion = nn.MSELoss().to(self.device)  # Gán làm thuộc tính instance
        print("Loss functions defined successfully.")

    def calculate_psnr(self, sr_image: torch.Tensor, hr_image: torch.Tensor) -> torch.Tensor:
        """
        Calculate the average Peak Signal-to-Noise Ratio (PSNR) for a batch of images.
    
        Args:
            sr_image (torch.Tensor): Super-resolved image tensor of shape (N, C, H, W).
            hr_image (torch.Tensor): High-resolution image tensor of shape (N, C, H, W).
    
        Returns:
            torch.Tensor: A single scalar tensor containing the average PSNR for the batch.
        """
        sr_image = self.normalize_to_01(sr_image)
        hr_image = self.normalize_to_01(hr_image)
        mse = torch.mean((sr_image - hr_image) ** 2, dim=(1, 2, 3))  # Compute MSE for each image
        psnr = 10 * torch.log10(1.0 / mse)  # Compute PSNR for each image
        return psnr.mean()  # Return the average PSNR for the batch
    
    
    def calculate_ssim(self, sr_image: torch.Tensor, hr_image: torch.Tensor) -> torch.Tensor:
        """
        Calculate the average Structural Similarity Index Measure (SSIM) for a batch of images.
    
        Args:
            sr_image (torch.Tensor): Super-resolved image tensor of shape (N, C, H, W).
            hr_image (torch.Tensor): High-resolution image tensor of shape (N, C, H, W).
    
        Returns:
            torch.Tensor: A single scalar tensor containing the average SSIM for the batch.
        """
        # Normalize images to [0, 1]
        sr_image = self.normalize_to_01(sr_image)
        hr_image = self.normalize_to_01(hr_image)
    
        # Compute SSIM for the entire batch
        ssim_values = [
            ssim(sr.unsqueeze(0), hr.unsqueeze(0), data_range=1.0)
            for sr, hr in zip(sr_image, hr_image)
        ]
    
        # Convert to tensor and calculate mean
        return torch.tensor(ssim_values, device=self.device).mean()





    def evaluate(self, sr_image_path, hr_image_path, hr_grayscale=False):
        """Evaluate PSNR and SSIM for a pair of SR and HR images."""
        # Load images
        sr_image = self.load_image(sr_image_path)
        hr_image = self.load_image(hr_image_path)

        # Convert HR to grayscale if needed
        if hr_grayscale:
            hr_image = hr_image.convert('L')

        # Convert to tensors
        sr_image_tensor = self.image_to_tensor(sr_image)
        hr_image_tensor = self.image_to_tensor(hr_image)

        # Ensure SR and HR have the same shape
        assert sr_image_tensor.shape == hr_image_tensor.shape, "SR and HR images must have the same dimensions!"

        # Calculate metrics
        psnr_value = self.calculate_psnr(sr_image_tensor, hr_image_tensor)
        ssim_value = self.calculate_ssim(sr_image_tensor, hr_image_tensor)

        return {
            'psnr': psnr_value,
            'ssim': ssim_value
        }
'''
# Example usage
if __name__ == "__main__":
    # Paths to images
    sr_image_path = 'D:/HUST/_Intro to DL/FSRCNN-PyTorch/results/fsrcnn_x4/fsrcnn_x4/img_001_x4.png'
    hr_image_path = 'D:/HUST/_Intro to DL/FSRCNN-PyTorch/data/Set5/GTmod12/img_001.png'

    # Initialize the metrics calculator
    metrics_calculator = EvaluationMetrics()

    # Evaluate metrics
    results = metrics_calculator.evaluate(sr_image_path, hr_image_path, hr_grayscale=True)

    # Print results
    print(f"PSNR: {results['psnr']} dB")
    print(f"SSIM: {results['ssim']}")
'''

'\n# Example usage\nif __name__ == "__main__":\n    # Paths to images\n    sr_image_path = \'D:/HUST/_Intro to DL/FSRCNN-PyTorch/results/fsrcnn_x4/fsrcnn_x4/img_001_x4.png\'\n    hr_image_path = \'D:/HUST/_Intro to DL/FSRCNN-PyTorch/data/Set5/GTmod12/img_001.png\'\n\n    # Initialize the metrics calculator\n    metrics_calculator = EvaluationMetrics()\n\n    # Evaluate metrics\n    results = metrics_calculator.evaluate(sr_image_path, hr_image_path, hr_grayscale=True)\n\n    # Print results\n    print(f"PSNR: {results[\'psnr\']} dB")\n    print(f"SSIM: {results[\'ssim\']}")\n'

# train

In [8]:
"""File description: Realize the model training function."""
import os
import shutil
import time
from enum import Enum

import torch
from torch import nn
from torch import optim
from torch.cuda import amp
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import wandb

def main():
    global train_image_dir, valid_image_dir, test_lr_image_dir, test_hr_image_dir
    global image_size, batch_size, num_workers
    global start_epoch, resume, epochs
    global model_lr, model_momentum, model_weight_decay, model_nesterov
    global print_frequency

    # Initialize WandB project
    wandb.login(key="1dbbc805555d01296d603229c2ad3e01943dd695")
    wandb.init(
        project="FSRCNN-Training",  # Tên dự án
        name=f"FSRCNN_{epochs}_epochs",  # Tên phiên log
        config={  # Các thông số cấu hình
            "learning_rate": model_lr,
            "momentum": model_momentum,
            "weight_decay": model_weight_decay,
            "batch_size": batch_size,
            "epochs": epochs,
            "image_size": image_size,
        },
        settings=wandb.Settings(init_timeout=300)
    )


    # Initialize training to generate network evaluation indicators
    best_psnr = 0.0

    train_prefetcher, valid_prefetcher, test_prefetcher = load_dataset()
    print("Load train dataset and valid dataset successfully.")

    model = build_model()
    print("Build FSRCNN model successfully.")

    metrics = EvaluationMetrics(device)

    optimizer = define_optimizer(model)
    print("Define all optimizer functions successfully.")

    print("Check whether the pretrained model is restored...")
    if resume:
        # Load checkpoint model
        checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)
        # Restore the parameters in the training node to this point
        start_epoch = checkpoint["epoch"]
        best_psnr = checkpoint["best_psnr"]
        # Load checkpoint state dict. Extract the fitted model weights
        model_state_dict = model.state_dict()
        new_state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict}
        # Overwrite the pretrained model weights to the current model
        model_state_dict.update(new_state_dict)
        model.load_state_dict(model_state_dict)
        # Load the optimizer model
        optimizer.load_state_dict(checkpoint["optimizer"])
        print("Loaded pretrained model weights.")

    # Create a folder of super-resolution experiment results
    samples_dir = os.path.join("samples", exp_name)
    results_dir = os.path.join("results", exp_name)
    if not os.path.exists(samples_dir):
        os.makedirs(samples_dir)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # Create training process log file
    writer = SummaryWriter(os.path.join("samples", "logs", exp_name))

    # Initialize the gradient scaler
    scaler = amp.GradScaler()
    
    for epoch in range(start_epoch, epochs):
        train(model, train_prefetcher, metrics, optimizer, epoch, scaler, writer)
        _ = validate(model, valid_prefetcher, metrics, epoch, writer, mode="Valid")
        psnr = validate(model, test_prefetcher, metrics, epoch, writer, "Test")  # Sửa ở đây
        print("\n")

        # Automatically save the model with the highest index
        is_best = psnr > best_psnr
        best_psnr = max(psnr, best_psnr)
        torch.save({"epoch": epoch + 1,
                    "best_psnr": best_psnr,
                    "state_dict": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "scheduler": None},
                   os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"))
        if is_best:
            shutil.copyfile(os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"), os.path.join(results_dir, "best.pth.tar"))
        if (epoch + 1) == epochs:
            shutil.copyfile(os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"), os.path.join(results_dir, "last.pth.tar"))

In [9]:
def load_dataset() -> [CUDAPrefetcher, CUDAPrefetcher, CUDAPrefetcher]:
    # Load train, test and valid datasets
    train_datasets = TrainValidImageDataset(train_image_dir, image_size, upscale_factor, "Train")
    valid_datasets = TrainValidImageDataset(valid_image_dir, image_size, upscale_factor, "Valid")
    test_datasets = TestImageDataset(test_lr_image_dir, test_hr_image_dir, upscale_factor)

    # Generator all dataloader
    train_dataloader = DataLoader(train_datasets,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  pin_memory=True,
                                  drop_last=True,
                                  persistent_workers=True)
    valid_dataloader = DataLoader(valid_datasets,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=num_workers,
                                  pin_memory=True,
                                  drop_last=False,
                                  persistent_workers=True)
    test_dataloader = DataLoader(test_datasets,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=1,
                                 pin_memory=True,
                                 drop_last=False,
                                 persistent_workers=False)

    # Place all data on the preprocessing data loader
    train_prefetcher = CUDAPrefetcher(train_dataloader, device)
    valid_prefetcher = CUDAPrefetcher(valid_dataloader, device)
    test_prefetcher = CUDAPrefetcher(test_dataloader, device)

    return train_prefetcher, valid_prefetcher, test_prefetcher


def build_model() -> nn.Module:
    model = FSRCNN(upscale_factor).to(device)

    return model




def define_optimizer(model) -> optim.SGD:
    optimizer = optim.SGD([{"params": model.feature_extraction.parameters()},
                           {"params": model.shrink.parameters()},
                           {"params": model.map.parameters()},
                           {"params": model.expand.parameters()},
                           {"params": model.deconv.parameters(), "lr": model_lr * 0.1}],
                          lr=model_lr,
                          momentum=model_momentum,
                          weight_decay=model_weight_decay,
                          nesterov=model_nesterov)

    return optimizer


In [10]:
def train(model, train_prefetcher, metrics, optimizer, epoch, scaler, writer) -> None:
    batches = len(train_prefetcher)

    batch_time = AverageMeter("Time", ":6.3f")
    data_time = AverageMeter("Data", ":6.3f")
    losses = AverageMeter("Loss", ":6.6f")
    psnres = AverageMeter("PSNR", ":4.2f")
    ssimes = AverageMeter("SSIM", ":4.3f")

    progress = ProgressMeter(batches, [batch_time, data_time, losses, psnres, ssimes], prefix=f"Epoch: [{epoch + 1}]")

    model.train()

    batch_index = 0
    end = time.time()
    train_prefetcher.reset()
    batch_data = train_prefetcher.next()

    while batch_data is not None:
        data_time.update(time.time() - end)

        lr = batch_data["lr"].to(device, non_blocking=True)
        hr = batch_data["hr"].to(device, non_blocking=True)

        model.zero_grad()
        
        metrics.define_loss()

        # Sử dụng các thuộc tính đã được tạo
        pixel_criterion = metrics.pixel_criterion

        with amp.autocast():
            sr = model(lr)
            loss = pixel_criterion(sr, hr)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        psnr = metrics.calculate_psnr(sr, hr)
        ssim = metrics.calculate_ssim(sr, hr)
        # Tính toán PSNR và SSIM trung bình cho batch
        psnr = metrics.calculate_psnr(sr, hr)  # PSNR trung bình của batch
        ssim = metrics.calculate_ssim(sr, hr)  # SSIM trung bình của batch
        
        # Cập nhật giá trị loss, PSNR, và SSIM trung bình cho batch
        losses.update(loss.item(), lr.size(0))  # Loss trung bình của batch
        psnres.update(psnr.item(), lr.size(0))  # PSNR trung bình của batch
        ssimes.update(ssim.item(), lr.size(0))  # SSIM trung bình của batch


        batch_time.update(time.time() - end)
        end = time.time()

        if batch_index % print_frequency == 0:
            # Log to TensorBoard
            writer.add_scalar("Train/Loss", loss.item(), batch_index + epoch * batches + 1)
            writer.add_scalar("Train/PSNR", psnr.item(), batch_index + epoch * batches + 1)
            writer.add_scalar("Train/SSIM", ssim, batch_index + epoch * batches + 1)

            # Log to WandB
            wandb.log({
                "Train/Loss": loss.item(),
                "Train/PSNR": psnr.item(),
                "Train/SSIM": ssim,
                "Batch": batch_index,
                "Epoch": epoch + 1,
            })
            progress.display(batch_index)

        batch_data = train_prefetcher.next()
        batch_index += 1


def validate(model, valid_prefetcher, metrics, epoch, writer, mode) -> float:
    batch_time = AverageMeter("Time", ":6.3f", Summary.NONE)
    psnres = AverageMeter("PSNR", ":4.2f", Summary.AVERAGE)
    ssimes = AverageMeter("SSIM", ":4.3f", Summary.AVERAGE)

    progress = ProgressMeter(len(valid_prefetcher), [batch_time, psnres, ssimes], prefix=f"{mode}: ")

    model.eval()

    batch_index = 0
    end = time.time()
    valid_prefetcher.reset()
    batch_data = valid_prefetcher.next()

    with torch.no_grad():
        while batch_data is not None:
            lr = batch_data["lr"].to(device, non_blocking=True)
            hr = batch_data["hr"].to(device, non_blocking=True)

            with amp.autocast():
                sr = model(lr)

            psnr = metrics.calculate_psnr(sr, hr)
            ssim = metrics.calculate_ssim(sr, hr)
            psnr = metrics.calculate_psnr(sr, hr)  # psnr có shape [batch_size]
            psnr_mean = psnr.mean()  # Tính trung bình PSNR cho batch
            psnres.update(psnr_mean.item(), lr.size(0))

            ssimes.update(ssim, lr.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if batch_index % print_frequency == 0:
                progress.display(batch_index)

            batch_data = valid_prefetcher.next()
            batch_index += 1

    progress.display_summary()

    # Log metrics to TensorBoard
    if mode == "Valid":
        writer.add_scalar("Valid/PSNR", psnres.avg, epoch + 1)
        writer.add_scalar("Valid/SSIM", ssimes.avg, epoch + 1)
    elif mode == "Test":
        writer.add_scalar("Test/PSNR", psnres.avg, epoch + 1)
        writer.add_scalar("Test/SSIM", ssimes.avg, epoch + 1)
    else:
        raise ValueError("Unsupported mode, please use `Valid` or `Test`.")

    # Log metrics to WandB
    wandb.log({
        f"{mode}/PSNR": psnres.avg,
        f"{mode}/SSIM": ssimes.avg,
        "Epoch": epoch + 1,
    })

    return psnres.avg


In [11]:
class Summary(Enum):
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
        self.name = name
        self.fmt = fmt
        self.summary_type = summary_type
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)

    def summary(self):
        if self.summary_type is Summary.NONE:
            fmtstr = ""
        elif self.summary_type is Summary.AVERAGE:
            fmtstr = "{name} {avg:.2f}"
        elif self.summary_type is Summary.SUM:
            fmtstr = "{name} {sum:.2f}"
        elif self.summary_type is Summary.COUNT:
            fmtstr = "{name} {count:.2f}"
        else:
            raise ValueError(f"Invalid summary type {self.summary_type}")

        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print("\t".join(entries))

    def display_summary(self):
        entries = [" *"]
        entries += [meter.summary() for meter in self.meters]
        print(" ".join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = "{:" + str(num_digits) + "d}"
        return "[" + fmt + "/" + fmt.format(num_batches) + "]"


if __name__ == "__main__":
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mai-quaqducc[0m ([33mai-quaqducc-0312[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.18.7
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20241214_101719-xz183oqb[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mFSRCNN_100_epochs[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/ai-quaqducc-0312/FSRCNN-Training[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/ai-quaqducc-0312/FSRCNN-Training/runs/xz183oqb[0m
  self.pid = os.fork()


Load train dataset and valid dataset successfully.
Build FSRCNN model successfully.
Define all optimizer functions successfully.
Check whether the pretrained model is restored...


  scaler = amp.GradScaler()


Loss functions defined successfully.


  with amp.autocast():


Epoch: [1][ 0/53]	Time 51.415 (51.415)	Data 48.190 (48.190)	Loss 0.243130 (0.243130)	PSNR 7.85 (7.85)	SSIM 0.002 (0.002)
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined succe

  with amp.autocast():


Test: [0/5]	Time  0.284 ( 0.284)	PSNR 9.52 (9.52)	SSIM 0.246 (0.246)
 *  PSNR 7.21 SSIM 0.15


Loss functions defined successfully.
Epoch: [2][ 0/53]	Time 28.416 (28.416)	Data 28.118 (28.118)	Loss 0.224921 (0.224921)	PSNR 8.59 (8.59)	SSIM 0.127 (0.127)
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss functions defined successfully.
Loss fu

  self.pid = os.fork()


Test: [0/5]	Time  0.076 ( 0.076)	PSNR 20.00 (20.00)	SSIM 0.534 (0.534)
 *  PSNR 19.47 SSIM 0.54




# validate

In [12]:
!pip install natsort

  pid, fd = os.forkpty()


Collecting natsort
  Downloading natsort-8.4.0-py3-none-any.whl.metadata (21 kB)
Downloading natsort-8.4.0-py3-none-any.whl (38 kB)
Installing collected packages: natsort
Successfully installed natsort-8.4.0


In [13]:

"""File description: Realize the verification function after model training."""
import os

import cv2
import numpy as np
import torch
from natsort import natsorted




def main() -> None:
    # Initialize the super-resolution model
    model = FSRCNN(upscale_factor).to(device)
    print("Build FSRCNN model successfully.")

    # Load the super-resolution model weights
    checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint["state_dict"])
    print(f"Load FSRCNN model weights `{os.path.abspath(model_path)}` successfully.")

    # Create a folder of super-resolution experiment results
    results_dir = os.path.join("results", "test", exp_name)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # Start the verification mode of the model.
    model.eval()
    # Turn on half-precision inference.
    model.half()

    # Initialize the image evaluation index.
    total_psnr = 0.0

    # Get a list of test image file names.
    file_names = natsorted(os.listdir(hr_dir))
    # Get the number of test image files.
    total_files = len(file_names)

    for index in range(total_files):
        lr_image_path = os.path.join(lr_dir, file_names[index])
        sr_image_path = os.path.join(sr_dir, file_names[index])
        hr_image_path = os.path.join(hr_dir, file_names[index])

        print(f"Processing `{os.path.abspath(hr_image_path)}`...")
        # Read LR image and HR image
        lr_image = cv2.imread(lr_image_path).astype(np.float32) / 255.0
        hr_image = cv2.imread(hr_image_path).astype(np.float32) / 255.0

        # Convert BGR image to YCbCr image
        lr_ycbcr_image = bgr2ycbcr(lr_image, use_y_channel=False)
        hr_ycbcr_image = bgr2ycbcr(hr_image, use_y_channel=False)

        # Split YCbCr image data
        lr_y_image, lr_cb_image, lr_cr_image = cv2.split(lr_ycbcr_image)
        hr_y_image, hr_cb_image, hr_cr_image = cv2.split(hr_ycbcr_image)

        # Convert Y image data convert to Y tensor data
        lr_y_tensor = image2tensor(lr_y_image, range_norm=False, half=True).to(device).unsqueeze_(0)
        hr_y_tensor = image2tensor(hr_y_image, range_norm=False, half=True).to(device).unsqueeze_(0)

        # Only reconstruct the Y channel image data.
        with torch.no_grad():
            sr_y_tensor = model(lr_y_tensor).clamp_(0, 1.0)

        # Cal PSNR
        total_psnr += 10. * torch.log10(1. / torch.mean((sr_y_tensor - hr_y_tensor) ** 2))

        # Save image
        sr_y_image = tensor2image(sr_y_tensor, range_norm=False, half=True)
        sr_y_image = sr_y_image.astype(np.float32) / 255.0
        sr_ycbcr_image = cv2.merge([sr_y_image, hr_cb_image, hr_cr_image])
        sr_image = ycbcr2bgr(sr_ycbcr_image)
        cv2.imwrite(sr_image_path, sr_image * 255.0)

    print(f"PSNR: {total_psnr / total_files:4.2f}dB.\n")


