In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
!pip install tensorflow_addons

Collecting tensorflow_addons
  Downloading tensorflow_addons-0.15.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[?25l[K     |▎                               | 10 kB 23.9 MB/s eta 0:00:01[K     |▋                               | 20 kB 24.3 MB/s eta 0:00:01[K     |▉                               | 30 kB 11.6 MB/s eta 0:00:01[K     |█▏                              | 40 kB 8.2 MB/s eta 0:00:01[K     |█▌                              | 51 kB 5.3 MB/s eta 0:00:01[K     |█▊                              | 61 kB 5.8 MB/s eta 0:00:01[K     |██                              | 71 kB 5.5 MB/s eta 0:00:01[K     |██▍                             | 81 kB 6.2 MB/s eta 0:00:01[K     |██▋                             | 92 kB 6.2 MB/s eta 0:00:01[K     |███                             | 102 kB 5.1 MB/s eta 0:00:01[K     |███▎                            | 112 kB 5.1 MB/s eta 0:00:01[K     |███▌                            | 122 kB 5.1 MB/s eta 0:00:01[K     |███

# Import

In [3]:
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")

import os
import sys
import pathlib
from matplotlib import pyplot as plt
import pickle
import pandas as pd
import numpy as np
import cv2
import timeit
import skimage
from skimage.util import random_noise
import tensorflow as tf
import datetime
from time import strftime, localtime
from scipy.ndimage import filters, measurements, interpolation
from math import pi
import torch
import torch.nn as nn
# from torchsummary import summary
import time
from PIL import Image
import scipy.io as sio
import matplotlib.pyplot as plt

from torch.nn import functional as F
from scipy.ndimage import measurements, interpolation
from torch.utils.data import Dataset
import random
import pywt
import keras
import tensorflow_addons as tfa
from tensorflow.keras import models, layers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Conv2DTranspose,\
                                    GlobalAveragePooling2D, AveragePooling2D, MaxPool2D, UpSampling2D,\
                                    BatchNormalization, Activation, ReLU, LeakyReLU, Flatten, Dense, Input,\
                                    Add, Multiply, Concatenate, Softmax
from tensorflow.keras import initializers, regularizers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.activations import softmax, sigmoid
from keras.applications.vgg19 import VGG19
tf.keras.backend.set_image_data_format('channels_last')
import keras.backend as K

# Model

In [4]:
def generator():
  input = Input(shape=(None, None, 1), batch_size=3)
  c_1 = tfa.layers.SpectralNormalization(Conv2D(filters = 64, kernel_size=7, strides=1, padding='valid',use_bias=False))(input)
  c_2 = tfa.layers.SpectralNormalization(Conv2D(filters = 64, kernel_size=5, strides=1, padding='valid',use_bias=False))(c_1)
  c_3 = tfa.layers.SpectralNormalization(Conv2D(filters = 64, kernel_size=3, strides=1, padding='valid',use_bias=False))(c_2)
  c_4 = tfa.layers.SpectralNormalization(Conv2D(filters = 64, kernel_size=1, strides=1, padding='valid',use_bias=False))(c_3)
  c_5 = tfa.layers.SpectralNormalization(Conv2D(filters = 64, kernel_size=1, strides=1, padding='valid',use_bias=False))(c_4)
  c_6 = tfa.layers.SpectralNormalization(Conv2D(filters =1, kernel_size=1, strides=2, padding='valid',use_bias=False))(c_5)
  return Model(inputs = input, outputs = c_6)

def discriminator():
  input = Input(shape = (None,None,3))
  l = tfa.layers.SpectralNormalization(Conv2D(filters=64, kernel_size=7,use_bias=True))(input)
  for _ in range(1,6):
    l = tfa.layers.SpectralNormalization(Conv2D(filters=64, kernel_size=1,use_bias=True))(l)
    l = BatchNormalization()(l)
    l = ReLU()(l)
  l = tfa.layers.SpectralNormalization(Conv2D(filters=1, kernel_size=1,use_bias=True))(l)
  out = sigmoid(l)
  return Model(inputs = input, outputs = out)

def d_loss(real_output, fake_output):
  real_loss = cross_entropy(tf.ones_like(real_output), real_output)
  fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
  total_loss = real_loss + fake_loss
  return total_loss

def g_loss(fake_output):
  return cross_entropy(tf.ones_like(fake_output), fake_output)

generator = generator()
discriminator = discriminator()
generator_optimizer = tf.keras.optimizers.Adam(learning_rate = 2e-4, beta_1=0.5, beta_2=0.999)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate = 2e-4, beta_1=0.5, beta_2=0.999)

class DWT_downsampling(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        """
        Chintan, (2021) Image Denoising using Deep Learning [Github]. 
        https://github.com/chintan1995/Image-Denoising-using-Deep-Learning/blob/main/Models/MWCNN_256x256.ipynb
        """
        
    def call(self, x):
        
        x1 = x[:, 0::2, 0::2, :] #x(2i−1, 2j−1)
        x2 = x[:, 1::2, 0::2, :] #x(2i, 2j-1)
        x3 = x[:, 0::2, 1::2, :] #x(2i−1, 2j)
        x4 = x[:, 1::2, 1::2, :] #x(2i, 2j)   

        x_LL = x1 + x2 + x3 + x4
        x_LH = -x1 - x3 + x2 + x4
        x_HL = -x1 + x3 - x2 + x4
        x_HH = x1 - x3 - x2 + x4

        return Concatenate(axis=-1)([x_LL, x_LH, x_HL, x_HH])

class IWT_upsampling(tf.keras.layers.Layer):
    """
    Chintan, (2021) Image Denoising using Deep Learning [Github]. 
    https://github.com/chintan1995/Image-Denoising-using-Deep-Learning/blob/main/Models/MWCNN_256x256.ipynb
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self, x):
       
        x_LL = x[:, :, :, 0:x.shape[3]//4]
        x_LH = x[:, :, :, x.shape[3]//4:x.shape[3]//4*2]
        x_HL = x[:, :, :, x.shape[3]//4*2:x.shape[3]//4*3]
        x_HH = x[:, :, :, x.shape[3]//4*3:]

        x1 = (x_LL - x_LH - x_HL + x_HH)/4
        x2 = (x_LL - x_LH + x_HL - x_HH)/4
        x3 = (x_LL + x_LH - x_HL - x_HH)/4
        x4 = (x_LL + x_LH + x_HL + x_HH)/4 

        y1 = K.stack([x1,x3], axis=2)
        y2 = K.stack([x2,x4], axis=2)
        shape = K.shape(x)
        return K.reshape(K.concatenate([y1,y2], axis=-1), K.stack([shape[0], shape[1]*2, shape[2]*2, shape[3]//4]))

class Conv_block(tf.keras.layers.Layer):
    def  __init__(self, num_filters=64, kernel_size=3, **kwargs):
        super().__init__(**kwargs)
        self.num_filters=num_filters
        self.kernel_size=kernel_size
        self.conv_1 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same')
        self.conv_2 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same')
        self.conv_3 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same')
        self.conv_4 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same')

    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'num_filters': self.num_filters,
            'kernel_size':self.kernel_size
        })
        return config

    def call(self, X):
        X = self.conv_1(X)
        
        X = ReLU()(X)
        X = self.conv_2(X)
        
        X = ReLU()(X)
        X = self.conv_3(X)
        
        X = ReLU()(X)

        X = self.conv_4(X)
        X = ReLU()(X)

        return X

def build_model():
  input = Input(shape=(None, None, 3))                              # Output Filters = 3

  cb_1 = Conv_block(num_filters = 64)(input)                        # Output Filters = 64

  dwt = DWT_downsampling()(cb_1)                                    # Output Filters = 4 x 64 = 256

  cb_2 = Conv_block(num_filters=64)(dwt)                            # Output Filters = 64

  c_1 = Conv2D(filters = 256, kernel_size=3, strides=1, padding='same', activation='relu')(cb_2)
                                                                    # Output Filters = 256

  iwt = IWT_upsampling()(c_1)                                      # Output Filters = 256 / 4 = 64

  cb_3 = Conv_block(num_filters=64)(Add()([iwt, cb_1]))             # Output Filters = 64

  c_2 = Conv2D(filters = 3, kernel_size=3, strides=1, padding='same', activation='linear')(cb_3)
                                                                    # Output Filters = 3
  output = tf.keras.layers.Add()([c_2, input])

  return Model(inputs = input, outputs = output)

# Data pre-processing

In [5]:
def load_img(img_path, return_data_type = 'float32'):
    """
    Takes input image and returns a return_data_type array
    """
    image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)                # Read image                                                                                                                                      
    if len(image.shape) == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)                # BGR -> RGB
    else:
        image = np.stack((image,) * 3, axis=-1)                       # Grayscale, channel 1 -> channel 3
    if image.shape[0] == 1 or image.shape[0] == 3:
        image = np.moveaxis(image, 0, -1)

    image = image.astype(return_data_type)       
    # np.finfo(img.dtype)
    # finfo(resolution=1e-06, min=-3.4028235e+38, max=3.4028235e+38, dtype=float32)
    return image

def add_noise(image, sigma):
    shape = image.shape

    noise = np.random.normal(0, sigma, (shape))           # random.normal(loc(mean)=0.0, scale(std)=1.0,
                                                                     # size(output shape)=None)
    noise = noise.astype('float32')                                  # Check image dtype before adding
    noisy = np.clip((image + noise), 0, 255)                         # We clip negative values and set them to zero 
                                                                     # and values over 255 are clipped to 255.
    return noisy

def image_crop(img_list, crop_size, leave_as_prob, shear_scale_prob, center_crop_prob):
  num_imgs = len(img_list)
  img_h, img_w, _ = img_list[0].shape
  for i in range(num_imgs):
    img_list[i] = np.array(img_list[i], dtype=np.float32)
  
  random_chooser = np.random.rand()
  random_augment = np.random.rand()

  if random_chooser < leave_as_prob:
    mode = 'leave_as_is'
  else:
    mode = 'random_augment'

  if mode == 'leave_as_is':
    for i in range(num_imgs):
      img_list[i] = img_list[i]
  else:
    if random_augment > shear_scale_prob:
      shear_x = np.random.randn()*0.25
      shear_y = np.random.randn()*0.25
      scale_x = np.random.randn()*0.15
      scale_y = np.random.randn()*0.15
      transform_matrix = np.array([[1+scale_x, shear_x, 0.0],[shear_y, 1+scale_y, 0.0]])
      transform_matrix = transform_matrix.astype(img_list[0].dtype)
      for i in range(num_imgs):
        img_list[i] = cv2.warpAffine(img_list[i], transform_matrix, (img_w,img_h))      
    else:
      if random_chooser > center_crop_prob:
        if img_h > crop_size:
          start_h = int((img_h - crop_size)/2)
          end_h = int(start_h + crop_size)
          for i in range(num_imgs):
            img_list[i] = img_list[i][start_h:end_h, :, :]
        if img_w > crop_size:
          start_w = int((img_w-crop_size)/2)
          end_w = int(start_w + crop_size)
          for i in range(num_imgs):
            img_list[i] = img_list[i][:, start_w: end_w, :]
      else:
          while (img_h - 1 < crop_size) or (img_w - 1 < crop_size):
            crop_size -=4

          w_crop_diff = img_w - crop_size
          h_crop_diff = img_h - crop_size

          top_left_x_coordinate = np.random.randint(0, w_crop_diff)
          top_left_y_coordinate = np.random.randint(0, h_crop_diff)

          X2_img_x = int(2*top_left_x_coordinate + crop_size/2)
          X2_img_y = int(2*top_left_y_coordinate + crop_size/2)   
          
          for i in range(num_imgs):
            img_list[i] = img_list[i][top_left_y_coordinate:top_left_y_coordinate+crop_size, top_left_x_coordinate:top_left_x_coordinate+crop_size,:]

    random_rot = random.randint(0,7)
    for i in range(num_imgs):
      img_list[i] = np.rot90(img_list[i], random_rot, axes = (0,1))
    if random_rot > 3 :
      for i in range(num_imgs):
        img_list[i] = np.fliplr(img_list[i])

  for i in range(num_imgs):
    if img_list[i].shape[0]%2 != 0:
      img_list[i] = img_list[i][:-1,:,:]
    if img_list[i].shape[1]%2 !=0:
      img_list[i] = img_list[i][:,:-1,:]

  return img_list

def swap_axis(image):

  return np.transpose(image, axes=[3,1,2,0]) if type(image) == np.ndarray else tf.transpose(image, perm=[3,1,2,0])
  
def parent_to_child(hr_parent, scale_factor, kernel):
  """
  This function takes the hr_parent and first downsamples it to create lr_child and 
  then adds noise if noise_flag is True.
  Finally, the image is upsampled to feed into the network.
  """
  scale_down = 1/scale_factor
  if len(hr_parent.shape) == 4:
    hr_parent = np.squeeze(hr_parent, axis=0)
  else:
    hr_parent = hr_parent
  input_shape = hr_parent.shape
  # print(hr_parent.shape)

  output_shape = np.uint(np.ceil(np.array(input_shape))*np.array(scale_down))

  lr_child = numeric_kernel(hr_parent, kernel, scale_down, output_shape)
  # print(lr_child.shape)
  lr_child = add_noise(lr_child, sigma)

  lr_child = cv2.resize(lr_child, (hr_parent.shape[1], hr_parent.shape[0]), interpolation = cv2.INTER_CUBIC)
  # print(lr_child.shape)

  return np.expand_dims(lr_child, axis=0)

def hr_lr_generator(image, 
                    scale_factor, 
                    final_kernel,
                    shear_scale_prob,
                    leave_as_prob,
                    center_crop_prob,
                    crop_size):

  """
  Generator to simply return hr_parent and lr_child as a pair
  """
  while True:
    
    hr_parent_list = image_crop([image],
                               crop_size,
                               leave_as_prob,
                               shear_scale_prob,
                               center_crop_prob)
    hr_parent = hr_parent_list[0]
    lr_child = parent_to_child(hr_parent, scale_factor, final_kernel)
  
    if len(lr_child.shape) == 4:
      x = lr_child
    else:
      x = np.expand_dims(lr_child, axis=0)

    if len(hr_parent.shape) == 4:
      y = hr_parent
    else:
      y = np.expand_dims(hr_parent, axis = 0)
    yield x, y

def get_gradual_factors(SR_factor, gradual_increase_value):
  gradual_SR_list = [SR_factor]
  sr_fact = SR_factor/gradual_increase_value
  while (sr_fact) != 1:
    gradual_SR_list.append(int(sr_fact))
    sr_fact = sr_fact/gradual_increase_value
  gradual_SR_list.reverse()
  return gradual_SR_list

def get_images_paths(input_pd):
  root = pathlib.Path(input_pd)
  img_paths = list(sorted(root.rglob('*.png')))
  img_paths_list = [str(path) for path in img_paths]
  
  return img_paths_list


# Kernel post-procesing

The post-processing here has been used as in the official code of KernelGAN, URL: https://github.com/sefibk/KernelGAN

Paper: S. B. Kligler, A. Shocher, M. Irani, "Blind Super-Resolution Kernel Estimation using an Internal-GAN" in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, 8-14 December 2019, Vancouver, BC, Canada. pages 284-293, 2019.

In [6]:
def numeric_kernel(hr_parent, kernel, scale_down, output_shape):
  out_im = np.zeros_like(hr_parent)
  # print('out_im shape', out_im.shape)
  # print('kernel_shape', kernel.shape)
  for channel in range(hr_parent.ndim):
    out_im[:,:,channel] = filters.correlate(hr_parent[:,:,channel], kernel)
  return out_im[np.round(np.linspace(0, hr_parent.shape[0] - 1 / scale_down, output_shape[0])).astype(int)[:, None],
           np.round(np.linspace(0, hr_parent.shape[1] - 1 / scale_down, output_shape[1])).astype(int), :]

def zeroize_negligible(k,n = 40):
  k_sorted = np.sort(k.flatten())
  k_n_min = 0.75 * k_sorted[-n - 1]
  filtered_k = np.clip(k - k_n_min, a_min=0, a_max=100)
  return filtered_k / filtered_k.sum()

def kernel_shift(kernel, sf):
    current_center_of_mass = measurements.center_of_mass(kernel)
    wanted_center_of_mass = np.array(kernel.shape) // 2 + 0.5 * (np.array(sf) - (np.array(kernel.shape) % 2))
    shift_vec = wanted_center_of_mass - current_center_of_mass
    kernel = np.pad(kernel, np.int(np.ceil(np.max(np.abs(shift_vec)))) + 1, 'constant')
    kernel = interpolation.shift(kernel, shift_vec)
    return kernel

def post_process_k(k, n):
    k = k.detach().cpu().float().numpy()
    # Zeroize negligible values
    significant_k = zeroize_negligible(k, n)
    # Force centralization on the kernel
    centralized_k = kernel_shift(significant_k, sf=2)
    # return shave_a2b(centralized_k, k)
    return centralized_k

def analytic_kernel(k):
    """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
    k_size = k.shape[0]
    # Calculate the big kernels size
    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
    # Loop over the small kernel to fill the big one
    for r in range(k_size):
        for c in range(k_size):
            big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
    # Crop the edges of the big kernel to ignore very small values and increase run time of SR
    crop = k_size // 2
    cropped_big_k = big_k[crop:-crop, crop:-crop]
    # Normalize to 1
    return cropped_big_k / cropped_big_k.sum()

def save_final_kernel(k_2, img_name, i):
    """saves the final kernel and the analytic kernel to the results folder"""
    if i == 0:
      sio.savemat(os.path.join(output_pd, '%s_kernel_x2.mat' % img_name), {'Kernel': k_2})
    else:
      k_4 = analytic_kernel(k_2)
      sio.savemat(os.path.join(output_pd, '%s_kernel_x4.mat' % img_name), {'Kernel': k_4})

# Run

In [7]:
input_pd = r'/content/gdrive/MyDrive/5. data_all/D1_X4'
output_pd = r'/content/gdrive/MyDrive/checking'

drop_lr = 0.5 #	factor by which the learning rate will be reduced. new_lr = lr * factor
num_epochs = 1500 # Number of epochs to run the model per image
sigma = 30 # Standard deviation (spread or “width”) of the normal distribution
gradual_increase_value = 2 # Value with which the images are gradually super-resolved. This gradual increase factor is inspired by Shocher, Assaf & Cohen, Nadav & Irani, Michal. (2018). Zero-Shot Super-Resolution Using Deep Internal Learning. 3118-3126. 10.1109/CVPR.2018.00329. 
leave_as_prob = 0.3 # This is the probability associated with augmentation of hr parent. A higher leave_as_is_probability reduces probability of random augmentation in hr parent.
shear_scale_prob = 0.6 # This the prabability associated with random shearing & scaling of HR parent during augmentations. A lower shear_scale_prob value prompts the model to increase the probability of random shearing & scaling, and vice-versa
center_crop_prob = 1  # If center_crop_prob is small, more crops are taken from the center of the image. Else, if center_crop_prob is large, crops are taken randomly from the image, regardless of location.
crop_size = 96 # This is the initial crop size to be considered. 
SR_factor = 4 # The is the super resolution factor.
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def predict_image(image, scale_factor):

  """
  This function predicts the original image on the trained model.
  It takes the original image and interpolates it by scale_factor, expands image dimesnions to 4d, takes prediction,
  and outputs 8 bit image.
  """
  image_upscaled = cv2.resize(np.float32(image), None, fx = scale_factor, fy = scale_factor, interpolation = cv2.INTER_CUBIC)
  image_upscaled = np.expand_dims(image_upscaled, axis = 0)
  super_image = model.predict(image_upscaled)
  super_image = np.squeeze(super_image, axis=0)
  super_image = cv2.convertScaleAbs(super_image)

  return super_image

def train_step(input_image, epoch, crop_size, leave_as_prob, shear_scale_prob, center_crop_prob):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    original_image = input_image
    noisy_image = add_noise(input_image, sigma)
    cropped_img_list = image_crop([original_image,
                                   noisy_image], 
                                  crop_size,
                                  leave_as_prob, 
                                  shear_scale_prob, 
                                  center_crop_prob)
                                                                                  
    noisy_image_expand_axis = tf.expand_dims(cropped_img_list[1], axis =0)
    noisy_image_swapped_axis = swap_axis(noisy_image_expand_axis)
    original_image_expand_axis = tf.expand_dims(cropped_img_list[0], axis = 0)
    gen_output = generator(noisy_image_swapped_axis, training=True)
    gen_output = swap_axis(gen_output)
    disc_real_output = discriminator(original_image_expand_axis, training = True)
    disc_generated_output = discriminator(gen_output, training = True)
    gen_loss = g_loss(disc_generated_output)
    disc_loss = d_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_loss,
                                          generator.trainable_variables)
  
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)
  
  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))
  
  # with summary_writer.as_default():
    # tf.summary.scalar('gen_total_loss', gen_loss, step=epoch)
    # tf.summary.scalar('disc_loss', disc_loss, step=epoch)

def fit(input_image, epochs, crop_size, leave_as_prob, shear_scale_prob, center_crop_prob):
  for epoch in range(epochs):
    start = time.time()

    # Training step
    train_step(input_image, epoch, crop_size, leave_as_prob, shear_scale_prob, center_crop_prob)
    
    # Saving (checkpointing) the model every 20 epochs
    # if (epoch + 1) % 20 == 0:
    # checkpoint.save(file_prefix=checkpoint_prefix)

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                        time.time()-start))
  # checkpoint.save(file_prefix=checkpoint_prefix)

callback_list = [tf.keras.callbacks.ReduceLROnPlateau(monitor = 'loss',
                                                      factor = drop_lr,
                                                      patience = 100,
                                                      verbose = 0,
                                                      mode = 'min',
                                                      min_delta = 0.001,
                                                      cooldown = 20,
                                                      min_lr = 0.00000001),
                  tf.keras.callbacks.EarlyStopping(monitor = 'loss',
                                                   min_delta = 0.0001,
                                                   patience = 350,
                                                   verbose = 1,
                                                   mode = 'min')
]

mae_loss_object = tf.keras.losses.MeanAbsoluteError()
vgg19 = VGG19(include_top=False, weights='imagenet')
vgg19.trainable = False
for l in vgg19.layers:
    l.trainable = False
vgg_model = Model(inputs=vgg19.input, outputs=vgg19.get_layer('block1_conv2').output)
vgg_model.trainable = False
model = build_model()
def custom_loss(y_true, y_pred):    
    mae_loss = mae_loss_object(y_true, y_pred)
    vgg_loss = K.mean(K.square(vgg_model(y_true) - vgg_model(y_pred)))
    vgg_loss_adjusted = (1 - (1/(1+vgg_loss)))*10
    
    return mae_loss + vgg_loss_adjusted

gradual_SR_list = get_gradual_factors(SR_factor, gradual_increase_value)

print('Scaling gradually in order:', gradual_SR_list)

scale_fact = gradual_increase_value

date_time = strftime('_%b_%d_%H_%M_%S', localtime())
super_dir = output_pd + '/' + date_time + '/' + '_super'
#avg_dir = output_pd + '/' + date_time + '/' + '_avg'
#median_dir = output_pd + '/' + date_time + '/' + '_median'

os.makedirs(super_dir)
#os.makedirs(avg_dir)
#os.makedirs(median_dir)
start = timeit.default_timer()

for file in os.listdir(input_pd):
  
  image_path = os.path.join(input_pd, '%s' %file)
  image = load_img(image_path)

  image_tf = tf.convert_to_tensor(image, dtype = tf.float32)

  tf.keras.backend.clear_session()
  fit(image_tf, 3000, crop_size, leave_as_prob, shear_scale_prob, center_crop_prob)

  delta = torch.Tensor([1.]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
  for layer_idx, layer in enumerate(generator.layers):
    if layer_idx == 0:
      pass
    else:
     for weight_idx, weight in enumerate(layer.get_weights()):
      if weight_idx == 0:
        
        weight = np.transpose(weight, axes=[3,2,0,1])
        weight = torch.from_numpy(weight)
        curr_k = F.conv2d(delta, weight, padding = 13-1) if layer_idx == 1 else F.conv2d(curr_k, weight)
       
  curr_k = curr_k.squeeze().flip([0,1])

  final_kernel = post_process_k(curr_k,n = 40)

  tf.keras.backend.clear_session()
  print('starting training for', file)
  tf.keras.backend.clear_session()
  model.compile(loss = custom_loss, optimizer = Adam(learning_rate=0.001))
  for i in range(len(gradual_SR_list)):
    print('checkpoint 1')
    if i == 0:
      k = final_kernel 
    else:
      k = analytic_kernel(final_kernel)
      
    img_name = str(file).split(sep='.')
    img_name = img_name[0]
    save_final_kernel(k, img_name, i)

    if len(image.shape) == 4:
      image = np.squeeze(image, axis = 0)
    else:
      image = image
    print('checkpoint 2')
    
    model.fit(hr_lr_generator(image, 
                              scale_factor=scale_fact, 
                              final_kernel = k,
                              shear_scale_prob = shear_scale_prob,
                              leave_as_prob= 0.2,
                              center_crop_prob = center_crop_prob,
                              crop_size = 96),
              batch_size = 1,
              epochs = num_epochs,
              verbose =1, 
              callbacks = callback_list, 
              steps_per_epoch = 1)
    
    super_image = predict_image(image, scale_factor = scale_fact)
    
    image = super_image
  plt.imsave(os.path.join(super_dir,'%s' %file), super_image, format = 'png')
  # plt.imsave(os.path.join(avg_dir,'%s' %file), avg, format = 'png')
  # plt.imsave(os.path.join(median_dir,'%s' %file), median, format = 'png')

  tf.keras.backend.clear_session()
  

print('Done!!!')
stop = timeit.default_timer()

print('Time take for all images is', stop-start, 'seconds')
SR_factor = None
input_pd = None
output_pd = None