Data Set Interface(数据处理)

In [None]:
from os import listdir # For file operations
from os.path import join # For catalog operations
from PIL import Image # For image processing
from torch.utils.data.dataset import Dataset # Base class for defining datasets
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize # For image preprocessing

#--- Used to determine if a file name is an image file ---#
def is_image_file(filename):
    return any(
        filename.endswith(extension)
        for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']
    ) # Checks if the filename suffix is a common image file format, if so, returns True, otherwise returns False

#--- Adjust the cut size according to the given cut size and upsampling factor ---#
def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor) # An integer multiple of the crop size is obtained by subtracting it from the remainder of the upsampling factor

#--- Training set high-resolution graph preprocessing function ---#
def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size), # Perform random cropping operations
        ToTensor(), # Converting images to tensors
    ])

#--- Training set low-resolution map preprocessing function ---#
def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        ToPILImage(), # Convert tensor to PIL image
        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC), # Resize the image and interpolate using bicubic interpolation (BICUBIC)
        ToTensor() # Converting images to tensors
    ])

#--- Preprocessing operations for displaying images ---#
def display_transform():
    return Compose(
        [ToPILImage(), # Converting images to PIL images
         Resize(400), # Resize the image to 400x400
         CenterCrop(400), # Perform Center Trimming
         ToTensor()] # Converting images to tensors
    )

#--- Training dataset class ---#
class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor): # Initialize the dataset based on the given dataset_dir, crop_size and upscale_factor
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames = [
            join(dataset_dir, x) for x in listdir(dataset_dir)
            if is_image_file(x)
        ]  # Filter the image files by traversing the files in dataset_dir and save them in the image_filenames list
        crop_size = calculate_valid_crop_size(crop_size,upscale_factor) # Calculate a suitable crop size, making sure that the crop size is divisible by the upscale_factor.
        self.hr_transform = train_hr_transform(crop_size)  # High-resolution map preprocessor function
        self.lr_transform = train_lr_transform(crop_size,upscale_factor)  # Low resolution map preprocessing function

    # Return low-resolution images and high-resolution images as training samples
    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))  # Randomized cropping and pre-processing to obtain high resolution maps
        lr_image = self.lr_transform(hr_image)  # Preprocessing to obtain low-resolution maps
        return lr_image, hr_image # Returns low-resolution images and high-resolution images

    # Returns the size of the dataset, i.e. the number of image files
    def __len__(self):
        return len(self.image_filenames)

#--- Validation dataset ---#
class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(ValDatasetFromFolder, self).__init__()
        self.upscale_factor = upscale_factor
        self.image_filenames = [
            join(dataset_dir, x) for x in listdir(dataset_dir)
            if is_image_file(x)
        ] # Here the function code is explained as above

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index]) # Index which opens high-resolution images
        w, h = hr_image.size # Get the width w and height h of the image
        crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor) # Calculate a suitable crop size, making sure that the crop size is divisible by the upscale_factor.
        lr_scale = Resize(crop_size // self.upscale_factor,interpolation=Image.BICUBIC) # Adjusting the image size
        hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
        hr_image = CenterCrop(crop_size)(hr_image)  # Center cropping of high-resolution images
        lr_image = lr_scale(hr_image)  # Obtaining low-resolution maps
        hr_restore_img = hr_scale(lr_image) # Recovered high-resolution image
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image) # The processed image data are all converted to a tensor, and the low-resolution image, the recovered high-resolution image, and the original high-resolution image are returned as validation samples

    # Returns the size of the dataset, i.e. the number of image files
    def __len__(self):
        return len(self.image_filenames)

#--- Test Data Set ---#
class TestDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(TestDatasetFromFolder, self).__init__()
        self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/' # Constructed a list of file paths for low-resolution images
        self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/' # Constructed a list of file paths for high-resolution images
        self.upscale_factor = upscale_factor # Constructed directory paths to hold low-resolution images and high-resolution images
        self.lr_filenames = [
            join(self.lr_path, x) for x in listdir(self.lr_path)
            if is_image_file(x)
        ]
        self.hr_filenames = [
            join(self.hr_path, x) for x in listdir(self.hr_path)
            if is_image_file(x)
        ]

    def __getitem__(self, index):
        image_name = self.lr_filenames[index].split('/')[-1]
        lr_image = Image.open(self.lr_filenames[index]) # Get the image name from the low-resolution image file list and open the image
        w, h = lr_image.size
        hr_image = Image.open(self.hr_filenames[index]) # Get the image name from the list of high-resolution image files and open the image
        hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w),interpolation=Image.BICUBIC) # Enlargement recovery operation for low-resolution images
        hr_restore_img = hr_scale(lr_image) # Obtaining recovered high-resolution images
        return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.lr_filenames)

Generator（生成器）

In [None]:
import math # For math calculations
import torch
from torch import nn # Importing the nn (neural network) module

#--- Generator Model ---#
class Generator(nn.Module):
    def __init__(self, scale_factor): # Accepts a parameter scale_factor, representing the upsampling factor
        upsample_block_num = int(math.log(scale_factor, 2)) # The number of upsampling blocks is calculated based on the upsampling factor. Here the upsampling block is the module used in the generator to recover the resolution

        super(Generator, self).__init__() # Initialize the base class of the generator
        self.block1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, padding=4),nn.PReLU()) # The first convolutional layer with a convolutional kernel size of 9×9, the number of input channels is 3 and the number of output channels is 64. nn.PReLU activation function
        # 6 residual blocks
        self.block2 = ResidualBlock(64) # Setting the number of input channels to 64 indicates that the output from the first convolutional block accesses the
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64)) # Sequence module containing convolution and batch normalization layers nn.BatchNorm2d.
        # upsample_block_num upsample blocks
        block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] # Recover 2x upsampling multiplier per upsampling module
        # The last convolutional layer with a convolutional kernel size of 9×9, 64 input channels and 3 output channels
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x): # Define the forward propagation method of the generator with input x, denoting the input image
        block1 = self.block1(x) # Pass the input image through the first convolution block block1

        # Pass the input image layer by layer to get the feature map in the middle block7
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        # The output of the first convolution block, block1, is summed with the output of the seventh block, block7, and then passed through the upsampling module sequence, block8.
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2 # The final generated feature map block8 is range-mapped by the Tanh activation function to values between [-1,1] and then mapped to [0,1] by some operations

#--- Residual blocks in the generator ---#
# In the super-resolution task, residual blocks are used to learn the details and textures of the image
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        # Two convolutional layers with a convolutional kernel size of 3×3 and a constant number of channels to keep the feature map size constant
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels) # Batch normalization layer for accelerating the training process and stabilizing the learning of the network
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x): # Define the forward propagation method of the generator with input x, denoting the input image
        # The first passes through the first convolutional layer, batch normalization and activation function, followed by the second convolutional layer and batch normalization. Finally, the inputs and residuals are summed, resulting in a residual join
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual

#--- Upsampling blocks in the generator---#
# In the super-resolution task, the upsampling block is used to increase the resolution of the image
class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, # Number of input channels
            in_channels * up_scale**2, # Number of output channels, each with an enhanced resolution of 2
            kernel_size=3, # Convolution kernel size
            padding=1
        ) # Convolutional layer for upsampling the low-resolution feature maps
        self.pixel_shuffle = nn.PixelShuffle(up_scale) # Pixel rearrangement layer for performing upsampling operations. It converts the input low-resolution feature map into a high-resolution image
        self.prelu = nn.PReLU() # Correcting the linear cell activation function

    def forward(self, x): # Define the forward propagation method of the generator
        # The feature map is upsampled by a convolutional layer, then the final upsampling operation is performed using a pixel rearrangement layer, and finally the activation function is applied
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

Discriminator（判别器）

In [None]:
#--- Discriminator Model ---#
# Used to discriminate the generated images to determine if they are similar to real images
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # A series of nn.Conv2d: defines a series of convolutional layers that are used to extract features from the image. Each convolutional layer alternates between nn.BatchNorm2d for batch normalization and nn.LeakyReLU as an activation function.

            # 1st convolutional layer with convolutional kernel size 3×3, number of input channels 3, number of output channels 64
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            # 2nd convolutional layer with 3×3 convolutional kernel size, 64 input channels and 64 output channels
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            # 3rd convolutional layer with a convolutional kernel size of 3×3, 64 input channels and 128 output channels
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            # 4th convolutional layer with 3×3 convolutional kernel size, 128 input channels and 128 output channels
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            # 5th convolutional layer with a convolutional kernel size of 3×3, 128 input channels and 256 output channels
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            # 6th convolutional layer with a convolutional kernel size of 3×3, 256 input channels and 256 output channels
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            # 7th convolutional layer with a convolutional kernel size of 3×3, 256 input channels and 512 output channels
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            # The 8th convolutional layer with a convolutional kernel size of 3×3 and a number of input channels of 512 and output channels of 512
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            # Global pooling layer for converting feature maps to global feature vectors
            nn.AdaptiveAvgPool2d(1),
            # Two fully-connected layers, implemented using convolution, are used to map a global feature vector to a scalar value, which is used to discriminate whether the input image is real or not
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1))

    def forward(self, x): # Define the forward propagation method of the generator
        # Through a series of convolutional and pooling layers, the global features are finally mapped to a scalar output through a convolutional layer
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size)) # Restricting the output to the range (0, 1) indicates the probability of the discriminatory result

Define the loss function（定义损失函数）

In [None]:
import torch
from torch import nn # Neural network module
from torchvision.models.vgg import vgg16 # Importing the VGG-16 model
import os # For operating the operating system
os.environ['TORCH_HOME'] = './' # The TORCH_HOME environment variable will be set to the current directory, which will be used to specify the path where PyTorch downloads pre-trained models


#--- Definition of generator loss function ---#
class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True) # Loading the pre-trained VGG-16 model
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() # Extracting the 31 layers of the former VGG-16 model as a loss network
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()  # Mean square error loss function MSE
        self.tv_loss = TVLoss()  # TV smoothing loss function

    # Adversarial loss, perceptual loss, image MSE loss and TV smoothing loss were calculated and summed with certain weights to form the combined loss of the generator
    def forward(self, out_labels, out_images, target_images):
        adversarial_loss = torch.mean(1 - out_labels) # Against loss, the generator outputs labels out_labels with the difference between the mean of the target labels and 1
        perception_loss = self.mse_loss(self.loss_network(out_images),self.loss_network(target_images)) # Perceptual loss, the generated image and the target image are fed into the loss network (the first 31 layers of vgg16) separately and their mean square error losses are calculated
        image_loss = self.mse_loss(out_images, target_images) # Image MSE loss, the difference between the generated image and the target image target_images
        tv_loss = self.tv_loss(out_images) # TV smoothing loss
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss # Return the weighted sum of the four losses as the total loss of the generator

#--- Definition of TV smoothing loss function ---#
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight # Denotes the weight of the TV smoothing loss, default is 1

    def forward(self, x): # Define the forward propagation method of the generator
        # Get the batch size, height and width of the input tensor
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        # Indicates the number of pixel pairs in the vertical and horizontal directions used to normalize the loss value
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        # Calculate the TV smoothing loss in the vertical and horizontal directions
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        # Return weighted TV smoothing loss
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3] # Calculate the number of elements of the input tensor

# Create an instance of the generator loss function and print the output
if __name__ == "__main__":
    g_loss = GeneratorLoss()
    print(g_loss)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to ./hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:14<00:00, 38.3MB/s]


GeneratorLoss(
  (loss_network): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding

Define the optimizer（定义优化器）

In [None]:
from math import exp # Introducing the exponential function exp for Gaussian functions

import torch
import torch.nn.functional as F # Import PyTorch's functional interface module for calling functions such as convolution operations
from torch.autograd import Variable # Import the Variable class from the AutoDerivative module to create the variable that stores the gradient

#--- Define the Gaussian function---#
def gaussian(window_size, sigma): # Create a Gaussian weight tensor based on window size and standard deviation
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2))
                          for x in range(window_size)]) # Calculate the value of the Gaussian function by looping through it and store the result in the gauss tensor
    return gauss / gauss.sum() # Normalize the gauss tensor to ensure that the weights sum to 1

#--- Creating Gaussian Window Functions ---#
def create_window(window_size, channel): # Create a 2D Gaussian window tensor based on window size and number of channels
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1) # Create a one-dimensional Gaussian weight tensor _1D_window and add a dimension to the first dimension
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) # Multiplying the one-dimensional weight tensor with its transpose gives the two-dimensional Gaussian window tensor
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) # Extending the 2D Gaussian window tensor to a shape that matches the number of input image channels
    return window

#--- Calculation of the Structural Similarity Index (SSIM)---#
def _ssim(img1, img2, window, window_size, channel, size_average=True): # Accepts two images, img1 and img2, and the previously created Gaussian window window to calculate the SSIM value
    # Calculate the mean of the input image and window using the convolution operation (F.conv2d)
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    # Calculate the square of the mean
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    # Calculate the variance (sigma1_sq and sigma2_sq) and covariance (sigma12) of the image using the convolution operation
    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    # Calculating SSIM images
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    # Returns a value representing the structural similarity between images
    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

#--- Defining SSIM Classes---#
# For the calculation of the Structural Similarity Index (SSIM)
class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True): # Initialization method to create an instance of the SSIM class
        super(SSIM, self).__init__()
        self.window_size = window_size # Used to store the size of the Gaussian window
        self.size_average = size_average # Used to store whether results are averaged
        self.channel = 1
        self.window = create_window(window_size, self.channel) # Creates a Gaussian window, calling the previously defined create_window function

    def forward(self, img1, img2): # Forward propagation method to compute the structural similarity index (SSIM) between the input images img1 and img2
        (_, channel, _, _) = img1.size() # Get the dimension information of the input image img1, channel stores the number of channels.

        if channel == self.channel and self.window.data.type() == img1.data.type(): # Check that the current number of channels is the same as the number of channels previously stored, and that the Gaussian window previously created is on the same device
            window = self.window # Use the previously created Gaussian window directly.
        else:
            window = create_window(self.window_size, channel) # Create a new Gaussian window based on the current number of channels and Gaussian window size

            if img1.is_cuda: # Check if the input image img1 is on the GPU
                window = window.cuda(img1.get_device()) # Move the Gaussian window to the same GPU device as the input image
            window = window.type_as(img1) # Set the data type of the Gaussian window to the same as the input image

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average) # Calculate the structural similarity between images and return the result

def ssim(img1, img2, window_size=11, size_average=True): # Used to calculate the structural similarity index (SSIM) between the input images img1 and img2
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel) # Create a Gaussian window

    if img1.is_cuda: # Check if the input image img1 is on the GPU
        window = window.cuda(img1.get_device()) # Move the Gaussian window to the same GPU device as the input image
    window = window.type_as(img1) # Set the data type of the Gaussian window to the same as the input image

    return _ssim(img1, img2, window, window_size, channel, size_average) # Calculate the structural similarity between images and return the result

Model Training（模型训练）

In [None]:
import os # For processing file paths
from math import log10 # Used to calculate logarithmic values

import pandas as pd # For processing and analyzing data
import torch.optim as optim # Used to define the optimizer
import torch.utils.data # Used to define and manipulate data loaders
import torchvision.utils as utils # For processing image data
from torch.autograd import Variable # Automatic Derivative Functions for Defining Variables
from torch.utils.data import DataLoader # Used to create data loaders
from tqdm import tqdm # Used to display a progress bar in a loop


if __name__ == '__main__': # Check if the current script is executed in the main program
    CROP_SIZE = 240 #opt.crop_size # crop_size, the size of the image crop used for training
    UPSCALE_FACTOR = 4 # oversampling multiplier
    NUM_EPOCHS = 4 # Iteration epoch number

    # Get training set/validation set image data
    train_set = TrainDatasetFromFolder('/content/drive/MyDrive/ImageNet', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
    val_set = ValDatasetFromFolder('/content/drive/MyDrive/Colab Notebooks/data/ImageNet', upscale_factor=UPSCALE_FACTOR)
    # DataLoader for training and validation sets created through DataLoader
    train_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=16, shuffle=True) # Data loader for training, each batch contains 16 images that are randomly shuffled
    val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False) # Data loader for validation, one image at a time, no shuffling

    netG = Generator(UPSCALE_FACTOR) # Generator Definitions
    netD = Discriminator() # Discriminator Definition
    generator_criterion = GeneratorLoss() # Generator Loss Function

    # Whether to use the GPU and move the model and loss function to the GPU
    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()

    # Generator and discriminator optimizers for optimizing generator and discriminator parameters
    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())

    # For storing loss and evaluation metrics during training, including generator loss, discriminator loss, discriminator score, generator score, PSNR (Peak Signal-to-Noise Ratio), and SSIM (Structural Similarity Index)
    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

    # A for loop is used to iterate through each epoch, performing training and loss calculations on each batch.
    for epoch in range(1, NUM_EPOCHS + 1):
        train_bar = tqdm(train_loader) # Use tqdm to create a progress bar, train_bar, that displays training progress
        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0} # Used to keep track of some statistical information in the current epoch, including batch_sizes (number of images currently processed), d_loss (discriminator loss), g_loss (generator loss), d_score (discriminator score), g_score (generator score)

        netG.train() # Generator training
        netD.train() # Discriminator training

        # Each epoch of data iteration, traversing each batch in the training dataset
        for data, target in train_bar:
            g_update_first = True # Flag to control whether the generator is updated first
            batch_size = data.size(0) # Get the size of the current batch
            running_results['batch_sizes'] += batch_size # Add the number of images of the current batch to the batch_sizes field of valing_results

            # Optimize the discriminator to maximize D(x)-1-D(G(z))
            real_img = Variable(target) # The goal of preserving real images
            if torch.cuda.is_available():
                real_img = real_img.cuda() # Some pre-processing of the real image, e.g. moving the data to the GPU (if available)
            z = Variable(data) # Save the target of the generated image
            if torch.cuda.is_available():
                z = z.cuda()
            fake_img = netG(z) # Generate fake_img (super-resolution image) from input data z (low-resolution image) via generator netG
            netD.zero_grad() # Zeroing the parameter gradient of the discriminator
            # Use netD to discriminate between the real image and the generated image to get real_out and fake_out respectively (output of the discriminator)
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out # Calculate the discriminator loss d_loss, which is used to maximize D(x)-1- D(G(z)), i.e., to encourage the discriminator to correctly discriminate between the real image and the generated image
            # Backpropagation and optimization of discriminator parameters to reduce d_loss
            d_loss.backward(retain_graph=True)
            optimizerD.step() # optimization discriminator

            # Optimization Generator Minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            netG.zero_grad() # Zeroing the parameter gradient of the generator
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()
            g_loss = generator_criterion(fake_out, fake_img, real_img) # Calculate generator loss g_loss, including perceptual loss, image loss and smoothing loss

            # Backpropagation and optimization of generator parameters to reduce g_loss
            g_loss.backward()

            optimizerG.step() # Optimization Generator

            # Calculate the various losses and scores for the current batch
            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size
            train_bar.set_description(
                desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']
                )
            ) # Update the information displayed in the progress bar, including the current epoch, the total number of epochs, the discriminator loss, the generator loss, the real image score and the generated image score, etc.

        # Validating the validation set
        netG.eval() # Switch generator netG to validation mode
        out_path = '/content/drive/MyDrive/training_results/SRF_' + str(UPSCALE_FACTOR) + '/' # Create a directory out_path to store the validation results; UPSCALE_FACTOR is the super-resolution upsampling multiplier used to construct the storage paths
        if not os.path.exists(out_path):
            os.makedirs(out_path) # If the directory path out_path does not exist, create the directory out_path.

        # Calculate validation set related metrics
        with torch.no_grad(): # means that no gradient computation will be performed in the following statement block to save memory and computational resources
            val_bar = tqdm(val_loader) # Use tqdm to create a progress bar, val_bar, that displays the progress of the evaluation on the validation set
            valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0} # Used to save metrics from the validation process, including mse (mean square error), ssims (sum of structural similarity metrics), psnr (peak signal-to-noise ratio), and ssim (structural similarity metrics)
            val_images = [] # Create an empty list to store the super-resolution images on the validation set
            for val_lr, val_hr_restore, val_hr in val_bar: # Start a loop that traverses each sample in the validation set
                batch_size = val_lr.size(0) # Get the size of the current batch
                valing_results['batch_sizes'] += batch_size # Add the number of images of the current batch to the batch_sizes field of valing_results
                lr = val_lr # Low resolution truth map
                hr = val_hr # High Resolution Truth Map
                if torch.cuda.is_available():
                    lr = lr.cuda() # Moving low-resolution truth maps to the GPU
                    hr = hr.cuda() # Moving high-resolution truth maps to the GPU
                sr = netG(lr) # Super-resolution reconstruction of low-resolution images

                batch_mse = ((sr - hr) ** 2).data.mean() # Calculates the mean square error (MSE) of the current batch, which measures the difference between the generated super-resolution image and the true high-resolution image
                valing_results['mse'] += batch_mse * batch_size # Add the MSE of the current batch multiplied by the batch size to the mse field of valing_results
                valing_results['psnr'] = 10 * log10(1 / (valing_results['mse'] / valing_results['batch_sizes'])) # Calculates the peak signal-to-noise ratio (PSNR) metric for the current batch, which measures the quality of the image reconstruction
                batch_ssim = ssim(sr, hr).item() # 计Calculates the structural similarity index (SSIM) of the current batch, which measures the structural similarity between the generated super-resolution image and the real high-resolution image
                valing_results['ssims'] += batch_ssim * batch_size # Add the SSIM of the current batch multiplied by the batch size to the ssims field of valing_results
                valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes'] # Calculate and update the average structural similarity metric (SSIM) in valing_results
        # Store generator, discriminator model parameters
        torch.save(netG.state_dict(), '/content/drive/MyDrive/epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch)) #  UPSCALE_FACTOR为超分辨率倍率，epoch为当前 epoch数
        torch.save(netD.state_dict(), '/content/drive/MyDrive/epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
        # Record the training set loss as well as the psnr,ssim and other metrics of the validation set \scores\psnr\ssim
        results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes']) # Add the average of the discriminator loss divided by the batch size for the current epoch to the d_loss list in the results dictionary results
        results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] / running_results['batch_sizes']) # Add the average of the discriminator scores from the current epoch divided by the batch size to the d_score list in the results dictionary results
        results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
        results['psnr'].append(valing_results['psnr']) # Add the PSNR metrics from the validation set in the current epoch to the psnr list in the results dictionary results
        results['ssim'].append(valing_results['ssim'])

        # Store results to a local file
        if epoch % 10 == 0 and epoch != 0:
            out_path = '/content/drive/MyDrive/statistics' # Setting the path to the directory where statistics are stored
            data_frame = pd.DataFrame( # Save the loss and evaluation metrics from the training process into a DataFrame
                data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
                      'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']}, # Create a dictionary containing the names of the loss and evaluation indicators as keys and the corresponding lists as values
                index=range(1, epoch + 1)) # Set the index of the DataFrame to 1 to the current epoch number.
            data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch') # Save the data in the DataFrame as a CSV file, with UPSCALE_FACTOR as the super resolution factor

[1/4] Loss_D: 0.9761 Loss_G: 0.0181 D(x): 0.6265 D(G(z)): 0.5953: 100%|██████████| 196/196 [06:52<00:00,  2.11s/it]
100%|██████████| 345/345 [00:23<00:00, 14.92it/s]
[2/4] Loss_D: 1.0008 Loss_G: 0.0097 D(x): 0.7308 D(G(z)): 0.7296: 100%|██████████| 196/196 [04:43<00:00,  1.45s/it]
100%|██████████| 345/345 [00:23<00:00, 14.88it/s]
[3/4] Loss_D: 0.9994 Loss_G: 0.0076 D(x): 0.9190 D(G(z)): 0.9207: 100%|██████████| 196/196 [04:44<00:00,  1.45s/it]
100%|██████████| 345/345 [00:23<00:00, 14.79it/s]
[4/4] Loss_D: 1.0000 Loss_G: 0.0069 D(x): 0.9997 D(G(z)): 0.9997: 100%|██████████| 196/196 [04:44<00:00,  1.45s/it]
100%|██████████| 345/345 [00:23<00:00, 14.95it/s]


Model Testing(模型测试)

In [50]:
import torch # For building and manipulating deep learning models
from PIL import Image # For image reading and processing
from torch.autograd import Variable # Variables for wrapping images into computable gradients
from torchvision.transforms import ToTensor, ToPILImage # For converting between PyTorch Tensor and PIL images

UPSCALE_FACTOR = 4 # upsampling multiplier
TEST_MODE = True # Testing with GPUs

IMAGE_NAME = "/content/drive/MyDrive/SampleImage/NO.20class4deer.jpg"  # Test Image Path
MODEL_NAME = '/content/drive/MyDrive/epochs/netG_epoch_4_4.pth' # The path of a trained model
model = Generator(UPSCALE_FACTOR).eval() # Create a generator model and set it to validation mode

if TEST_MODE:
    model.cuda() # Moving the model to the GPU
    model.load_state_dict(torch.load(MODEL_NAME),False) # Load the weights of the trained generator model
else:
    model.load_state_dict(torch.load(MODEL_NAME, map_location=lambda storage, loc: storage)) # Load model weights and move the model to the CPU

image = Image.open(IMAGE_NAME) # Read the image to be tested
# image = Variable(ToTensor()(image), volatile=True).unsqueeze(0) # Image Preprocessing
image = ToTensor()(image).unsqueeze(0)
if TEST_MODE:
    image = image.cuda() # Moving images to the GPU

with torch.no_grad(): # Avoid calculating gradients
    RESULT_NAME = "/content/drive/MyDrive/ImageReconstructed" + str(UPSCALE_FACTOR) + "_" + IMAGE_NAME.split("/")[-1] # Constructs the filename of the output image
    out = model(image) # The image is fed into the generator model for super-resolution reconstruction to get the output image
    #out_img = ToPILImage()(out[0].data.cpu()) # Converting the output PyTorch Tensor to a PIL image
    out_img = ToPILImage()(out[0].cpu())
    out_img.save(RESULT_NAME) # Save the output image to a file