In [1]:
import os
import tensorflow as tf
from tensorflow.keras.models import Model
#from tensorflow.python.keras.layers import PReLU
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from models import *
from utils import *
from dataset_load import create_training_and_validation_datasets
import matplotlib.pyplot as plt
import time
from PIL import Image

In [2]:
dataset_folder = os.path.abspath('dataset')
dataset_parameters = {'train_low_res':'LR/DIV2K_train_LR_bicubic',
                     'val_low_res': 'LR/DIV2K_valid_LR_bicubic',
                     'data_directory': dataset_folder}

In [3]:
hr_crop_size = 96

In [4]:
train_mappings = [
    lambda lr, hr: random_crop(lr, hr, hr_crop_size=hr_crop_size)
    ]

In [5]:
train_dataset, valid_dataset = create_training_and_validation_datasets(dataset_parameters, train_mappings)
# for sppeding up the process take only 20 images from the validation set
valid_dataset_subset = valid_dataset.take(20)

In [6]:
fine_tune =True # for fine tuning the srgan

# Train the SRResNet generator model

In [7]:
if not fine_tune:
    generator = generatorNet()
    checkpoint_dir = f'./ckpt/srresnet_bicubic_x4'

    learning_rate=1e-4

    checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                     epoch=tf.Variable(0),
                                     psnr=tf.Variable(0.0),
                                     optimizer=tf.keras.optimizers.Adam(learning_rate),
                                     model=generator)

    checkpoint_manager = tf.train.CheckpointManager(checkpoint=checkpoint,
                                                    directory=checkpoint_dir,
                                                    max_to_keep=3)

    if checkpoint_manager.latest_checkpoint:
        checkpoint.restore(checkpoint_manager.latest_checkpoint)
        print(f'Model restored from checkpoint at step {checkpoint.step.numpy()} with validation PSNR {checkpoint.psnr.numpy()}.')

    training_steps = 1000000

    steps_per_epoch = 1000

    training_epochs = training_steps / steps_per_epoch

    if checkpoint.epoch.numpy() < training_epochs:
        remaining_epochs = int(training_epochs - checkpoint.epoch.numpy())
        print(f"Continuing Training from epoch {checkpoint.epoch.numpy()}. Remaining epochs: {remaining_epochs}.")
        save_checkpoint_callback = SaveCheckpoints(checkpoint_manager, steps_per_epoch)
        checkpoint.model.compile(optimizer=checkpoint.optimizer, loss=tf.keras.losses.MeanSquaredError(), metrics=[custom_psnr])
        checkpoint.model.fit(train_dataset,validation_data=valid_dataset_subset, steps_per_epoch=steps_per_epoch, epochs=remaining_epochs, callbacks=[save_checkpoint_callback])
    else:
        print("Training already completed. To continue training, increase the number of training steps")

In [8]:
if not fine_tune:
    weights_file = f"weights/srresnetGenerator.h5"
    os.makedirs(weights_directory, exist_ok=True)
    checkpoint.model.save_weights(weights_file)

# Train SRGAN using SRResNet as the generator

In [9]:
if fine_tune:
    weights_file = f"weights/srganGenerator.h5"

In [10]:
generator = generatorNet()
print(len(generator.layers))
generator.summary()
#generator.load_weights(weights_file)

111
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
lambda (Lambda)                 (None, None, None, 3 0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, None, None, 6 15616       lambda[0][0]                     
__________________________________________________________________________________________________
p_re_lu (PReLU)                 (None, None, None, 6 64          conv2d[0][0]                     
__________________________________________________________________________________________

In [None]:
discriminator = discriminatorNet(hr_crop_size = hr_crop_size)

In [None]:
layer_5_4 = 20
vgg = VGG19(input_shape=(None, None, 3), include_top=False)
perceptual_model = Model(vgg.input, vgg.layers[layer_5_4].output)

In [None]:
binary_cross_entropy = tf.keras.losses.BinaryCrossentropy()
mean_squared_error = tf.keras.losses.MeanSquaredError()
learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[100000], values=[1e-4, 1e-5])

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

In [None]:
srgan_checkpoint_dir=f'./ckpt/srgan_bicubic_x4'

srgan_checkpoint = tf.train.Checkpoint(step = tf.Variable(0),
                                       psnr = tf.Variable(0.0),
                                       generator_optimizer = generator_optimizer,
                                       discriminator_optimizer = discriminator_optimizer,
                                       generator = generator,
                                       discriminator = discriminator)

srgan_checkpoint_manager = tf.train.CheckpointManager(checkpoint = srgan_checkpoint,
                                                directory = srgan_checkpoint_dir,
                                                max_to_keep = 3)

In [None]:
if srgan_checkpoint_manager.latest_checkpoint:
    srgan_checkpoint.restore(srgan_checkpoint_manager.latest_checkpoint)
    print(f'Model restored from checkpoint at step {srgan_checkpoint.step.numpy()} with validation PSNR {srgan_checkpoint.psnr.numpy()}.')

In [None]:
@tf.function
def train_step(lr, hr):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        lr = tf.cast(lr, tf.float32)
        hr = tf.cast(hr, tf.float32)

        sr = srgan_checkpoint.generator(lr, training=True)

        hr_output = srgan_checkpoint.discriminator(hr, training=True)
        sr_output = srgan_checkpoint.discriminator(sr, training=True)
        
        # content loss
        con_loss = calculate_content_loss(hr, sr)
        # generator loss
        gen_loss = calculate_generator_loss(sr_output)
        # perceptual loss
        perc_loss = con_loss + 0.001 * gen_loss
        # discriminator loss
        disc_loss = calculate_discriminator_loss(hr_output, sr_output)

    gradients_of_generator = gen_tape.gradient(perc_loss, srgan_checkpoint.generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, srgan_checkpoint.discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, srgan_checkpoint.generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, srgan_checkpoint.discriminator.trainable_variables))

    return perc_loss, disc_loss

@tf.function
def calculate_content_loss(hr, sr):
    sr = preprocess_input(sr)
    hr = preprocess_input(hr)
    sr_features = perceptual_model(sr) / 12.75
    hr_features = perceptual_model(hr) / 12.75
    return mean_squared_error(hr_features, sr_features)

def calculate_generator_loss(sr_out):
    return binary_cross_entropy(tf.ones_like(sr_out), sr_out)

def calculate_discriminator_loss(hr_out, sr_out):
    hr_loss = binary_cross_entropy(tf.ones_like(hr_out), hr_out)
    sr_loss = binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
    return hr_loss + sr_loss

In [None]:
perceptual_loss_metric = tf.keras.metrics.Mean()
discriminator_loss_metric = tf.keras.metrics.Mean()

step = srgan_checkpoint.step.numpy()
steps = 200000

monitor_folder = f"monitor_training/srgan_bicubic_x4"
os.makedirs(monitor_folder, exist_ok=True)

now = time.perf_counter()

psnr_values = []
for lr, hr in train_dataset.take(steps - step):
    srgan_checkpoint.step.assign_add(1)
    step = srgan_checkpoint.step.numpy()

    perceptual_loss, discriminator_loss = train_step(lr, hr)
    perceptual_loss_metric(perceptual_loss)
    discriminator_loss_metric(discriminator_loss)

    if step % 1000 == 0:
        
        for lr, hr in valid_dataset_subset:
            sr = srgan_checkpoint.generator.predict(lr)[0]
            sr = tf.clip_by_value(sr, 0, 255)
            sr = tf.round(sr)
            sr = tf.cast(sr, tf.uint8)
            
            psnr_value = custom_psnr(hr, sr)[0]
            psnr_values.append(psnr_value)
            psnr = tf.reduce_mean(psnr_values)
            
        image = Image.fromarray(sr.numpy())
        image.save(f"{monitor_folder}/{step}.png" )
        
        duration = time.perf_counter() - now
        
        now = time.perf_counter()
        
        print(f'{step}/{steps}, psnr = {psnr}, perceptual loss = {perceptual_loss_metric.result():.4f}, discriminator loss = {discriminator_loss_metric.result():.4f} ({duration:.2f}s)')
        
        perceptual_loss_metric.reset_states()
        discriminator_loss_metric.reset_states()
        
        srgan_checkpoint.psnr.assign(psnr)
        srgan_checkpoint_manager.save()

In [None]:
weights_directory = f"weights"
os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/srganGenerator.h5'
srgan_checkpoint.generator.save_weights(weights_file)

In [None]:
plt.plot(range(len(psnr_values)), psnr_values)