In [None]:
import tensorflow as tf
from keras.models import Model
from keras.layers import Conv2D, Activation, BatchNormalization, UpSampling2D, Input, LeakyReLU, Add, Dense

import numpy as np
from numpy import cov, trace, iscomplexobj
from scipy.linalg import sqrtm
from keras.applications.inception_v3 import InceptionV3, preprocess_input

import matplotlib.pyplot as plt
import glob
from PIL import Image

In [None]:
# UIEB dataset is available here
!git clone https://github.com/vanathi-g/fyp-datasets.git
!cd fyp-datasets/

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
checkpoint_filepath = '/content/gdrive/MyDrive/SRGAN_weights/' # Path to model weights save folder in Google Drive
model_save_filepath = '/content/gdrive/MyDrive/SRGAN_models/' # Path to model save folder

# Image Preprocessing and Display Functions
---
### Image dimensions:
min width, height = 225, 159<br>
max width, height = 2180, 1450

In [None]:
# Image preprocessing
def load_all_images(data_dir, image_shape):
  all_images = glob.glob(data_dir)
  original_images = []
  enhanced_images = []
  width, height = image_shape

  for img in all_images:
    img1 = Image.open(img)
    img2 = Image.open("/content/fyp-datasets/reference-890/"+img[30:])

    if(img1.size[0] < width or img1.size[1] < height):
      continue

    # Cropping image to extract max number of patches of specified shape from each image

    start = [0, 0]
    while True:
      crop_coords = (start[0], start[1], start[0] + width, start[1] + height)
      cropped1 = img1.crop(crop_coords)
      cropped2 = img2.crop(crop_coords)

      original = np.array(cropped1)
      enhanced = np.array(cropped2)

      original_images.append(original)
      enhanced_images.append(enhanced)

      start[0] += width
      if(img1.size[0] - start[0] < width):
        start[0] = 0
        start[1] += height
      if(img1.size[1] - start[1] < height):
        break

  return np.array(enhanced_images), np.array(original_images)

In [None]:
def select_random_batch(original_images, enhanced_images, batch_size):
  selected_ind = np.random.choice(range(0, len(original_images)), batch_size)
  selected_original = np.array([original_images[x] for x in selected_ind])
  selected_enhanced = np.array([enhanced_images[x] for x in selected_ind])
  return selected_original, selected_enhanced

In [None]:
def load_test_images(data_dir, batch_size, image_shape):
  all_images = glob.glob(data_dir)
  images_batch = np.random.choice(all_images, size=batch_size)
  images = []

  for img in images_batch:
    img1 = Image.open(img)
    img1 = img1.crop((0, 0, image_shape[0], image_shape[1]))
    img1_test = np.array(img1)
    images.append(img1_test)
  return np.array(images)

In [None]:
def display_images(original_img, enhanced_img, generated_img):
    fig = plt.figure()
    ax = fig.add_subplot(1, 3, 1)
    original_img = (original_img + 1)/2
    ax.imshow(original_img)
    ax.axis("off")
    ax.set_title("Original Image")

    ax = fig.add_subplot(1, 3, 2)
    enhanced_img = (enhanced_img + 1)/2
    ax.imshow(enhanced_img)
    ax.axis("off")
    ax.set_title("Enhanced Image")

    ax = fig.add_subplot(1, 3, 3)
    generated_img = (generated_img + 1)/2
    ax.imshow(generated_img)
    ax.axis("off")
    ax.set_title("Generated Image")

    plt.show()

In [None]:
def plot_losses(loss_arr, intervals, label):
  plt.plot(intervals, loss_arr, 'g', label=label)
  plt.title('Visualizing ' + label)
  plt.xlabel('Epochs')
  plt.ylabel('Loss')
  plt.legend()
  plt.show()

# Defining the Model
---

In [None]:
def residual_block(x):
  filters = [64, 64]
  kernel_size = 3
  strides = 1
  padding = "same"
  momentum = 0.8
  activation = "relu"

  res = Conv2D(filters=filters[0], kernel_size=kernel_size, strides=strides, padding=padding)(x)
  res = Activation(activation=activation)(res)
  res = BatchNormalization(momentum=momentum)(res)
  res = Conv2D(filters=filters[1], kernel_size=kernel_size, strides=strides, padding=padding)(res)
  res = BatchNormalization(momentum=momentum)(res)
  # Elementwise sum
  res = Add()([res, x])
  return res

In [None]:
def build_generator():
  residual_blocks = 8
  momentum = 0.8
  input_shape = (None, None, 3)

  # Input Layer of the generator network
  input_layer = Input(shape=input_shape)

  # Add the pre-residual block
  gen1 = Conv2D(filters=64, kernel_size=9, strides=1, padding='same',
  activation='relu')(input_layer)

  # Add 8 residual blocks
  res = residual_block(gen1)
  for i in range(residual_blocks - 1):
    res = residual_block(res)

  # Add the post-residual block
  gen2 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(res)
  gen2 = BatchNormalization(momentum=momentum)(gen2)

  # Take the sum of the output from the pre-residual block(gen1) and the post-residual block(gen2)
  # Skip connection
  gen3 = Add()([gen2, gen1])

  # Add an upsampling block
  gen4 = UpSampling2D(size=2)(gen3)
  gen4 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(gen4)
  gen4 = Activation('relu')(gen4)

  # Add another upsampling block
  gen5 = UpSampling2D(size=2)(gen4)
  gen5 = Conv2D(filters=256, kernel_size=3, strides=1,
  padding='same')(gen5)
  gen5 = Activation('relu')(gen5)

  # Output convolution layer
  gen6 = Conv2D(filters=3, kernel_size=9, strides=1, padding='same')(gen3)
  output = Activation('tanh')(gen6)

  # Keras model
  model = Model(inputs=[input_layer], outputs=[output], name='generator')
  return model

In [None]:
def build_discriminator():
  leakyrelu_alpha = 0.2
  momentum = 0.8

  input_shape = (256, 256, 3)
  input_layer = Input(shape=input_shape)

  # Add the first convolution block
  dis1 = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(input_layer)
  dis1 = LeakyReLU(alpha=leakyrelu_alpha)(dis1)

  # Add the 2nd convolution block
  dis2 = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(dis1)
  dis2 = LeakyReLU(alpha=leakyrelu_alpha)(dis2)
  dis2 = BatchNormalization(momentum=momentum)(dis2)

  # Add the third convolution block
  dis3 = Conv2D(filters=128, kernel_size=3, strides=1, padding='same')(dis2)
  dis3 = LeakyReLU(alpha=leakyrelu_alpha)(dis3)
  dis3 = BatchNormalization(momentum=momentum)(dis3)

  # Add the fourth convolution block
  dis4 = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(dis3)
  dis4 = LeakyReLU(alpha=leakyrelu_alpha)(dis4)
  dis4 = BatchNormalization(momentum=0.8)(dis4)

  # Add the fifth convolution block
  dis5 = Conv2D(256, kernel_size=3, strides=1, padding='same')(dis4)
  dis5 = LeakyReLU(alpha=leakyrelu_alpha)(dis5)
  dis5 = BatchNormalization(momentum=momentum)(dis5)

  # Add the sixth convolution block
  dis6 = Conv2D(filters=256, kernel_size=3, strides=2, padding='same')(dis5)
  dis6 = LeakyReLU(alpha=leakyrelu_alpha)(dis6)
  dis6 = BatchNormalization(momentum=momentum)(dis6)

  # Add the seventh convolution block
  dis7 = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(dis6)
  dis7 = LeakyReLU(alpha=leakyrelu_alpha)(dis7)
  dis7 = BatchNormalization(momentum=momentum)(dis7)

  # Add the eight convolution block
  dis8 = Conv2D(filters=256, kernel_size=3, strides=2, padding='same')(dis7)
  dis8 = LeakyReLU(alpha=leakyrelu_alpha)(dis8)
  dis8 = BatchNormalization(momentum=momentum)(dis8)

  # Add a dense layer
  dis9 = Dense(units=256)(dis8)
  dis9 = LeakyReLU(alpha=0.2)(dis9)
  
  # Last dense layer - for classification
  output = Dense(units=1, activation='sigmoid')(dis9)
  model = Model(inputs=[input_layer], outputs=[output], name='discriminator')
  return model

In [None]:
def build_vgg():
  # Load pre-trained VGG19 model trained on 'Imagenet' dataset
  vgg = tf.keras.applications.VGG19(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
  vgg.trainable = False

  outputs = [vgg.layers[10].output]

  # Create a Keras model
  model = Model([vgg.input], outputs) #TODO: Try changing no. of layers
  return model

# Evaluation metric - Frechet Inception Distance (FID)
---

In [None]:
def calculate_fid(model, images1, images2):
	# calculate activations
	act1 = model.predict(images1)
	act2 = model.predict(images2)
	# calculate mean and covariance statistics
	mu1, sigma1 = act1.mean(axis=0), cov(act1, rowvar=False)
	mu2, sigma2 = act2.mean(axis=0), cov(act2, rowvar=False)
	# calculate sum squared difference between means
	ssdiff = np.sum((mu1 - mu2)**2.0)
	# calculate sqrt of product between cov
	covmean = sqrtm(sigma1.dot(sigma2))
	# check and correct imaginary numbers from sqrt
	if iscomplexobj(covmean):
		covmean = covmean.real
	# calculate score
	fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
	return fid

# Training the Model
---

In [None]:
# Take a checkpoint (to model weights save, model save locations) while training
def take_checkpoint(generator, discriminator, checkpoint_filepath, epoch):
  generator.save(model_save_filepath + "generator_" + str(epoch) + ".h5")
  discriminator.save(model_save_filepath + "discriminator_" + str(epoch) + ".h5")

  generator.save_weights(checkpoint_filepath + "generator_weights_" + str(epoch) + ".h5")
  discriminator.save_weights(checkpoint_filepath + "discriminator_weights_" + str(epoch) + ".h5")

In [None]:
data_dir = '/content/fyp-datasets/raw-890/*'
image_shape = (256, 256, 3)

In [None]:
all_enhanced_images, all_original_images = load_all_images(data_dir, (image_shape[0], image_shape[1]))
print(len(all_enhanced_images), len(all_original_images))

In [None]:
# Set the below values as necessary
epochs = 7880
batch_size = 8

loss_step = 100 # Interval to write losses to file
display_step = 100 # Interval to display sample raw-reference-generated images
checkpoint_step = 1000 # Interval at which to take checkpoint

continue_training = True # Set False if starting from epoch 0, else True
gen_weights_file = 'generator_weights_6000.h5' # Change file name based on where to continue from
disc_weights_file = 'discriminator_weights_6000.h5'
losses_file = 'losses.txt' # File to save losses to (IMPORTANT: for plotting later)

In [None]:
# Common optimizer for all networks
common_optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)

# Build and compile VGG19 network to extract features
vgg = build_vgg()
vgg.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])

# Build and compile the discriminator network
discriminator = build_discriminator()
discriminator.compile(loss='mse', optimizer=common_optimizer, metrics=['accuracy'])

# Build the generator network
generator = build_generator()

# Input layers for original and enhanced images
input_original = Input(shape=image_shape)
input_enhanced = Input(shape=image_shape)

# Generate enhanced images from original images
generated_enhanced_images = generator(input_original)

# Extract feature maps of the generated images
features = vgg(generated_enhanced_images)

# Make the discriminator network as non-trainable
discriminator.trainable = False

# Get the probability of generated enhanced images
probs = discriminator(generated_enhanced_images)

# Create and compile an adversarial model
adversarial_model = Model([input_original, input_enhanced], [probs, features])
adversarial_model.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[1e-3, 1], optimizer=common_optimizer)

# For calculating FID
inception_model = InceptionV3(include_top=False, pooling='avg', input_shape=(256,256,3))

In [None]:
# Uncomment to see Generator architecture
# tf.keras.utils.plot_model(generator,show_shapes=True)

In [None]:
# Uncomment to see Discriminator architecture
# tf.keras.utils.plot_model(discriminator,show_shapes=True)

In [None]:
start = 1 # Don't change this, default value
losses_file_obj = ''

# If continuing training from checkpoint
if continue_training:
  start = 6001 # Change based on where to continue from
  generator.load_weights(checkpoint_filepath + gen_weights_file)
  discriminator.load_weights(checkpoint_filepath + disc_weights_file)
  # Open losses text file in append mode
  losses_file_obj = open(checkpoint_filepath + losses_file, 'a')
else:
  losses_file_obj = open(checkpoint_filepath + losses_file, 'w+')

In [None]:
try:
  for epoch in range(start, epochs + 1):
      # Sample a batch of images
      original_images, enhanced_images = select_random_batch(all_enhanced_images, all_original_images, batch_size)
      
      # Normalizing
      original_images = original_images / 127.5 - 1.
      enhanced_images = enhanced_images / 127.5 - 1.

      # Generate enhanced images (fake) from original images (real)
      generated_enhanced_images = generator.predict(original_images)

      # Generate batch of real and fake labels
      real_labels = np.ones((batch_size, 16, 16, 1))
      fake_labels = np.zeros((batch_size, 16, 16, 1))

      # Train the discriminator network on real and fake images
      d_loss_real = discriminator.train_on_batch(enhanced_images, real_labels)
      d_loss_fake = discriminator.train_on_batch(generated_enhanced_images, fake_labels)

      # Calculate total discriminator loss
      d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

      # TRAINING GENERATOR NETWORK
      original_images, enhanced_images = select_random_batch(all_original_images, all_enhanced_images, batch_size)

      # Normalizing
      original_images = original_images / 127.5 - 1.
      enhanced_images = enhanced_images / 127.5 - 1.

      # Extract feature maps for enhanced images
      image_features = vgg.predict(enhanced_images)

      # Train the generator network
      g_loss = adversarial_model.train_on_batch([original_images, enhanced_images],
                                        [real_labels, image_features])

      if epoch % loss_step == 0:
        print("Epoch:{}".format(epoch))
        print("d_loss:", d_loss)
        print("g_loss:", g_loss)

        loss_to_write = "{epoch} {d_loss} {g_loss_per}\n".format(epoch=epoch, d_loss=d_loss[0], g_loss_per=g_loss[0])
        losses_file_obj.write(loss_to_write)

      if epoch != 1 and epoch % display_step == 0:
        original_images, enhanced_images = select_random_batch(all_original_images, all_enhanced_images, batch_size)

        # Normalizing
        original_images = original_images / 127.5 - 1.
        enhanced_images = enhanced_images / 127.5 - 1.

        generated_images = generator.predict_on_batch(original_images)

        fid = calculate_fid(inception_model, enhanced_images, generated_images)
        print('FID: %.3f' % fid)

        fid_to_write = "FID {epoch} {FID}\n".format(epoch=epoch, FID=fid)
        losses_file_obj.write(fid_to_write)

        for index, img in enumerate(generated_images):
          if index >= 4:
            break
          display_images(original_images[index], enhanced_images[index], img)
      
      if epoch != 1 and epoch % checkpoint_step == 0:
        take_checkpoint(generator, discriminator, checkpoint_filepath, epoch)
finally:
  losses_file_obj.close()

# Once training is completed for all epochs, final checkpoint
take_checkpoint(generator, discriminator, checkpoint_filepath, epoch)

# Visualizing losses
---

In [None]:
losses_file_to_plot = open(checkpoint_filepath + losses_file, 'r')
gen_loss_plot = []
disc_loss_plot = []
loss_epochs = []
FID_plot = []
FID_epochs = []

lines = losses_file_to_plot.readlines()

for line in lines:
  values = line.split()
  if values[0] == "FID":
    FID_epochs.append(int(values[1]))
    FID_plot.append(float(values[2]))
  else:
    loss_epochs.append(int(values[0]))
    disc_loss_plot.append(float(values[1]))
    gen_loss_plot.append(float(values[2]))

losses_file_to_plot.close()

In [None]:
# Plot generator loss
plot_losses(gen_loss_plot, loss_epochs, 'Generator loss')

In [None]:
# Plot discriminator loss
plot_losses(disc_loss_plot, loss_epochs, 'Discriminator loss')

In [None]:
# Plot FID
plot_losses(FID_plot, FID_epochs, 'FID')

# Testing
---


In [None]:
# Loading a saved model - change filename based on which model to load
generator = tf.keras.models.load_model(model_save_filepath + "generator_2000.h5", compile=False)

In [None]:
# Getting predictions to calculate FID - change number of images in batch as necessary
original_images, enhanced_images = select_random_batch(all_original_images, all_enhanced_images, 1000)
enhanced_images = enhanced_images / 127.5 - 1.
original_images = original_images / 127.5 - 1.
generated_images = []
for image in original_images:
  exp = tf.expand_dims(image, axis=0)
  temp = generator.predict(exp)
  generated_images.append(temp[0])
generated_images = np.array(generated_images)

In [None]:
# FID calculation
inception_model = InceptionV3(include_top=False, pooling='avg', input_shape=(256,256,3))
fid = calculate_fid(inception_model, enhanced_images, generated_images)
print('FID: %.3f' % fid)

In [None]:
# Testing on UIEB "challenging" set (without reference images) - change image shape as necessary
data_dir = '/content/fyp-datasets/challenging-60/*'
test_images = load_test_images(data_dir=data_dir, batch_size=2, image_shape=(400,350))
test_images = test_images / 127.5 - 1.

generated_images = generator.predict_on_batch(test_images)
for index, img in enumerate(generated_images):
  fig = plt.figure()
  ax = fig.add_subplot(1, 2, 1)
  ax.imshow(test_images[index])
  ax.axis("off")
  ax.set_title("Test Image")

  ax = fig.add_subplot(1, 2, 2)
  ax.imshow(img)
  ax.axis("off")
  ax.set_title("Generated Image")

  plt.show()