In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import cv2
import os
import torch
import torchinfo
import skimage

from torch.utils.data import Dataset, DataLoader

TRAIN_FOLDER = 'train/'
TEST_FOLDER = 'test/'

EPOCHS = 5000
BATCH_SIZE = 32
LR = 1e-4
WINDOW_SIZE = 33
WINDOW_OFFSET = 14

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f"Device: {device}")

In [None]:
class Mutliresolution(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.first_conv2d = torch.nn.Conv2d(1, 64, kernel_size=(9, 9), padding="same")
        self.second_conv2d = torch.nn.Conv2d(64, 32, kernel_size=(5, 5), padding="same")
        self.third_conv2d = torch.nn.Conv2d(32, 1, kernel_size=(5, 5), padding="same")
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.first_conv2d(x)
        x = self.relu(x)
        x = self.second_conv2d(x)
        x = self.relu(x)
        x = self.third_conv2d(x)
        return x

In [None]:
class ImageDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.Y[index]

In [None]:
def fix_image(image):
    image = image / 255
    image[image > 1] = 1
    image[image < 0] = 0
    image = image.astype(np.float32)
    return image

In [None]:
def expand(image):
    image = image.transpose(-1, 0, 1)

    return image

In [None]:
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        image = cv2.imread(os.path.join(folder, filename), -1)
        if image is not None:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            images.append(fix_image(image))

    return images

In [None]:
def display_image(image, cmap="viridis"):
  plt.imshow(image, cmap=cmap)
  plt.axis('off')
  plt.show()

In [None]:
def display_org_lr_hr(img_org, img_lr, img_hr, metrics, zoom_pos, zoom_size=(50, 50)):
  _, axes = plt.subplots(2, 3, figsize=(10, 5))
  
  img_hr = cv2.cvtColor(img_hr.transpose(1, 2, 0), cv2.COLOR_YCR_CB2BGR)
  img_hr[img_hr < 0] = 0
  img_hr[img_hr > 1] = 1

  axes[0, 0].imshow(img_org)
  axes[0, 0].add_patch(patches.Rectangle(zoom_pos, *zoom_size, edgecolor="r", facecolor="none"))
  axes[0, 0].axis("off")
  axes[0, 0].set_title("ORG / PSNR / SSIM")

  axes[0, 1].imshow(img_lr)
  axes[0, 1].add_patch(patches.Rectangle(zoom_pos, *zoom_size, edgecolor="r", facecolor="none"))
  axes[0, 1].axis("off")
  axes[0, 1].set_title(f"LR / {metrics['psnr_org_lr'] :.2f} dB / {metrics['ssim_org_lr'] :.4f}")

  axes[0, 2].imshow(img_hr)
  axes[0, 2].add_patch(patches.Rectangle(zoom_pos, *zoom_size, edgecolor="r", facecolor="none"))
  axes[0, 2].axis("off")
  axes[0, 2].set_title(f"HR / {metrics['psnr_org_hr'] :.2f} dB /␣{metrics['ssim_org_hr'] :.4f}")

  axes[1, 0].imshow(img_org[zoom_pos[1]:zoom_pos[1]+zoom_size[1], zoom_pos[0]:zoom_pos[0]+zoom_size[0]])
  axes[1, 0].axis("off")

  axes[1, 1].imshow(img_lr[zoom_pos[1]:zoom_pos[1]+zoom_size[1], zoom_pos[0]:zoom_pos[0]+zoom_size[0]])
  axes[1, 1].axis("off")

  axes[1, 2].imshow(img_hr[zoom_pos[1]:zoom_pos[1]+zoom_size[1], zoom_pos[0]:zoom_pos[0]+zoom_size[0]])
  axes[1, 2].axis("off")
  plt.tight_layout()

In [None]:
def display_filters(filters, title):
    # [N, C, H, W] -> [N, H, W, C]
    filters = filters.permute(0, 2, 3, 1)
    filters = filters.clamp(0, 1)

    # Sortiraj po padajoči varianci
    indices = torch.argsort(torch.var(filters, axis=(1, 2, 3)), descending=True)
    filters = filters[indices]
    _, axes = plt.subplots(4, 16, figsize=(16, 4))

    for r in range(axes.shape[0]):
        for c in range(axes.shape[1]):
            if r == 0 and c == 0:
                axes[r, c].set_title(title)
            axes[r, c].imshow(filters[r * axes.shape[1] + c], cmap="gray")
            axes[r, c].axis("off")

    plt.tight_layout()

In [None]:
def display_images(images, num, f1, f2, title=""):
    indices = np.random.default_rng().choice(len(images), num, replace=False)
    _, axes = plt.subplots(f1, f2, figsize=(8, 3), squeeze=False)

    for r in range(axes.shape[0]):
        for c in range(axes.shape[1]):
            if r == 0 and c == 0:
                axes[r, c].set_title(title)
            axes[r, c].imshow(images[indices[r * axes.shape[1] + c]])
            axes[r, c].axis("off")

    plt.tight_layout()

In [None]:
def display_dataset_info(images, title, num, f1, f2):
    heights = list(map(lambda image: image.shape[0], images))
    widths = list(map(lambda image: image.shape[1], images))

    print(f"Število slik: {len(images)}")
    print(f"Interval širine slik: [{np.min(widths)}, {np.max(widths)}]")
    print(f"Interval višine slik: [{np.min(heights)}, {np.max(heights)}]")
    print(f"Povprečna širina slik: {np.mean(widths) :.2f} +- {np.std(widths) :.2f}")
    print(f"Povprečna višina slik: {np.mean(heights) :.2f} +- {np.std(heights) :.2f}")
    print("Primeri slik:")
    display_images(images, num, f1, f2, title)

In [None]:
def preprocess(images, factor, type):
    low_res = []
    originals = []

    for image in images:
        smaller_size = (int(image.shape[1] / factor), int(image.shape[0] / factor))
        bigger_size = (int(image.shape[1]), int(image.shape[0]))
        
        low_res_image = cv2.resize(image, smaller_size, cv2.INTER_LINEAR)
        low_res_image = cv2.resize(low_res_image, bigger_size, cv2.INTER_LINEAR)

        low_res_image = cv2.cvtColor(low_res_image, cv2.COLOR_BGR2YCR_CB)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2YCR_CB)

        if type == 0:
            low_res_image = cv2.copyMakeBorder(low_res_image, 0, 33, 0, 33, cv2.BORDER_CONSTANT, None, value=0)
            image = cv2.copyMakeBorder(image, 0, 33, 0, 33, cv2.BORDER_CONSTANT, None, value=0)

            for i in range(0, smaller_size[0], WINDOW_OFFSET):
                for j in range(0, smaller_size[1], WINDOW_OFFSET):
                    low_res.append(low_res_image[j:j+WINDOW_SIZE, i:i+WINDOW_SIZE, 0:1])

            for i in range(0, bigger_size[0], WINDOW_OFFSET):
                for j in range(0, bigger_size[1], WINDOW_OFFSET):
                    originals.append(image[j:j+WINDOW_SIZE, i:i+WINDOW_SIZE, 0:1])
        else:
            low_res.append(low_res_image[:, :, 0:1])
            originals.append(image[:, :, 0:1])

    display_dataset_info(low_res, "LOW RESOLUTION IMAGES, FACTOR: " + str(factor), 5, 1, 5)
    display_dataset_info(originals, "ORIGINAL IMAGES, FACTOR: " + str(factor), 5, 1, 5)

    return low_res, originals

In [None]:
def train(originals, low_res, epochs, factor):
    model = Mutliresolution()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = torch.nn.MSELoss()

    for i in range(len(low_res)):
        low_res[i] = expand(low_res[i])

    for i in range(len(originals)):
        originals[i] = expand(originals[i])

    model = model.to(device)

    model.train()

    dataset = ImageDataset(low_res, originals)
    dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

    for epoch in range(epochs):
        print('epoch:' + str(epoch + 1) + ' of ' + str(epochs))

        for step, (low_res, originals) in enumerate(dataloader):
            if (step + 1) % 100 == 0:
                print('training image:' + str(step + 1) + ' of ' + str(len(dataloader)))

                low_res = low_res.to(device)
                originals = originals.to(device)

                pred = model(low_res)

                loss = loss_fn(pred, originals)

                optimizer.zero_grad()

                loss.backward()

                optimizer.step()

    torch.save(model, "multiresolution-factor" + str(factor) + ".pt")

In [None]:
def test(originals, low_res, unmodified, factor, model):
    for step in range(len(low_res)):
        original_image = torch.from_numpy(expand(originals[step])).to(device)
        low_res_image = torch.from_numpy(expand(low_res[step])).to(device)
        
        smaller_size = (int(unmodified[step].shape[1] / factor), int(unmodified[step].shape[0] / factor))        
        bigger_size = (int(unmodified[step].shape[1]), int(unmodified[step].shape[0]))
        low_res_rgb = cv2.resize(unmodified[step], smaller_size, cv2.INTER_LINEAR)
        low_res_rgb = cv2.resize(unmodified[step], bigger_size, cv2.INTER_LINEAR)

        low_res_ycbcr = cv2.cvtColor(low_res_rgb, cv2.COLOR_BGR2YCR_CB)
        low_res_ycbcr = low_res_ycbcr.transpose(-1, 0, 1)
        low_res_ycbcr = np.expand_dims(low_res_ycbcr, axis = 0)
        low_res_ycbcr = torch.from_numpy(low_res_ycbcr).to(device)

        pred = model(low_res_image)
        pred = pred.clamp(0, 1)

        high_res_image = torch.concat([pred.unsqueeze(0), low_res_ycbcr[:, 1:3]], axis=1)

        low_res_image_numpy = low_res_image[:, 0:1].squeeze().cpu().detach().numpy()
        original_image_numpy = original_image[:, 0:1].squeeze().cpu().detach().numpy()
        high_res_image_numpy = pred[:, 0:1].squeeze().cpu().detach().numpy()

        if step < 3:
            metrics = {
                "psnr_org_lr": skimage.metrics.peak_signal_noise_ratio(original_image_numpy, low_res_image_numpy),
                "psnr_org_hr": skimage.metrics.peak_signal_noise_ratio(original_image_numpy, high_res_image_numpy),
                "ssim_org_lr": skimage.metrics.structural_similarity(original_image_numpy, low_res_image_numpy),
                "ssim_org_hr": skimage.metrics.structural_similarity(original_image_numpy, high_res_image_numpy)
            }
            display_org_lr_hr(unmodified[step], low_res_rgb, high_res_image.cpu().detach().numpy()[0], metrics, (200, 160))

        # display_image(pred_img)

In [None]:
images = load_images_from_folder(TRAIN_FOLDER)

display_dataset_info(images, "DATASET", 15, 3, 5)

In [None]:
for factor in range(2, 4):
    low_res, originals = preprocess(images, factor, 0)
    train(originals, low_res, EPOCHS, factor)

In [None]:
for factor in range(2, 4):
    print("==========================================================================================")
    print("FACTOR: " + str(factor))
    print("==========================================================================================")

    model = torch.load("multiresolution-factor" + str(factor) + ".pt", map_location=device)

    print(torchinfo.summary(model, (1, 1, 33, 33)))

    display_filters(model.first_conv2d.weight.cpu().detach(), "FACTOR " + str(factor) + " FIRST CONV2D LAYER FILTERS")
    
    print("\n\n")

In [None]:
for factor in range(2, 4):
    model = torch.load("multiresolution-factor" + str(factor) + ".pt", map_location=device)

    for dataset in [f.path for f in os.scandir(TEST_FOLDER) if f.is_dir()]:
        images = load_images_from_folder(dataset + "/test/")

        display_dataset_info(images, "DATASET: " + dataset.split("/")[1] + ", FACTOR: " + str(factor), 5, 1, 5)

        low_res, originals = preprocess(images, factor, 1)

        test(originals, low_res, images, factor, model)