<a href="https://colab.research.google.com/github/rsanzd/deep-learning-coursera/blob/master/Image_Super_Resolution_ESRGAN_TF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Inference with ESRGAN

In [None]:
import tensorflow as tf
import tensorflow_hub as hub

import os, glob
import time
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Define paths
INPUT_DIR = '/content/drive/MyDrive/SpainAI'
TRAIN_FOLDER  = 'TrainSet'
TEST_FOLDER = 'TestSet'

OUT_DIR = os.path.join(INPUT_DIR, TEST_FOLDER, 'Upscaled_ESRGAN_inference')

if not os.path.exists(OUT_DIR):
    os.makedirs(OUT_DIR)

CWD = os.getcwd()

In [None]:
os.chdir(os.path.join(INPUT_DIR, TRAIN_FOLDER))
train_imgs = glob.glob('*.png')
os.chdir(CWD)

os.chdir(os.path.join(INPUT_DIR, TEST_FOLDER))
test_imgs = glob.glob('*.png')
os.chdir(CWD)

random.shuffle(train_imgs)
num_val_imgs = int(0.2 * len(train_imgs))
val_imgs = train_imgs[:num_val_imgs]
train_imgs = train_imgs[num_val_imgs:]

print('Len of the train files: ', len(train_imgs))
print('Len of the val files: ', len(val_imgs))
print('Len of the test files: ', len(test_imgs))

Len of the files:  100


### Definition auxiliary functions

In [None]:
def preprocess_image(image_path):
  """ Loads image from path and preprocesses to make it model ready
      Args:
        image_path: Path to the image file
  """
  hr_image = tf.image.decode_image(tf.io.read_file(image_path))
  # If PNG, remove the alpha channel. The model only supports
  # images with 3 color channels.
  if hr_image.shape[-1] == 4:
    hr_image = hr_image[...,:-1]
  hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4
  hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1])
  hr_image = tf.cast(hr_image, tf.float32)
  return tf.expand_dims(hr_image, 0)

def save_image(image, filepath):
  """
    Saves unscaled Tensor Images.
    Args:
      image: 3D image tensor. [height, width, channels]
      filepath: Name of the file to save to.
  """
  if not isinstance(image, Image.Image):
    image = tf.clip_by_value(image, 0, 255)
    image = Image.fromarray(tf.cast(image, tf.uint8).numpy())
  
  image.save(filepath)
  print("Saved as ", filepath)


In [None]:
%matplotlib inline
def plot_image(image, title=""):
  """
    Plots images from image tensors.
    Args:
      image: 3D image tensor. [height, width, channels].
      title: Title to display in the plot.
  """
  image = np.asarray(image)
  image = tf.clip_by_value(image, 0, 255)
  image = Image.fromarray(tf.cast(image, tf.uint8).numpy())
  plt.imshow(image)
  plt.axis("off")
  plt.title(title)

## Load the model

In [None]:
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")

Downloaded https://tfhub.dev/captain-pool/esrgan-tf2/1, Total size: 20.60MB



## Inference on test set

In [None]:
# Calculating PSNR wrt Original Image
def calculate_ssim(gen_image, hr_image, printed=True):
    ssim = tf.image.ssim(tf.clip_by_value(gen_image, 0, 255),
                  tf.clip_by_value(hr_image, 0, 255),
                  max_val=255)
    if printed:
        print("SSIM Achieved: %f" % ssim)
        
    return ssim

def calculate_psnr(gen_image, hr_image):
    psnr = tf.image.psnr(
        tf.clip_by_value(gen_image, 0, 255),
        tf.clip_by_value(hr_image, 0, 255), max_val=255)
    print("PSNR Achieved: %f" % psnr)

In [None]:
for filename in img_list:
    img_path = os.path.join(INPUT_DIR, TEST_FOLDER, filename)
    dest_name = 'candidate_' + filename.split('_',2)[-1]
    dest_path = os.path.join(OUT_DIR, dest_name)
    lr_image = preprocess_image(img_path)
    sr_image = model(lr_image)
    # plot_image(tf.squeeze(sr_image), title="Generated image")
    save_image(tf.squeeze(sr_image), filepath=dest_path)

## Fine tuning of the model

### Define the model

In [None]:
generator = tf.keras.models.Sequential([
    hub.KerasLayer("https://tfhub.dev/captain-pool/esrgan-tf2/1", trainable=True, input_shape=(256,256, 3)),
    tf.keras.layers.Conv2D(filters=3, kernel_size=[1, 1], strides=[1, 1])
])

In [None]:
generator.summary()

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
keras_layer_5 (KerasLayer)   (None, None, None, 3)     4605955   
_________________________________________________________________
conv2d_5 (Conv2D)            (None, None, None, 3)     12        
Total params: 4,605,967
Trainable params: 4,605,967
Non-trainable params: 0
_________________________________________________________________


In [None]:
len(generator.trainable_weights)

344

In [None]:
EPOCHS = 100
LR = 5e-5
BS = 32

### Preparing the dataset

In [None]:
# preprocessing function
def map_image(lr_image_path, hr_image_path):
    lr_image = preprocess_image(lr_image_path)
    lr_image = tf.clip_by_value(tf.squeeze(lr_image), 0, 255)
    lr_image = Image.fromarray(tf.cast(lr_image, tf.uint8).numpy())

    hr_image = tf.image.decode_image(tf.io.read_file(hr_image_path))
    hr_image = np.asarray(hr_image)
    hr_image = tf.clip_by_value(tf.squeeze(lr_image), 0, 255)
    hr_image = Image.fromarray(tf.cast(hr_image, tf.uint8).numpy())

    return lr_image, hr_image


# Prepare the training dataset.
batch_size = BS

# Prepare the training dataset. preprocess the dataset with the `map_image()` function above
train_dataset = tf.data.Dataset.from_tensor_slices((train_imgs, train_imgs_hr)).map(map_image)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset. preprocess the dataset with the `map_image()` function above
val_dataset = tf.data.Dataset.from_tensor_slices((val_imgs, val_imgs_hr)).map(map_image)
val_dataset = val_dataset.batch(batch_size)

# Alternative. Using ImageDataGenerator

### Define losses and metrics

In [None]:
def ssim_loss(gen, target):
    ssim_loss = 1 - tf.reduce_mean(calculate_ssim(gen, target, printed=False))
    return ssim_loss

def ssim_metric(gen, target):
    return tf.reduce_mean(calculate_ssim(gen, target, printed=False))

### Training the model

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=LR)

In [None]:
import time

for epoch in range(EPOCHS):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    train_ssim_batch_acc = []
    val_ssim_batch_acc = []

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):

        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:

            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            gen_batch_train = generator(x_batch_train)  

            # Compute the loss value for this minibatch.
            loss_value = ssim_loss(gen_batch_train, y_batch_train)

        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, generator.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, generator.trainable_weights))

        # Update training metric.
        train_ssim_batch_acc.append(ssim_metric(gen_batch_train, y_batch_train))

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * 64))

    # Display metrics at the end of each epoch.
    train_ssim = tf.reduce_mean(train_ssim_batch_acc)
    print("Training ssim over epoch: %.4f" % (float(train_ssim),))

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        gen_batch_val = generator(x_batch_val)

        # Update val metrics
        val_ssim_batch_acc.append(ssim_metric(gen_batch_val, y_batch_val))

    val_ssim = tf.reduce_mean(val_ssim_batch_acc)
    print("Validation acc: %.4f" % (float(val_ssim),))
    print("Time taken: %.2fs" % (time.time() - start_time))

generator.save('gen_model.h5')

NameError: ignored