In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import torch
import torchinfo
import skimage

TRAIN_FOLDER = 'train/'
TEST_FOLDER = 'test/'

EPOCHS = 10
#BATCH_SIZE = 32

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

print(f"Device: {device}")

In [None]:
class Mutliresolution(torch.nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.first_conv2d = torch.nn.Conv2d(1, 64, kernel_size=(9, 9), padding="same")
    self.second_conv2d = torch.nn.Conv2d(64, 32, kernel_size=(5, 5), padding="same")
    self.third_conv2d = torch.nn.Conv2d(32, 1, kernel_size=(5, 5), padding="same")
    self.relu = torch.nn.ReLU()

  def forward(self, x):
    x = self.first_conv2d(x)
    x = self.relu(x)
    x = self.second_conv2d(x)
    x = self.relu(x)
    x = self.third_conv2d(x)
    return x

In [None]:
def prepare_image():
    print("prepare")

In [None]:
def fix_image(image):
  image[image > 1] = 1
  image[image < 0] = 0

In [None]:
def expand(image): 
  image = image / 255
  fix_image(image)

  image = image.astype(np.float32)
  
  expanded = image.transpose(-1, 0, 1)
  expanded = np.expand_dims(expanded, axis = 0)
  return expanded

In [None]:
def load_images_from_folder(folder):
  images = []
  for filename in os.listdir(folder):
    image = cv2.imread(os.path.join(folder, filename), -1)
    if image is not None:
      image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
      images.append(image)

  return images

In [None]:
def display_filters(filters, title):
  # [N, C, H, W] -> [N, H, W, C]
  filters = filters.permute(0, 2, 3, 1)
  filters = filters.clamp(0, 1)

  # Sortiraj po padajoči varianci
  indices = torch.argsort(torch.var(filters, axis=(1, 2, 3)), descending=True)
  filters = filters[indices]
  _, axes = plt.subplots(4, 16, figsize=(16, 4))

  for r in range(axes.shape[0]):
    for c in range(axes.shape[1]):
      if r == 0 and c == 0:
          axes[r, c].set_title(title)
      axes[r, c].imshow(filters[r * axes.shape[1] + c], cmap="gray")
      axes[r, c].axis("off")
      
  plt.tight_layout()


In [None]:
def display_images(images, num, f1, f2, title = ""):
    indices = np.random.default_rng().choice(len(images), num, replace=False)
    _, axes = plt.subplots(f1, f2, figsize=(8, 3), squeeze=False) 
    
    for r in range(axes.shape[0]):
        for c in range(axes.shape[1]):
            if r == 0 and c == 0:
                axes[r, c].set_title(title)
            axes[r, c].imshow(images[indices[r * axes.shape[1] + c]])
            axes[r, c].axis("off")

    plt.tight_layout()

In [None]:
def display_dataset_info(images, title, num, f1, f2):
    heights = list(map(lambda image: image.shape[0], images))
    widths = list(map(lambda image: image.shape[1], images))

    print(f"Število slik: {len(images)}")
    print(f"Interval širine slik: [{np.min(widths)}, {np.max(widths)}]") 
    print(f"Interval višine slik: [{np.min(heights)}, {np.max(heights)}]") 
    print(f"Povprečna širina slik: {np.mean(widths) :.2f} +- {np.std(widths) :.2f}") 
    print(f"Povprečna višina slik: {np.mean(heights) :.2f} +- {np.std(heights) :.2f}") 
    print("Primeri slik:")
    display_images(images, num, f1, f2, title)

In [None]:
def preprocess(images, factor, type):
    low_res = []
    high_res = []
    originals = []

    for image in images:
        smaller_size = (int(image.shape[1] / factor), int(image.shape[0] / factor))
        bigger_size = (int(smaller_size[0] * factor), int(smaller_size[1] * factor))
        low_res_image = cv2.resize(image, smaller_size, cv2.INTER_LINEAR)
        high_res_image = cv2.resize(low_res_image, bigger_size, cv2.INTER_LINEAR)

        low_res_image = cv2.cvtColor(low_res_image, cv2.COLOR_BGR2YCR_CB)
        high_res_image = cv2.cvtColor(high_res_image, cv2.COLOR_BGR2YCR_CB)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2YCR_CB)

        if type == 0:
            low_res_image = cv2.copyMakeBorder(low_res_image, 0, 33, 0, 33, cv2.BORDER_CONSTANT, None, value = 0)
            high_res_image = cv2.copyMakeBorder(high_res_image, 0, 33, 0, 33, cv2.BORDER_CONSTANT, None, value = 0)
            image = cv2.copyMakeBorder(image, 0, 33, 0, 33, cv2.BORDER_CONSTANT, None, value = 0)
            
            for i in range(0, smaller_size[0], 14):
                for j in range(0, smaller_size[1], 14):
                    low_res.append(low_res_image[j:j+33, i:i+33, 0:1])

            for i in range(0, bigger_size[0], 14):
                for j in range(0, bigger_size[1], 14):
                    high_res.append(high_res_image[j:j+33, i:i+33, 0:1])

            for i in range(0, bigger_size[0], 14):
                for j in range(0, bigger_size[1], 14):
                    originals.append(image[j:j+33, i:i+33, 0:1])
        else:
            low_res.append(low_res_image[:,:,0:1])
            high_res.append(high_res_image[:,:,0:1])
            originals.append(image[:,:,0:1])

    display_dataset_info(low_res, "LOW RESOLUTION IMAGES, FACTOR: " + str(factor), 5, 1, 5)
    display_dataset_info(high_res, "HIGH RESOLUTION IMAGES, FACTOR: " + str(factor), 5, 1, 5)
    display_dataset_info(originals, "ORIGINAL IMAGES, FACTOR: " + str(factor), 5, 1, 5)
    return low_res, high_res, originals

In [None]:
def train(originals, low_res, epochs, factor):
  model = Mutliresolution()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
  loss_fn = torch.nn.MSELoss()

  #dataset = torch.utils.data.TensorDataset(high_res, originals)
  #dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

  model = model.to(device)

  model.train()

  for epoch in range(epochs):
    print('epoch:' + str(epoch + 1) + ' of ' + str(epochs))

    for step in range(len(low_res)):
      if step % 100 == 0:
        print('training image:' + str(step + 1) + ' of ' + str(len(low_res)))
          
        image = torch.from_numpy(expand(low_res[step]))
        original = torch.from_numpy(expand(originals[step]))

        image = image.to(device)
        original = original.to(device)

        pred = model(image)

        loss = loss_fn(pred, original)

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

  torch.save(model, "multiresolution-factor" + str(factor) + ".pt")

In [None]:
def test(low_res, high_res, originals, model):
  for step in range(len(low_res)):
    original_image = torch.from_numpy(expand(originals[step])).to(device)
    low_res_image = torch.from_numpy(expand(low_res[step])).to(device)
    high_res_image = torch.from_numpy(expand(high_res[step])).to(device)

    pred = model(low_res_image)

    pred = torch.concat([pred, low_res_image[:, 1:3]], axis = 1)

    #display_image(pred_img)

In [None]:
images = load_images_from_folder(TRAIN_FOLDER)

display_dataset_info(images, "DATASET", 2, 1, 2)

In [None]:
for factor in range(2, 4):
  low_res, high_res, originals = preprocess(images, factor, 0)
  train(originals, high_res, EPOCHS, factor)

In [None]:
for factor in range(2, 4):
  print("==========================================================================================")
  print("FACTOR: " + str(factor))
  print("==========================================================================================")
  model = torch.load("multiresolution-factor" + str(factor) + ".pt", map_location=device)
  print(torchinfo.summary(model, (1, 1, 33, 33)))
  display_filters(model.first_conv2d.weight.cpu().detach(), "FACTOR " + str(factor) + " FIRST CONV2D LAYER FILTERS")
  print("\n\n")

In [None]:
for factor in range(2, 4):
  model = torch.load("multiresolution-factor" + str(factor) + ".pt", map_location=device)

  for dataset in [ f.path for f in os.scandir(TEST_FOLDER) if f.is_dir() ]:
    images = load_images_from_folder(dataset + "/test/")

    display_dataset_info(images, "DATASET: " + dataset.split("/")[1] + ", FACTOR: " + str(factor), 5, 1, 5)

    low_res, high_res, originals = preprocess(images, factor, 1)

    test(low_res, high_res, originals, model)