<a href="https://colab.research.google.com/github/vincentbriat/Super-resolution-investigation/blob/main/Discriminator_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Checks if the code is in a colab notebook

try:
  from google.colab import drive
  IN_COLAB = True
except:
  IN_COLAB = False

Run on Colab notebook

In [None]:
if IN_COLAB:
  !pip install basicsr
  drive.mount('/content/drive/')
  !unzip /content/drive/MyDrive/Datasets/DIV2K_valid_HR.zip
  !unzip /content/drive/MyDrive/Datasets/DIV2K_valid_LR_clean.zip
  FOLDER_LR_TEST = 'DIV2K_valid_LR_clean'
  FOLDER_HR_TEST = 'DIV2K_valid_HR'

  STUDENT_MODEL_PATH = 'drive/MyDrive/ML/Indiv_Project/Second_Year/KD/Models/student.pth'
  STUDENT_RECORDS_PATH = 'drive/MyDrive/ML/Indiv_Project/Second_Year/KD/Models/student.csv'
  GENERATOR_MODEL_PATH = 'drive/MyDrive/ML/Indiv_Project/Second_Year/KD/Models/generator.pth'
  GENERATOR_RECORDS_PATH = 'drive/MyDrive/ML/Indiv_Project/Second_Year/KD/Models/generator.csv'

  TEACHER_MODEL_PATH = 'drive/MyDrive/ML/Indiv_Project/Second_Year/KD/ESRGAN_models/RealESRGAN_x4plus.pth'

Run on my Windows desktop

In [None]:
if not IN_COLAB:
  FOLDER_LR_TEST = 'D:\Downloads\Div2k\DIV2K_valid_LR_clean'
  FOLDER_HR_TEST = 'D:\Downloads\Div2k\DIV2K_valid_HR'
  
  STUDENT_MODEL_PATH='D:\oldDrive\ML\Indiv_Project\Second_Year\KD\Models\student.pth'
  STUDENT_RECORDS_PATH='D:\oldDrive\ML\Indiv_Project\Second_Year\KD\Models\student.csv'
  GENERATOR_MODEL_PATH='D:\oldDrive\ML\Indiv_Project\Second_Year\KD\Models\generator.pth'
  GENERATOR_RECORDS_PATH='D:\oldDrive\ML\Indiv_Project\Second_Year\KD\Models\generator.csv'

  TEACHER_MODEL_PATH = 'D:\oldDrive\ML\Indiv_Project\Second_Year\KD\ESRGAN_models\RealESRGAN_x4plus.pth'

# Setup

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import cv2
import csv
import math
import numpy as np
import os
import queue
import threading
from basicsr.archs.rrdbnet_arch import RRDBNet
import torchvision
from os import listdir, environ, path
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

# Making sure to use the gpu, if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# torch.cuda.set_device(torch.device(0))

In [None]:
class DIV2KLoader(object):
  def __init__(self, low_res_folder, high_res_folder = None, preprocessing_input = None, preprocessing_output = None, include_targets = True):
    self.low_res_folder = low_res_folder
    self.high_res_folder = high_res_folder
    self.img_names = sorted(listdir(low_res_folder))
    self.include_targets = include_targets
    self.len = len(self.img_names)
    
    if include_targets or high_res_folder is not None:
      self.target_names = sorted(listdir(high_res_folder))

    if preprocessing_input is None:
      self.preprocessing_input = torchvision.transforms.ToTensor()
    else:
      self.preprocessing_input = preprocessing_input

    if preprocessing_output is None:
      self.preprocessing_output = torchvision.transforms.ToTensor()
    else:
      self.preprocessing_output = preprocessing_output
  
  def __getitem__(self, i):
    # Get the right image
    img = Image.open(Path(self.low_res_folder).joinpath(self.img_names[i]))
    if self.include_targets or self.high_res_folder is not None:
      target = Image.open(Path(self.high_res_folder).joinpath(self.target_names[i]))
      return self.preprocessing_input(img), self.preprocessing_output(target)
    else:
      return self.preprocessing_input(img)

  def __len__(self):
    return self.len
  
  def restrict_size(self, size):
    if size < len(self.img_names) and size > 0:
      self.len = size
    else:
      self.len = len(self.img_names)
      print(f"Size must be between 0 and {len(self.img_names)}")

In [None]:
# 0 is resize
# 1 is crop center
# 2 is crop random (requires finish implementation)

MODE = 1
IM_SIZE = 50

def resize_ratio(img, size=IM_SIZE):
  width, height = img.size
  if width > height:
    width, height = size, int(height * size / width)
  else:
    width, height = int(width * size / height), size
  return width, height

def preprocessing_output(img):
  im_w, im_h = img.size
  if MODE == 1:
    img = img.crop(((im_w - 400)/2, (im_h - 400)/2, (im_w + 400)/2, (im_h + 400)/2))
  elif MODE == 2:
    width, height = resize_ratio(img)
    origin_x, origin_y = 0, 0
    img = img.crop((origin_x, origin_y, origin_x + width*4, origin_y + height*4))
  return torchvision.transforms.ToTensor()(img).to(device)

def preprocessing_input(img):
  im_w, im_h = img.size
  if MODE == 1:
    img = img.crop(((im_w - 400)/2, (im_h - 400)/2, (im_w + 400)/2, (im_h + 400)/2)).resize((100, 100))
  elif MODE == 2:
    width, height = resize_ratio(img)
    origin_x, origin_y = 0, 0
    img = img.crop((origin_x, origin_y, origin_x + width*4, origin_y + height*4)).resize((width, height))
  return torchvision.transforms.ToTensor()(img).to(device)

data_test = DIV2KLoader(FOLDER_HR_TEST, FOLDER_HR_TEST, preprocessing_input=preprocessing_input, preprocessing_output=preprocessing_output)
data_test.restrict_size(10)
dataloader_test = torch.utils.data.DataLoader(data_test, batch_size=1, shuffle=False)

# Discriminator

In [None]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels: int, out_channels: int, kernel_size: int, use_act: bool, **kwargs):
    super().__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, **kwargs)
    self.activation = nn.LeakyReLU(1, inplace=True) if use_act else nn.Identity()

  def forward(self, x):
    return self.activation(self.conv(x))

class UNet(nn.Module):
  def __init__(self, in_channels, middle_channels = 32, residual_scale = .2):
    super().__init__()
    self.block = nn.ModuleList([ConvBlock(in_channels + i * middle_channels,
                                  middle_channels if i<4 else in_channels,
                                  3,
                                  stride=1,
                                  padding=1,
                                  use_act=i<4) for i in range(5)])


  def forward(self, x):
    input = x
    for conv in self.block:
      out = conv(input)
      input = torch.cat([input, out], dim=1)
    return self.residual_scale * out + x

class Discriminator(nn.Module):
  