<a href="https://colab.research.google.com/github/pierclgr/SuperResolution/blob/master/MPRNet.ipynb">
      <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Preliminary steps
We first import all the libraries that we need.

In [47]:
import os
import torch
import torch.utils.data as data
import torchvision.transforms as T
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import time
from torch.utils import data
import numpy as np
import random
from skimage import io

Then, we set up a function that we can use to obtain the backend device on which the training will be executed.

In [48]:
# import torch_xla library if runtime is using a Colab TPU
if 'COLAB_TPU_ADDR' in os.environ:
    import torch_xla.core.xla_model as xm

def get_device() -> str:
    """
    Get the current machine device to use

    :returns: the current machine device
    """

    # if the current runtime is using a Colab TPU, define a flag specifying that TPU will be used
    if 'COLAB_TPU_ADDR' in os.environ:
        use_tpu = True
    else:
        use_tpu = False

    # if TPU is available, use it as device
    if use_tpu:
        device = xm.xla_device()
    else:
        # otherwise use CUDA device or CPU accordingly to the one available
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # if the device is a GPU
        if torch.cuda.is_available():

            # print the details of the given GPU
            stream = os.popen('nvidia-smi')
            output = stream.read()
            print(output)

    print(f">>> Using {device} device")

    return device

# get the backend device
device = get_device()

Fri Dec 10 23:47:47 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P8    27W / 149W |      3MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Data preparation
The first step is to prepare the data that will be used for the training and testing.

## Training data
As fhe first step, we need to prepare the training data by creating a PyTorch DataLoader.

### Obtain dataset
In order to do so, we need to first copy the DIV2K dataset zip file from Google Drive by mounting it, then we need to unzip it. After extracting the images, we delete the zip file in order to free space.

In [49]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

# copy div2k zip file
!echo "Copying the dataset .zip file from Google Drive (may take a while)..."
!mkdir -p /content/data/ && cp /content/drive/MyDrive/Colab\ Notebooks/ML4CV/div2k.zip /content/data/
!echo "Done!"

# unzip div2k zip file
!echo "Unzipping the .zip file (may take a while)..."
!unzip -qq /content/data/div2k.zip -d /content/data/
!echo "Done!"

# deleting the copied zip file to free space
!echo "Deleting zip file to free space..."
!rm /content/data/div2k.zip
!echo "Done!"

KeyboardInterrupt: ignored

### PyTorch Dataset
We first create functions to perform random crops to extract patches and generate random rotations and flips.

In [50]:
def random_crop(lr: np.ndarray, hr: np.ndarray, scale: int = 2, patch_size: int = 64) -> tuple:
    """
    Extracts a random patch of the given size from an input lr image and the corresponding hr patch scaled using the
    given scale.

    :param lr: low-resolution image to extract the patch from (ndarray)
    :param hr: high-resolution image to extract the patch from (ndarray)
    :param scale: scale to use for the extraction of the hr patch (int, default 2)
    :param patch_size: size of the low resolution (square) patch (int, default 64)
    :return: tuple containing the extacted lr and hr patches
    """

    # extract size of the lr image
    height, width = lr.shape[:-1]

    # extract random starting coordinates of the patch in the lr image
    x = random.randint(0, width - patch_size)
    y = random.randint(0, height - patch_size)

    # compute the starting coordinates of the patch in the hr image
    hr_patch_size = patch_size * scale
    hx, hy = x * scale, y * scale

    # extract the patch from the two images
    lr = lr[y:y + patch_size, x:x + patch_size].copy()
    hr = hr[hy:hy + hr_patch_size, hx:hx + hr_patch_size].copy()

    return lr, hr


def random_horizontal_flip(lr: np.ndarray, hr: np.ndarray, p: float = .5) -> tuple:
    """
    Randomly applies horizontal flip to the given lr and hr patches with the given flipping probability

    :param lr: low-resolution patch to flip (ndarray)
    :param hr: high-resolution patch to flip (ndarray)
    :param p: probability of the flipping (float, default 0.5)
    :return: tuple containing the flipped (or not) lr and hr patches
    """

    # flip horizontally the images
    if random.random() < p:
        lr = np.fliplr(lr)
        hr = np.fliplr(hr)

    return lr.copy(), hr.copy()


def random_90_rotation(lr: np.ndarray, hr: np.ndarray) -> tuple:
    """
    Randomly applies a 90° rotation (or not) to the given lr and hr patches

    :param lr: low-resolution patch to rotate (ndarray)
    :param hr: high-resolution patch to rotate (ndarray)
    :return: tuple containing the rotated (or not) lr and hr patches
    """

    # choose a rotation angle (0, 90, -90)
    n_rotations = random.choice([0, 1, 3])

    # rotate the images
    lr = np.rot90(lr, n_rotations)
    hr = np.rot90(hr, n_rotations)

    return lr.copy(), hr.copy()

We now have to create the PyTorch Dataset in order to make data loadable using a PyTorch DataLoader.

In [71]:
class DIV2KDataset(data.Dataset):
    """
    PyTorch dataset loading DIV2K HR and LR images
    """

    def __init__(self, dataset_path: str, scales: list = None, split: str = "train",
                 degradation: str = "bicubic", patch_size: int = 64, augment: bool = True) -> None:
        """
        Constructor method of the class

        :param dataset_path: path of the folder containing dataset images (str)
        :param scales: list containing the resolution scales to consider (list, default None)
        :param split: split of the dataset to use (str, default "train")
        :param degradation: type of degraded images to use (str, default "bicubic")
        :param patch_size: size of the square (patch_size x patch_size) lr patches to extract (int, default 64)
        :param augment: flag to control the augmentation of images (bool, default true)
        """

        super(DIV2KDataset, self).__init__()

        # define dataset path
        self.dataset_path = dataset_path

        # define scales to use if not given
        if not scales:
            self.scales = [2, 3, 4]
        else:
            self.scales = scales

        # define degradation method to use
        self.degradation = degradation.lower()

        # define patch size
        self.patch_size = patch_size

        # define split
        self.split = split.lower()

        # define augmentation
        self.augment = augment

        # extract the image file names from the dataset
        self.filenames = sorted(os.listdir(os.path.join(dataset_path, split, "hr")))

        # define transform
        self.transform = T.Compose([T.ToTensor()])

    def __len__(self) -> int:
        """
        Returns the length of the dataset

        :return: length of the dataset
        """

        return len(self.filenames)

    def __getitem__(self, item) -> tuple:
        """
        Get a HR image and the corresponding LR images in all the scales

        :param item: the chosen item index in the dataset
        """

        # select the image to pick
        file_name = self.filenames[item]

        # extract the HR image from the HR folder
        hr_image_path = os.path.join(self.dataset_path, self.split, "hr", file_name)
        hr_image = io.imread(hr_image_path)

        # define the output tuple as empty
        output_tuple = ()

        # extract the LR images from the LR folder
        for scale in self.scales:

            # extract the LR image for the current scale from the LR folder
            lr_image_path = os.path.join(self.dataset_path, self.split, "lr", self.degradation, "x" + str(scale), file_name)
            lr_image = io.imread(lr_image_path)

            # extract the LR and HR patches from the current scaled LR image and the HR image
            lr_patch, hr_patch = random_crop(lr_image, hr_image, scale)

            # if augmentation is required
            if self.augment:
                # flip the patches
                lr_patch, hr_patch = random_horizontal_flip(lr_patch, hr_patch)

                # rotate the patches
                lr_patch, hr_patch = random_90_rotation(lr_patch, hr_patch)

            # add the current scale_factor-LR-HR triple to the output tuple
            output_tuple += (scale, T.ToTensor()(lr_patch), T.ToTensor()(hr_patch))

        return output_tuple

We then create the collate function for the DataLoader.

In [72]:
def collate_fn(batch: list) -> tuple:
    """
    Collate function for the creation of a dataset

    :param batch: list containing the batch samples extracted from the dataset using its __getitem__ method (list)
    :return: tuple containing the LR and HR batches in the chosen scale
    """

    # unzip the batch
    unzipped = list(zip(*batch))

    # choose a random scale from the ones given for the current batch
    starting_sub_index = random.randint(0, int(len(unzipped) / 3) - 1) * 3
    scale = unzipped[starting_sub_index][0]

    # stack the hr and lr batches into a unique PyTorch tensor
    lr = torch.stack(unzipped[starting_sub_index + 1])
    hr = torch.stack(unzipped[starting_sub_index + 2])

    # return the batch
    return scale, lr, hr

We now test the dataset and DataLoader speed in loading images.

In [None]:
train_dataset = DIV2KDataset("/content/data/div2k")
train_dataloader = data.DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True)

train_dataset.__getitem__(0)
start = time.time()
for _ in tqdm(train_dataloader):
    pass
end = time.time()

print(end - start)