In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import cv2 
import os
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras import layers, Model
from sklearn.model_selection import train_test_split
from keras import Model
from keras.layers import Conv2D
from keras.layers import PReLU
from keras.layers import BatchNormalization
from keras.layers import Flatten
from keras.layers import UpSampling2D
from keras.layers import LeakyReLU
from keras.layers import Dense
from keras.layers import Input
from keras.layers import add
from tqdm import tqdm




In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
# Since the data is too large to fit in memory, we will use a generator to load the data in batches.
def load_image(path):
    try:
        img = cv2.imread(path)
        # If the image has not 3 channels, generate a 3 channels image from the gray scale image
        if img.shape[2]!=3:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img
    except:
        print(path)
        return None

def load_data(batch_images, data_augmentation=0):
    imgs_hr, imgs_lr = np.empty((0, 1080, 1920, 3)), np.empty((0, 270, 480, 3))
    for img_path in batch_images:
        img_hr = load_image(img_path)
        # If the image is not at 1920x1080, resize it to 1920x1080
        if img_hr.shape[0]!=1080 or img_hr.shape[1]!=1920:
            img_hr = cv2.resize(img_hr, (1920, 1080), interpolation=cv2.INTER_CUBIC)
        
        # Make data augmentation here
        if data_augmentation>0:
            img_hr = aug_image(img_hr, data_augmentation)
        img_lr = cv2.resize(img_hr, (480, 270), interpolation=cv2.INTER_CUBIC)
        
        imgs_hr = np.append(imgs_hr, [img_hr], axis=0)
        imgs_lr = np.append(imgs_lr, [img_lr], axis=0)
    return imgs_hr, imgs_lr

# Augment the image by applying random rotation, random zoom and random translation
def aug_image(img,num_of_aug):
    img_list = []
    for i in range(num_of_aug):
        # Random rotation
        angle = np.random.randint(0,360)
        img_rot = rotate_image(img, angle)
        # Random zoom
        zoom_factor = np.random.randint(1,5)
        img_zoom = zoom_image(img_rot, zoom_factor)
        # Random translation
        x_shift = np.random.randint(-50,50)
        y_shift = np.random.randint(-50,50)
        img_shift = shift_image(img_zoom, x_shift, y_shift)
        img_list.append(img_shift)
    return img_list[np.random.randint(0,num_of_aug)]


def rotate_image(img, angle):
    rows,cols = img.shape[0:2]
    M = cv2.getRotationMatrix2D((cols/2,rows/2),angle,1)
    img_rot = cv2.warpAffine(img,M,(cols,rows))
    return img_rot

def zoom_image(img, zoom_factor):
    rows,cols = img.shape[0:2]
    M = cv2.getRotationMatrix2D((cols/2,rows/2),0,zoom_factor)
    img_zoom = cv2.warpAffine(img,M,(cols,rows))
    return img_zoom

def shift_image(img, x_shift, y_shift):
    rows,cols = img.shape[0:2]
    M = np.float32([[1,0,x_shift],[0,1,y_shift]])
    img_shift = cv2.warpAffine(img,M,(cols,rows))
    return img_shift

In [3]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Add, Lambda
from tensorflow.python.keras.layers import PReLU

from utils.normalization import normalize_01, denormalize_m11


upsamples_per_scale = {
    2: 1,
    4: 2,
    8: 3
}


def pixel_shuffle(scale):
    return lambda x: tf.nn.depth_to_space(x, scale)


def upsample(x_in, num_filters):
    x = Conv2D(num_filters, kernel_size=3, padding='same')(x_in)
    x = Lambda(pixel_shuffle(scale=2))(x)
    return PReLU(shared_axes=[1, 2])(x)


def residual_block(block_input, num_filters, momentum=0.8):
    x = Conv2D(num_filters, kernel_size=3, padding='same')(block_input)
    x = BatchNormalization(momentum=momentum)(x)
    x = PReLU(shared_axes=[1, 2])(x)
    x = Conv2D(num_filters, kernel_size=3, padding='same')(x)
    x = BatchNormalization(momentum=momentum)(x)
    x = Add()([block_input, x])
    return x


def build_srresnet(scale=4, num_filters=64, num_res_blocks=16):
    if scale not in upsamples_per_scale:
        raise ValueError(f"available scales are: {upsamples_per_scale.keys()}")

    num_upsamples = upsamples_per_scale[scale]

    lr = Input(shape=(None, None, 3))
    x = Lambda(normalize_01)(lr)

    x = Conv2D(num_filters, kernel_size=9, padding='same')(x)
    x = x_1 = PReLU(shared_axes=[1, 2])(x)

    for _ in range(num_res_blocks):
        x = residual_block(x, num_filters)

    x = Conv2D(num_filters, kernel_size=3, padding='same')(x)
    x = BatchNormalization()(x)
    x = Add()([x_1, x])

    for _ in range(num_upsamples):
        x = upsample(x, num_filters * 4)

    x = Conv2D(3, kernel_size=9, padding='same', activation='tanh')(x)
    sr = Lambda(denormalize_m11)(x)

    return Model(lr, sr)

In [4]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Add, Lambda, LeakyReLU, Flatten, Dense
from tensorflow.python.keras.layers import PReLU

from utils.normalization import normalize_m11


def discriminator_block(x_in, num_filters, strides=1, batchnorm=True, momentum=0.8):
    x = Conv2D(num_filters, kernel_size=3, strides=strides, padding='same')(x_in)
    if batchnorm:
        x = BatchNormalization(momentum=momentum)(x)
    return LeakyReLU(alpha=0.2)(x)


def build_discriminator(hr_crop_size):
    x_in = Input(shape=(hr_crop_size, hr_crop_size, 3))
    x = Lambda(normalize_m11)(x_in)

    x = discriminator_block(x, 64, batchnorm=False)
    x = discriminator_block(x, 64, strides=2)

    x = discriminator_block(x, 128)
    x = discriminator_block(x, 128, strides=2)

    x = discriminator_block(x, 256)
    x = discriminator_block(x, 256, strides=2)

    x = discriminator_block(x, 512)
    x = discriminator_block(x, 512, strides=2)

    x = Flatten()(x)

    x = Dense(1024)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dense(1, activation='sigmoid')(x)

    return Model(x_in, x)

In [5]:
from utils.dataset_mappings import random_crop, random_flip, random_rotate, random_lr_jpeg_noise
from utils.metrics import psnr_metric
from utils.config import config
from utils.callbacks import SaveCustomCheckpoint
from tensorflow.python.data.experimental import AUTOTUNE


hr_crop_size = 96

train_mappings = [
    lambda lr, hr: random_crop(lr, hr, hr_crop_size=hr_crop_size, scale=4), 
    random_flip, 
    random_rotate, 
    random_lr_jpeg_noise]

def data_loader_train(paths, batch_size, mappings,nb_batches=10):
    batch_images_hr, batch_images_lr = [], []
    for i in range(nb_batches):
        batch_paths= np.random.choice(a=paths, size=batch_size)
        images_hr, images_lr = load_data(batch_paths, data_augmentation=0)
        batch_images_hr.append(images_hr)
        batch_images_lr.append(images_lr)
        del images_hr, images_lr, batch_paths

    # Convert the list into tf dataset
    batch_images_lr = np.concatenate(batch_images_lr, axis=0)
    batch_images_hr = np.concatenate(batch_images_hr, axis=0)
    batch_images_lr = tf.data.Dataset.from_tensor_slices(batch_images_lr)
    batch_images_hr = tf.data.Dataset.from_tensor_slices(batch_images_hr)

    dataset = tf.data.Dataset.zip((batch_images_lr, batch_images_hr))
    del batch_images_lr, batch_images_hr
    for mapping in mappings:
        dataset = dataset.map(mapping,num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    return dataset


train_path = './data/test'
train_set_path = []
for path in os.listdir(train_path):
    train_set_path.append(os.path.join(train_path, path))

# train_dataset = data_loader_train(paths=train_set_path, batch_size=10, mappings=train_mappings)
# valid_dataset_subset = train_dataset.take(4)



In [6]:
from tensorflow.keras.optimizers import Adam

# valid_dataset_subset = train_dataset.take(10)


generator = build_srresnet(scale=4, num_filters=64, num_res_blocks=16)
generator.summary()

checkpoint_dir=f'./ckpt/sr_resnet'

learning_rate=1e-4

checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                 epoch=tf.Variable(0),
                                 psnr=tf.Variable(0.0),
                                 optimizer=Adam(learning_rate),
                                 model=generator)

checkpoint_manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                                directory=checkpoint_dir,
                                                max_to_keep=3)

if checkpoint_manager.latest_checkpoint:
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    print(f'Model restored from checkpoint at step {checkpoint.step.numpy()} with validation PSNR {checkpoint.psnr.numpy()}.')





Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, None, None, 3)]      0         []                            
                                                                                                  
 lambda (Lambda)             (None, None, None, 3)        0         ['input_1[0][0]']             
                                                                                                  
 conv2d (Conv2D)             (None, None, None, 64)       15616     ['lambda[0][0]']              
                                                                                                  
 tf.math.negative (TFOpLamb  (None, None, None, 64)       0         ['conv2d[0][0]']              
 da)                                                                                        

In [None]:
from tensorflow.keras.losses import MeanSquaredError, BinaryCrossentropy, MeanAbsoluteError
import gc

training_steps = 20_000

steps_per_epoch = 1000

training_epochs = training_steps / steps_per_epoch
gc.collect()
loss_for_plot = []


if checkpoint.epoch.numpy() < training_epochs:

    remaining_epochs = int(training_epochs - checkpoint.epoch.numpy())
    print(f"Continuing Training from epoch {checkpoint.epoch.numpy()}. Remaining epochs: {remaining_epochs}.")
    save_checkpoint_callback = SaveCustomCheckpoint(checkpoint_manager, steps_per_epoch)
    checkpoint.model.compile(optimizer=checkpoint.optimizer, loss=MeanSquaredError(), metrics=[psnr_metric])
    checkpoint.model.fit(train_dataset,validation_data=valid_dataset_subset, steps_per_epoch=steps_per_epoch, epochs=remaining_epochs, callbacks=[save_checkpoint_callback])
        # Compute the loss on the validation set
    loss = checkpoint.model.evaluate(valid_dataset_subset, steps=1)
    loss_for_plot.append(loss)
else:
    print("Training already completed. To continue training, increase the number of training steps")

weights_directory = f"/content/drive/MyDrive/weights/srresnet"
os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/generator.h5'
checkpoint.model.save_weights(weights_file)
loss_df = pd.DataFrame(loss_for_plot)
loss_df.to_csv(f'{weights_directory}/loss.csv', index=False)


In [None]:
weights_directory = f"weights/srresnet"
os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/generator.h5'
checkpoint.model.save_weights(weights_file)

Train the discriminator

In [None]:
generator = build_srresnet(scale=4)
generator.load_weights(weights_file)

In [7]:
discriminator = build_discriminator(hr_crop_size=hr_crop_size)
discriminator.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 96, 96, 3)]       0         
                                                                 
 lambda_4 (Lambda)           (None, 96, 96, 3)         0         
                                                                 
 conv2d_37 (Conv2D)          (None, 96, 96, 64)        1792      
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 96, 96, 64)        0         
                                                                 
 conv2d_38 (Conv2D)          (None, 48, 48, 64)        36928     
                                                                 
 batch_normalization_33 (Ba  (None, 48, 48, 64)        256       
 tchNormalization)                                               
                                                           

In [None]:
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
layer_5_4 = 20
vgg = VGG19(input_shape=(None, None, 3), include_top=False)
perceptual_model = Model(vgg.input, vgg.layers[layer_5_4].output)

In [None]:
binary_cross_entropy = BinaryCrossentropy()
mean_squared_error = MeanSquaredError()

In [None]:
from tensorflow.keras.optimizers.schedules import PiecewiseConstantDecay
learning_rate=PiecewiseConstantDecay(boundaries=[10000], values=[1e-4, 1e-5])

In [None]:
generator_optimizer = Adam(learning_rate=learning_rate)
discriminator_optimizer = Adam(learning_rate=learning_rate)

In [None]:
srgan_checkpoint_dir=f'./ckpt/srgan'

srgan_checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                       psnr=tf.Variable(0.0),
                                       generator_optimizer=Adam(learning_rate),
                                       discriminator_optimizer=Adam(learning_rate),
                                       generator=generator,
                                       discriminator=discriminator)

srgan_checkpoint_manager = tf.train.CheckpointManager(checkpoint=srgan_checkpoint,
                                                directory=srgan_checkpoint_dir,
                                                max_to_keep=3)

In [None]:
if srgan_checkpoint_manager.latest_checkpoint:
    srgan_checkpoint.restore(srgan_checkpoint_manager.latest_checkpoint)
    print(f'Model restored from checkpoint at step {srgan_checkpoint.step.numpy()} with validation PSNR {srgan_checkpoint.psnr.numpy()}.')


In [None]:
@tf.function
def train_step(lr, hr):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        lr = tf.cast(lr, tf.float32)
        hr = tf.cast(hr, tf.float32)

        sr = srgan_checkpoint.generator(lr, training=True)

        hr_output = srgan_checkpoint.discriminator(hr, training=True)
        sr_output = srgan_checkpoint.discriminator(sr, training=True)

        con_loss = calculate_content_loss(hr, sr)
        gen_loss = calculate_generator_loss(sr_output)
        perc_loss = con_loss + 0.001 * gen_loss
        disc_loss = calculate_discriminator_loss(hr_output, sr_output)

    gradients_of_generator = gen_tape.gradient(perc_loss, srgan_checkpoint.generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, srgan_checkpoint.discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, srgan_checkpoint.generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, srgan_checkpoint.discriminator.trainable_variables))

    return perc_loss, disc_loss

@tf.function
def calculate_content_loss(hr, sr):
    sr = preprocess_input(sr)
    hr = preprocess_input(hr)
    sr_features = perceptual_model(sr) / 12.75
    hr_features = perceptual_model(hr) / 12.75
    return mean_squared_error(hr_features, sr_features)

def calculate_generator_loss(sr_out):
    return binary_cross_entropy(tf.ones_like(sr_out), sr_out)

def calculate_discriminator_loss(hr_out, sr_out):
    hr_loss = binary_cross_entropy(tf.ones_like(hr_out), hr_out)
    sr_loss = binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
    return hr_loss + sr_loss

In [None]:
from tensorflow.keras.metrics import Mean
import time
from PIL import Image

perceptual_loss_metric = Mean()
discriminator_loss_metric = Mean()

step = srgan_checkpoint.step.numpy()
steps = 20000

monitor_folder = f"monitor_training/srgan"
os.makedirs(monitor_folder, exist_ok=True)

now = time.perf_counter()

loss_for_plot = []

for lr, hr in train_dataset.take(steps - step):
    srgan_checkpoint.step.assign_add(1)
    step = srgan_checkpoint.step.numpy()

    perceptual_loss, discriminator_loss = train_step(lr, hr)
    perceptual_loss_metric(perceptual_loss)
    discriminator_loss_metric(discriminator_loss)

    if step % 200 == 0:
        psnr_values = []
        loss_array = []
        for lr, hr in valid_dataset_subset:
            sr = srgan_checkpoint.generator.predict(lr)[0]
            sr = tf.clip_by_value(sr, 0, 255)
            sr = tf.round(sr)
            sr = tf.cast(sr, tf.uint8)

            loss_array.append(srgan_checkpoint.generator.evaluate(lr, hr, steps=1))

            psnr_value = psnr_metric(hr, sr)[0]
            psnr_values.append(psnr_value)
            psnr = tf.reduce_mean(psnr_values)


        mean_loss = np.mean(loss_array, axis=0)
        image = Image.fromarray(sr.numpy())
        image.save(f"{monitor_folder}/{step}.png" )
        
        duration = time.perf_counter() - now
        
        now = time.perf_counter()

        loss_for_plot.append([step, perceptual_loss_metric.result().numpy(), discriminator_loss_metric.result().numpy(), psnr.numpy(),mean_loss])
        
        print(f'{step}/{steps}, psnr = {psnr}, perceptual loss = {perceptual_loss_metric.result():.4f}, discriminator loss = {discriminator_loss_metric.result():.4f} ({duration:.2f}s)')
        
        perceptual_loss_metric.reset_states()
        discriminator_loss_metric.reset_states()
        
        srgan_checkpoint.psnr.assign(psnr)
        srgan_checkpoint_manager.save()

In [None]:
weights_directory = f"weights/srgan"
os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/generator.h5'
srgan_checkpoint.generator.save_weights(weights_file)
# Save the loss for each epoch in a csv file
loss_df = pd.DataFrame(loss_for_plot)
loss_df.to_csv(f'{weights_directory}/loss.csv', index=False)