
# Import Libraries

In [None]:
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")
import timeit
import os
import sys
import pathlib
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import cv2
import skimage
from skimage.util import random_noise
import tensorflow as tf
from time import strftime, localtime
import random
import pywt
from tensorflow.keras import models, layers
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Conv2DTranspose,\
                                    GlobalAveragePooling2D, AveragePooling2D, MaxPool2D, UpSampling2D,\
                                    BatchNormalization, Activation, ReLU, Flatten, Dense, Input,\
                                    Add, Multiply, Concatenate, Softmax
from tensorflow.keras import initializers, regularizers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.activations import softmax
from keras.applications.vgg19 import VGG19
tf.keras.backend.set_image_data_format('channels_last')
import keras.backend as K
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

# Model

In [None]:
class DWT_downsampling(tf.keras.layers.Layer):
    """
    Chintan, (2021) Image Denoising using Deep Learning [Github]. 
    https://github.com/chintan1995/Image-Denoising-using-Deep-Learning/blob/main/Models/MWCNN_256x256.ipynb
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self, x):
  
        x1 = x[:, 0::2, 0::2, :] #x(2i−1, 2j−1)
        x2 = x[:, 1::2, 0::2, :] #x(2i, 2j-1)
        x3 = x[:, 0::2, 1::2, :] #x(2i−1, 2j)
        x4 = x[:, 1::2, 1::2, :] #x(2i, 2j)   

        x_LL = x1 + x2 + x3 + x4
        x_LH = -x1 - x3 + x2 + x4
        x_HL = -x1 + x3 - x2 + x4
        x_HH = x1 - x3 - x2 + x4

        return Concatenate(axis=-1)([x_LL, x_LH, x_HL, x_HH])

class IWT_upsampling(tf.keras.layers.Layer):
    """
    Chintan, (2021) Image Denoising using Deep Learning [Github]. 
    https://github.com/chintan1995/Image-Denoising-using-Deep-Learning/blob/main/Models/MWCNN_256x256.ipynb
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self, x):
        
        x_LL = x[:, :, :, 0:x.shape[3]//4]
        x_LH = x[:, :, :, x.shape[3]//4:x.shape[3]//4*2]
        x_HL = x[:, :, :, x.shape[3]//4*2:x.shape[3]//4*3]
        x_HH = x[:, :, :, x.shape[3]//4*3:]

        x1 = (x_LL - x_LH - x_HL + x_HH)/4
        x2 = (x_LL - x_LH + x_HL - x_HH)/4
        x3 = (x_LL + x_LH - x_HL - x_HH)/4
        x4 = (x_LL + x_LH + x_HL + x_HH)/4 

        y1 = K.stack([x1,x3], axis=2)
        y2 = K.stack([x2,x4], axis=2)
        shape = K.shape(x)
        return K.reshape(K.concatenate([y1,y2], axis=-1), K.stack([shape[0], shape[1]*2, shape[2]*2, shape[3]//4]))

class Conv_block(tf.keras.layers.Layer):
    def  __init__(self, num_filters=64, kernel_size=3, **kwargs):
        super().__init__(**kwargs)
        self.num_filters=num_filters
        self.kernel_size=kernel_size
        self.initializer = tf.keras.initializers.Orthogonal()
        self.conv_1 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same',kernel_initializer=self.initializer)
        self.conv_2 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same',kernel_initializer=self.initializer)
        self.conv_3 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same',kernel_initializer=self.initializer)
        self.conv_4 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same',kernel_initializer=self.initializer)
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'num_filters': self.num_filters,
            'kernel_size':self.kernel_size
        })
        return config

    def call(self, X):
        X = self.conv_1(X)
        X = ReLU()(X)
        X = self.conv_2(X)
        X = ReLU()(X)
        X = self.conv_3(X)
        X = ReLU()(X)
        X = self.conv_4(X)
        X = ReLU()(X)

        return X

def build_model2():
  input = Input(shape=(None, None, 3))                              

  cb_1 = Conv_block(num_filters = 64)(input)                        

  dwt = DWT_downsampling()(cb_1)                                    

  cb_2 = Conv_block(num_filters=64)(dwt)                            

  c_1 = Conv2D(filters = 256, kernel_size=3, strides=1, padding='same', activation='relu',kernel_initializer=tf.keras.initializers.Orthogonal())(cb_2)
                                                                    
  iwt = IWT_upsampling()(c_1)                                      

  cb_3 = Conv_block(num_filters=64)(Add()([iwt, cb_1]))             

  c_2 = Conv2D(filters = 3, kernel_size=3, strides=1, padding='same', activation='linear',kernel_initializer=tf.keras.initializers.Orthogonal())(cb_3)
                                                                    
  output = tf.keras.layers.Add()([c_2, input])

  return Model(inputs = input, outputs = output)

# Data pre-process

In [None]:
def load_img(file_name):
    
    image = cv2.imread(file_name, cv2.IMREAD_UNCHANGED)               
                                                                      
    if len(image.shape) == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)                
    else:
        image = np.stack((image,) * 3, axis=-1)                       
    if image.shape[0] == 1 or image.shape[0] == 3:
        image = np.moveaxis(image, 0, -1)
    image = image.astype('float32')                                   
    
    return image

def add_noise(image, sigma):
    row, col, ch = image.shape

    noise = np.random.normal(0, sigma, (row, col, ch))    
    
    noise = noise.astype('float32')                                 
    
    noisy = np.clip((image + noise), 0, 255)                          
                                                                     
    return noisy

def augment_parent(image,
                   shear_scale_prob,
                   leave_as_is_probability,
                   crop_size):
  
  random_prob = np.random.rand()
  random_augment = np.random.rand()

  if random_prob < leave_as_is_probability:
    mode = 'leave_as_is'
  else:
    mode = 'random_augment'
  
  image_h, image_w , _ = image.shape

  if mode == 'leave_as_is':
    hr_parent = image
  else: 
    scaled_parent_h = image.shape[0]
    scaled_parent_w = image.shape[1]
    if random_augment > shear_scale_prob:
      shear_x = np.random.randn()*0.25
      shear_y = np.random.randn()*0.25
      scale_x = np.random.randn()*0.15
      scale_y = np.random.randn()*0.15
      transform_matrix = np.array([[1+scale_x, shear_x, 0],[shear_y, 1+scale_y, 0]])
      hr_parent = cv2.warpAffine(image, transform_matrix, (scaled_parent_w,scaled_parent_h))
    if random_augment > 1:
      if scaled_parent_h > crop_size:
        start_h = int((scaled_parent_h - crop_size)/2)
        end_h = int(start_h + crop_size)
        hr_parent = image[start_h : end_h, :,:]
      if scaled_parent_w > crop_size:
        start_w = int((scaled_parent_w - crop_size)/2)
        end_w = int(start_w + crop_size)
        hr_parent = image[:, start_w : end_w, :]
    else:
      while (scaled_parent_h-1 < crop_size) or (scaled_parent_w-1 < crop_size):
        crop_size -= 4
      w_crop_diff = scaled_parent_w - crop_size
      h_crop_diff = scaled_parent_h - crop_size
      top_left_x_coordinate = np.random.randint(0, w_crop_diff)
      top_left_y_coordinate = np.random.randint(0, h_crop_diff)
      hr_parent = image[top_left_y_coordinate:top_left_y_coordinate+crop_size, top_left_x_coordinate:top_left_x_coordinate+crop_size,:]

    random_rot = random.randint(0,7)
    hr_parent = np.rot90(hr_parent, random_rot, axes = (0,1))
    if random_rot > 3 :
        hr_parent = np.fliplr(hr_parent)

  if hr_parent.shape[0]%2 != 0:
      hr_parent = hr_parent[:-1,:,:]
  if hr_parent.shape[1]%2 !=0:
      hr_parent = hr_parent[:,:-1,:]
        
  return hr_parent

def parent_to_child(hr_parent, scale_factor, sigma):

  scale_down = 1/scale_factor
  random_chooser = np.random.rand()
  downsample_prob = 0.5
  if random_chooser < downsample_prob:
    lr_child = hr_parent
  else:
    lr_child = cv2.resize(hr_parent, None, fx = scale_down, fy = scale_down, interpolation = cv2.INTER_CUBIC)
  lr_child = add_noise(lr_child, sigma)

  return lr_child

def hr_lr_generator(image, 
                    scale_factor, 
                    shear_scale_prob,
                    leave_as_is_probability,
                    crop_size,
                    sigma):

  while True:
    hr_parent = augment_parent(image,
                               shear_scale_prob,
                               leave_as_is_probability,
                               crop_size)
    
    lr_child = parent_to_child(hr_parent, scale_factor, sigma)
    if lr_child.shape != hr_parent.shape:
      lr_child = cv2.resize(lr_child, (hr_parent.shape[1], hr_parent.shape[0]), interpolation = cv2.INTER_CUBIC)

    y = np.expand_dims(hr_parent, axis = 0)
    x = np.expand_dims(lr_child, axis=0)
    
    yield x, y

# Other processing steps

In [None]:
def predict_image(image, scale_factor):

  """
  This function predicts the original image on the trained model.
  It takes the original image and interpolates it by scale_factor, expands image dimesnions to 4d, takes prediction,
  and outputs 8 bit image.
  """

  image_upscaled = cv2.resize(image, None, fx = scale_factor, fy = scale_factor, interpolation = cv2.INTER_CUBIC)

  image_upscaled = np.expand_dims(image_upscaled, axis = 0)

  super_image = model.predict(image_upscaled)

  super_image = np.squeeze(super_image, axis=0)

  super_image = cv2.convertScaleAbs(super_image)

  return super_image


"""# Gradual SR"""

def get_gradual_factors(SR_factor, gradual_increase_value):
  gradual_SR_list = [SR_factor]

  sr_fact = SR_factor/gradual_increase_value

  while (sr_fact) != 1:
    gradual_SR_list.append(int(sr_fact))
    sr_fact = sr_fact/gradual_increase_value

  
  gradual_SR_list.reverse()
  return gradual_SR_list

def get_images_paths(input_pd):
  root = pathlib.Path(input_pd)
  img_paths = list(sorted(root.rglob('*.png')))
  img_paths_list = [str(path) for path in img_paths]
  
  return img_paths_list

# Run

In [None]:
input_pd = r''
output_pd = r''

drop_lr = 0.5 #	factor by which the learning rate will be reduced. new_lr = lr * factor
num_epochs = 1500 # Number of epochs to run the model per image
scale_factor= 2 # This is the Super-resolution factor
sigma = 30 # Standard deviation (spread or “width”) of the normal distribution
gradual_increase_value = 2 # Value with which the images are gradually super-resolved. This gradual increase factor is inspired by Shocher, Assaf & Cohen, Nadav & Irani, Michal. (2018). Zero-Shot Super-Resolution Using Deep Internal Learning. 3118-3126. 10.1109/CVPR.2018.00329. 
leave_as_is_probability = 0.2 # This is the probability associated with augmentation of hr parent. A higher leave_as_is_probability reduces probability of random augmentation in hr parent.
shear_scale_prob = 0 # This the prabability associated with random shearing & scaling of HR parent during augmentations. A lower shear_scale_prob value prompts the model to increase the probability of random shearing & scaling, and vice-versa
crop_size = 96 # This is the crop size of 
SR_factor = 4

In [None]:
callback_list = [tf.keras.callbacks.ReduceLROnPlateau(monitor = 'loss',
                                                      factor = drop_lr,
                                                      patience = 100,
                                                      verbose = 1,
                                                      mode = 'min',
                                                      min_delta = 0.001,
                                                      cooldown = 20,
                                                      min_lr = 0.00000001),
                  tf.keras.callbacks.EarlyStopping(monitor = 'loss',
                                                   min_delta = 0.0001,
                                                   patience = 350,
                                                   verbose = 1,
                                                   mode = 'min')
]

model = build_model2()

mae_loss_object = tf.keras.losses.MeanAbsoluteError()
vgg19 = VGG19(include_top=False, weights='imagenet')
vgg19.trainable = False
for l in vgg19.layers:
    l.trainable = False
vgg_model = Model(inputs=vgg19.input, outputs=vgg19.get_layer('block1_conv2').output)
vgg_model.trainable = False

def custom_loss(y_true, y_pred):    
    mae_loss = mae_loss_object(y_true, y_pred)
    vgg_loss = K.mean(K.square(vgg_model(y_true) - vgg_model(y_pred)))
    vgg_loss_adjusted = (1 - (1/(1+vgg_loss)))*10
    
    return mae_loss + vgg_loss_adjusted



gradual_SR_list = get_gradual_factors(SR_factor, gradual_increase_value)

print('Scaling gradually in order:', gradual_SR_list)

scale_fact = gradual_increase_value

date_time = strftime('_%b_%d_%H_%M_%S', localtime())
super_dir = output_pd + '/' + date_time + '/' + '_super'
#avg_dir = output_pd + '/' + date_time + '/' + '_avg'
#median_dir = output_pd + '/' + date_time + '/' + '_median'

os.makedirs(super_dir)
#os.makedirs(avg_dir)
#os.makedirs(median_dir)
start = timeit.default_timer()
for file in os.listdir(input_pd):
  model.compile(loss = custom_loss, optimizer = Adam(learning_rate=0.001))
  model.summary()

  image_path = os.path.join(input_pd, '%s' %file)
  image = load_img(image_path)
  print('starting training for', file)

  for i in range(len(gradual_SR_list)):
    
    model.fit(hr_lr_generator(image, 
                              scale_factor=scale_fact, 
                              shear_scale_prob = shear_scale_prob,
                              leave_as_is_probability = leave_as_is_probability,
                              crop_size = crop_size,
                              sigma = sigma),
              batch_size = 1,
              epochs = num_epochs,
              verbose =1, 
              callbacks = callback_list, 
              steps_per_epoch = 1)
    
    super_image = predict_image(image, scale_factor = scale_fact)
   
    image = super_image 
    # avg, median = accumulated_result(image, scale_factor = scale_fact)

  plt.imsave(os.path.join(super_dir,'%s' %file), super_image, format = 'png')
  #plt.imsave(os.path.join(avg_dir,'%s' %file), avg, format = 'png')
  #plt.imsave(os.path.join(median_dir,'%s' %file), median, format = 'png')

  tf.keras.backend.clear_session()
  
stop = timeit.default_timer()
print('Done!!!')
print('Time take for all images is', stop-start, 'seconds')
input_pd, output_pd = None, None

