<h2>Implementación de SRGAN para la generación de imágenes en Super Resolución</h2>

Se importan las librerías necesarias para la ejecución del Notebook.

In [None]:
# coding=utf-8
# srgan.ipynb
# ---------
# Licensing Information:  You are free to use or extend these projects for
# educational purposes provided that (1) you do not distribute or publish
# solutions, (2) you retain this notice, and (3) you provide clear
# attribution to UC Berkeley, including a link to http://ai.berkeley.edu and Pablo Doñate.
#
# Using Machine Learning techniques for image enhancement.
# This file has been created by Pablo Doñate Navarro (800710@unizar.es).

import os
import time
import tensorflow as tf
import keras
import numpy as np
import pandas as pd
from keras.optimizers import Adam
from keras.losses import MeanSquaredError, BinaryCrossentropy
from keras.applications.vgg19 import VGG19, preprocess_input
from keras.models import Model
from keras.metrics import Mean
from PIL import Image
import matplotlib.pyplot as plt

from Utils.metrics import rmse_metric
from Utils.callbacks import SaveCustomCheckpoint
from Dataset.dataset_config import DatasetConfig
from Dataset.dataset_loader import create_training_and_validation_datasets
from Dataset.dataset_mappings import random_crop, random_flip, random_rotate, random_lr_jpeg_noise
from Modelos.generator import build_generator
from Modelos.discriminator import build_discriminator

<h3>Preparación del dataset:</h3>

In [2]:
# Nombre del dataset a utilizar.
dataset_name = "chest_x-ray"

# Ubicación de la carpeta de los datasets.
dataset_folder = os.path.abspath(os.path.join(os.getcwd(), "Dataset", dataset_name))

# Ubicación de la carpeta de los resultados de entrenamiento.
training_results_folder = f"training_results/srgan_{dataset_name}"

# Se definen los parámetros del dataset.
dataset_parameters = DatasetConfig(dataset_name, save_data_directory=dataset_folder)

Se definen algunos hiperparámetros.

In [3]:
# Se define el tamaño de los recortes o crops.
crop_size = 48

""" 
Tamaño de batch.
      Número de imágenes LR que son procesadas simultáneamente durante el entrenamiento.
      Cuanto mayor sea, más rapido el entrenamiento, pero más memoria y calculo necesitará. 
"""
batch_size = 16

""" 
Tasa de aprendizaje.
      Porcentaje de cambio con el que se actualizan los pesos en cada iteración.
      Controla la velocidad a la que la red neuronal converge hacia la solución óptima durante el entrenamiento.
      Una tasa de aprendizaje alta puede llevar a oscilaciones o divergencia durante el entrenamiento,
      mientras que una tasa de aprendizaje demasiado baja puede hacer que el modelo tarde demasiado en converger hacia la solución óptima.
"""
learning_rate=1e-4

# Se definen las iteraciones del entrenamiento del modelo combinado.
steps = 200_000

# Se definen el número de bits de las imágenes de entrada.
num_bits_img = 8

Se definen las funciones de mapeado que se van a aplicar sobre el dataset con el fin de realizar un pre-procesamiento de los datos.

In [4]:
train_mapping = [
    lambda lr, hr: random_crop(lr, hr, hr_crop_size=crop_size, scale=dataset_parameters.scale),
    random_flip,
    random_rotate]

Se crea el dataset de entrenamiento y validación.

In [None]:
train_dataset, valid_dataset = create_training_and_validation_datasets(dataset_parameters, batch_size, train_mapping)

# Se toman únicamente las primeras 10 imágenes del dataset de validación.
valid_dataset_subset = valid_dataset.take(10)

Se entrena la red Generadora

<h2>MODELO COMBINADO</h2>

In [6]:
# Se obtiene el generador.
generator = build_generator(scale=dataset_parameters.scale, num_filters=64, num_residual_blocks=16)
generator.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, None, None,  0           []                               
                                 1)]                                                              
                                                                                                  
 lambda (Lambda)                (None, None, None,   0           ['input_1[0][0]']                
                                1)                                                                
                                                                                                  
 conv2d (Conv2D)                (None, None, None,   5248        ['lambda[0][0]']                 
                                64)                                                           

In [7]:
# Se obtiene el discriminador.
discriminator = build_discriminator(hr_crop_size=crop_size)
discriminator.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 48, 48, 1)]       0         
                                                                 
 lambda_4 (Lambda)           (None, 48, 48, 1)         0         
                                                                 
 conv2d_37 (Conv2D)          (None, 48, 48, 64)        640       
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 48, 48, 64)        0         
                                                                 
 conv2d_38 (Conv2D)          (None, 24, 24, 64)        36928     
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 24, 24, 64)        0         
                                                                 
 batch_normalization_33 (Bat  (None, 24, 24, 64)       256 

Se define la red neuronal VGG19.

In [8]:
""" 
VGG19
      Red neuronal preentrenada que se utiliza como un extractor de características para calcular la pérdida de contenido durante 
      el entrenamiento del modelo, ya que comparar los píxeles entre las dos imágenes puede no ser suficiente.
"""
layer_5_4 = 20
img_input = keras.layers.Input(shape=(None, None, 1))
img_conc = keras.layers.Concatenate()([img_input, img_input, img_input]) 
vgg = VGG19(weights='imagenet', input_tensor=img_conc, include_top=False)
perceptual_model = Model(vgg.input, vgg.layers[layer_5_4].output)

In [9]:
"""
Entropía cruzada binaria
    Función de pérdida que determina qué tanto por ciento se parece la imagen generada a la imagen real.
"""
binary_cross_entropy = BinaryCrossentropy()

""" 
Error cuadrático medio
    Función de pérdida para medir el error cuadrático medio de las predicciones del modelo en relación 
    con los valores verdaderos. 
"""
mean_squared_error = MeanSquaredError()

In [10]:
""" 
Tasa de aprendizaje
    Se definen los valores que va a tomar el learning rate, indicando que a partir de la iteración 100.000, 
    el learning rate va a disminuir de 1e-4 a 1e-5. 
"""
learning_rate=tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[100000], values=[1e-4, 1e-5])

In [11]:
""" 
Funciones de optimización
        Se definen las funciones de optimización del generador y discriminador.
        En este caso, se utilizan las funciones Adam en ambos casos. 
"""
generator_optimizer = Adam(learning_rate=learning_rate)
discriminator_optimizer = Adam(learning_rate=learning_rate)

Se definen los checkpoints que van a ir salvando el entrenamiento del modelo (generador y discriminador).

In [12]:
srgan_checkpoint_dir=f'checkpoints/srgan_{dataset_name}'

srgan_checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                       generator_optimizer=generator_optimizer,
                                       discriminator_optimizer=discriminator_optimizer,
                                       generator=generator,
                                       discriminator=discriminator)

srgan_checkpoint_manager = tf.train.CheckpointManager(checkpoint=srgan_checkpoint,
                                                directory=srgan_checkpoint_dir,
                                                max_to_keep=3)

In [13]:
if srgan_checkpoint_manager.latest_checkpoint:
    srgan_checkpoint.restore(srgan_checkpoint_manager.latest_checkpoint)
    print(f'Modelo restaurado con checkpoint en el paso {srgan_checkpoint.step.numpy()}.')

In [14]:
@tf.function
def train_step(lr, hr):
    # Se calculan los gradientes de la función de perdida del generador y del discriminador.
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        
        lr = tf.cast(lr, tf.float32)
        hr = tf.cast(hr, tf.float32)
        
        sr = srgan_checkpoint.generator(lr, training=True)

        hr_output = srgan_checkpoint.discriminator(hr, training=True)
        sr_output = srgan_checkpoint.discriminator(sr, training=True)

        mse = calculate_content_loss(hr, sr)
        gen_loss = calculate_generator_loss(sr_output)
        perc_loss = mse + 0.001 * gen_loss
        hr_loss, sr_loss = calculate_discriminator_loss(hr_output, sr_output)
        disc_loss = hr_loss + sr_loss

    gradients_of_generator = gen_tape.gradient(perc_loss, srgan_checkpoint.generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, srgan_checkpoint.discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, srgan_checkpoint.generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, srgan_checkpoint.discriminator.trainable_variables))

    return perc_loss, hr_loss, sr_loss, disc_loss

@tf.function
def calculate_content_loss(hr, sr):
    sr_features = perceptual_model(sr)
    hr_features = perceptual_model(hr)
    return mean_squared_error(hr_features, sr_features)

def calculate_generator_loss(sr_out):
    return binary_cross_entropy(tf.ones_like(sr_out), sr_out)

def calculate_discriminator_loss(hr_out, sr_out):
    hr_loss = binary_cross_entropy(tf.ones_like(hr_out), hr_out)
    sr_loss = binary_cross_entropy(tf.zeros_like(sr_out), sr_out)
    return hr_loss, sr_loss

In [15]:
gen_loss_array = np.array([])
disc_loss_array = np.array([])
hr_loss_array = np.array([])
sr_loss_array = np.array([])
rmse_error_array = np.array([])
time_array = np.array([])

start_time = time.time()
step = srgan_checkpoint.step.numpy()

# Se crea el directorio de resultados de entrenamiento.
os.makedirs(training_results_folder, exist_ok=True)

for lr, hr in train_dataset.take(steps - step):

    srgan_checkpoint.step.assign_add(1)
    step = srgan_checkpoint.step.numpy()

    # Se lleva a cabo una iteración de entrenamiento.
    perceptual_loss, hr_loss, sr_loss, discriminator_loss = train_step(lr, hr)

    if step % 1000 == 0:
        
        for lr, hr in valid_dataset_subset:
            sr = srgan_checkpoint.generator.predict(lr)[0]

        image_hr = Image.fromarray(hr.numpy().squeeze())

        if step == 1000:
            
            lr = tf.clip_by_value(lr, 0, 255)
            lr = tf.round(lr)
            lr = tf.cast(lr, tf.uint8)
            
            hr = tf.clip_by_value(hr, 0, 255)
            hr = tf.round(hr)
            hr = tf.cast(hr, tf.uint8)
            
            image = Image.fromarray(lr.numpy().squeeze())
            image.save(f"{training_results_folder}/low_res_image.jpeg" )
            
            image_hr.save(f"{training_results_folder}/high_res_image.jpeg" )
        
        sr = tf.clip_by_value(sr, 0, 255)
        sr = tf.round(sr)
        sr = tf.cast(sr, tf.uint8)
        
        image_sr = Image.fromarray(sr.numpy().squeeze())
        image_sr.save(f"{training_results_folder}/{step}.jpeg" )

        rmse = rmse_metric(image_sr, image_hr)

        train_time = time.time()
            
        rmse_error_array = np.append(rmse_error_array, rmse)
        disc_loss_array = np.append(disc_loss_array, discriminator_loss)
        hr_loss_array = np.append(hr_loss_array, hr_loss)
        sr_loss_array = np.append(sr_loss_array, sr_loss)
        gen_loss_array = np.append(gen_loss_array, perceptual_loss)
        current_time = train_time - start_time
        time_array = np.append(time_array, current_time)
        
        print(f'{step}/{steps}, perceptual loss = {perceptual_loss:.4f}, discriminator loss = {discriminator_loss:.4f}, RMSE = {rmse:.4f} ({current_time} s)')

        srgan_checkpoint_manager.save()

2023-06-04 20:44:41.315282: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8100
2023-06-04 20:44:42.493823: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:630] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.




2023-06-04 20:45:18.295030: W tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.12GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2023-06-04 20:45:18.295052: W tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.12GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.




2023-06-04 20:45:18.889949: W tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 7.03GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2023-06-04 20:45:18.889971: W tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 7.03GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.




2023-06-04 20:45:19.299057: W tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.55GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2023-06-04 20:45:19.299079: W tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.55GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.




2023-06-04 20:45:19.710762: W tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 7.17GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2023-06-04 20:45:19.710784: W tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 7.17GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.




2023-06-04 20:45:22.795004: W tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.83GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2023-06-04 20:45:22.795024: W tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.83GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


1000/200000, perceptual loss = 444.0130, discriminator loss = 0.0000, RMSE = 0.2614 (45.63327693939209 s)
2000/200000, perceptual loss = 242.6172, discriminator loss = 0.0000, RMSE = 0.2953 (80.13908195495605 s)
3000/200000, perceptual loss = 185.3007, discriminator loss = 0.0000, RMSE = 0.3803 (114.86024379730225 s)
4000/200000, perceptual loss = 320.7060, discriminator loss = 0.0000, RMSE = 0.3600 (149.47192573547363 s)
5000/200000, perceptual loss = 235.2681, discriminator loss = 0.0000, RMSE = 0.2771 (184.0727355480194 s)
6000/200000, perceptual loss = 194.4716, discriminator loss = 0.1843, RMSE = 0.1712 (218.53023958206177 s)
7000/200000, perceptual loss = 139.7874, discriminator loss = 0.0000, RMSE = 0.1364 (252.99342846870422 s)
8000/200000, perceptual loss = 181.3584, discriminator loss = 0.0002, RMSE = 0.1457 (287.55118560791016 s)
9000/200000, perceptual loss = 181.2562, discriminator loss = 0.0000, RMSE = 0.1284 (322.13104486465454 s)
10000/200000, perceptual loss = 156.9196

Se guardan los pesos de la red generadora.

In [16]:
weights_directory = f"weights/srgan_{dataset_name}"
os.makedirs(weights_directory, exist_ok=True)
weights_file = f'{weights_directory}/generador.h5'
srgan_checkpoint.generator.save_weights(weights_file)

Se guardan los resultados en un fichero csv.

In [19]:
data = {
    'Time': time_array,
    'RMSE Error': rmse_error_array,
    'Generator Loss': gen_loss_array,
    'Discriminator Loss': disc_loss_array,
    'HR Loss': hr_loss_array,
    'SR Loss': sr_loss_array
}

df = pd.DataFrame(data)
df.to_csv('metrics.csv', index=False)