In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, PReLU, LeakyReLU, Activation, BatchNormalization, Add
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam # Optimizer
from tensorflow.keras.applications.vgg19 import VGG19 # Perceptual loss
from tensorflow.keras.losses import MeanSquaredError # Loss function
from tensorflow.image import psnr, ssim # Evaluation metrics
from tensorflow.keras.utils import Progbar

2023-04-12 03:32:30.304465: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1


In [2]:
# Image dimension (height, width, color)
lr_shape = (255, 255, 3)
hr_shape = (1020, 1020, 3)

In [3]:
# Build VGG19 model to use as perceptual loss (10 layers)
def build_vgg():
    
    vgg = VGG19(weights='imagenet', include_top=False, input_shape=hr_shape)
    
    return Model(inputs=vgg.inputs, outputs=vgg.layers[10].output)

In [4]:
vgg = build_vgg()
vgg.summary()

2023-04-12 03:32:36.876469: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2023-04-12 03:32:36.877812: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2023-04-12 03:32:36.955119: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:12:00.0 name: Tesla V100-PCIE-32GB computeCapability: 7.0
coreClock: 1.38GHz coreCount: 80 deviceMemorySize: 31.75GiB deviceMemoryBandwidth: 836.37GiB/s
2023-04-12 03:32:36.955697: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 1 with properties: 
pciBusID: 0000:13:00.0 name: Tesla V100-PCIE-32GB computeCapability: 7.0
coreClock: 1.38GHz coreCount: 80 deviceMemorySize: 31.75GiB deviceMemoryBandwidth: 836.37GiB/s
2023-04-12 03:32:36.956213: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 2 with properties: 
pciBusID: 0000:14:00.0 name: Tesl

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 1020, 1020, 3)]   0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 1020, 1020, 64)    1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 1020, 1020, 64)    36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 510, 510, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 510, 510, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 510, 510, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 255, 255, 128)     0     

In [5]:
# Residual block
def residual_block(x):

    input_tensor = x
    
    x = Conv2D(64, kernel_size=3, padding = "same")(x)
    x = BatchNormalization()(x)
    x = PReLU(shared_axes = [1,2])(x)

    x = Conv2D(64, kernel_size=3, padding = "same")(x)
    x = BatchNormalization()(x)

    return Add()([input_tensor, x])

In [6]:
# Upscale the image 2x
def upscale_block(x):   
    
    x = Conv2DTranspose(256, kernel_size=3, strides=2, padding="same")(x)
    x = PReLU(shared_axes=[1,2])(x)

    return x

In [7]:
num_residual_block = 16

In [8]:
# Generator Model
def build_generator():
    
    lr_input = Input(shape=lr_shape)
    
    x = Conv2D(64, kernel_size=9, padding="same")(lr_input)
    x = PReLU(shared_axes=[1,2])(x)
    temp = x
    
    for i in range(num_residual_block):
        x = residual_block(x)
        
    x = Conv2D(64, kernel_size=3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Add()([x, temp])
    
    x = upscale_block(x)
    x = upscale_block(x)
    
    hr_image = Conv2D(3, kernel_size=9, padding="same")(x)
    
    return Model(inputs=lr_input, outputs=hr_image)

In [9]:
gen = build_generator()
gen.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 255, 255, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 255, 255, 64) 15616       input_2[0][0]                    
__________________________________________________________________________________________________
p_re_lu (PReLU)                 (None, 255, 255, 64) 64          conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 255, 255, 64) 36928       p_re_lu[0][0]                    
____________________________________________________________________________________________

In [10]:
# Discriminator model
def build_discriminator():
    # Input layer
    hr_input = Input(shape=hr_shape)

    # First block
    x = Conv2D(filters=64, kernel_size=3, strides=1, padding='same')(hr_input)
    x = LeakyReLU(alpha=0.2)(x)

    # Second block
    x = Conv2D(filters=64, kernel_size=3, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization()(x)

    # Third block
    x = Conv2D(filters=128, kernel_size=3, strides=1, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization()(x)

    # Fourth block
    x = Conv2D(filters=128, kernel_size=3, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization()(x)

    # Fifth block
    x = Conv2D(filters=256, kernel_size=3, strides=1, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization()(x)

    # Sixth block
    x = Conv2D(filters=256, kernel_size=3, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization()(x)

    # Seventh block
    x = Conv2D(filters=512, kernel_size=3, strides=1, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization()(x)

    # Eighth block
    x = Conv2D(filters=512, kernel_size=3, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization()(x)

    # Flatten
    x = Flatten()(x)
    
    # Ninth block
    x = Dense(512)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    # Output layer
    output = Dense(1, activation='sigmoid')(x)

    # Model
    model = Model(inputs=hr_input, outputs=output)

    return model

In [11]:
dis = build_discriminator()
dis.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 1020, 1020, 3)]   0         
_________________________________________________________________
conv2d_35 (Conv2D)           (None, 1020, 1020, 64)    1792      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 1020, 1020, 64)    0         
_________________________________________________________________
conv2d_36 (Conv2D)           (None, 510, 510, 64)      36928     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 510, 510, 64)      0         
_________________________________________________________________
batch_normalization_33 (Batc (None, 510, 510, 64)      256       
_________________________________________________________________
conv2d_37 (Conv2D)           (None, 510, 510, 128)     7385

In [12]:
# SRGAN model
def build_srgan(gen, dis, vgg):
    
    dis.trainable = False
    vgg.trainable = False
    
    # Input layer
    lr_input = Input(shape=lr_shape)
    hr_input = Input(shape=hr_shape)

    # Generate high-resolution images
    hr_image = gen(lr_input)
    hr_features = vgg(hr_image)

    # Discriminator output for generated images
    hr_image_pred = dis(hr_image)

    # Model - lr_input -> gen / hr_input -> dis / dis -> hr_image_pred / vgg -> hr_features
    model = Model(inputs=[lr_input, hr_input], outputs=[hr_image_pred, hr_features])

    return model

In [13]:
srg = build_srgan(gen, dis, vgg)
srg.summary()

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 255, 255, 3) 0                                            
__________________________________________________________________________________________________
model_1 (Functional)            (None, 1020, 1020, 3 2044291     input_4[0][0]                    
__________________________________________________________________________________________________
input_5 (InputLayer)            [(None, 1020, 1020,  0                                            
__________________________________________________________________________________________________
model_2 (Functional)            (None, 1)            1078435649  model_1[0][0]                    
____________________________________________________________________________________________

In [14]:
# Determine evaluation metrics

# Determine color max range
max_val = 1.0

# Peak signal-to-noise ratio (PSNR)
def PSNR(y_true, y_pred):
    return psnr(y_true, y_pred, max_val=max_val)

# Structural similarity index measure (SSIM)
def SSIM(y_true, y_pred):
    return ssim(y_true, y_pred, max_val=max_val)

### Load images for training and evaluation

In [15]:
# Load the low-resolution images for training
lr_images_train = np.load('../Datasets/lr_images_bi_train.npy')/255.0
# lr_images_train = np.load('lr_images_bi_train.npy')[0:10]/255.0

# Load the high-resolution images for training
hr_images_train = np.load('../Datasets/hr_images_train.npy')/255.0
# hr_images_train = np.load('hr_images_train.npy')[0:10]/255.0

# Load the low-resolution images for validation
lr_images_val = np.load('../Datasets/lr_images_bi_val.npy')/255.0
# lr_images_val = np.load('lr_images_bi_val.npy')[0:50]/255.0

# Load the high-resolution images for validation
hr_images_val = np.load('../Datasets/hr_images_val.npy')/255.0
# hr_images_val = np.load('hr_images_val.npy')[0:50]/255.0

### Determine parameters

In [16]:
# n-th tries
model_name = 'gen'
data_type = 'bi'
n = 8

batch_size = 1
epochs = 25

In [17]:
val_features_1 = vgg.predict(hr_images_val[:50], batch_size=batch_size)
val_features_2 = vgg.predict(hr_images_val[50:], batch_size=batch_size)
val_features = np.append(val_features_1, val_features_2, axis=0)
val_real_labels = np.ones((val_features.shape[0], 1))

2023-04-12 03:33:02.030768: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2023-04-12 03:33:02.043304: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2500000000 Hz
2023-04-12 03:33:02.183372: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.7
2023-04-12 03:33:03.741631: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.10


In [18]:
dis.compile(optimizer=Adam(), loss='binary_crossentropy')
srg.compile(optimizer=Adam(), loss=['binary_crossentropy','mse'], loss_weights=[1e-3, 1], metrics=[['accuracy'], ['accuracy']])

In [19]:
# Train the models
for epoch in range(epochs):

    print('Epoch %d/%d' % (epoch+1, epochs))
    progbar = Progbar(lr_images_train.shape[0] // batch_size)

    for i in range(lr_images_train.shape[0] // batch_size):
        # Randomly sample a batch of images
        idx = np.random.randint(0, lr_images_train.shape[0], batch_size)
        lr_batch = lr_images_train[idx]
        hr_batch = hr_images_train[idx]

        # Generate a batch of high-resolution images
        sr_batch = gen.predict_on_batch(lr_batch)

        # Train the discriminator
        real_labels = np.ones((batch_size, 1))
        fake_labels = np.zeros((batch_size, 1))
        
        # Un-freeze the discriminator layers
        dis.trainable = True
        d_loss_real = dis.train_on_batch(hr_batch, real_labels)
        d_loss_fake = dis.train_on_batch(sr_batch, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        # Freeze the discriminator layers
        dis.trainable = False

        # Train the SRGAN
        hr_batch_features = vgg.predict(hr_batch)
        srg_loss, d_loss_s, g_loss, d_acc, g_acc = srg.train_on_batch([lr_batch, hr_batch], [real_labels, hr_batch_features])

        # Update the progress bar
        progbar.update(i+1, [('Dis', d_loss), ('Gen', g_loss), ('Dis Acc', d_acc), ('Gen Acc', g_acc)])

    # Evaluate the generator model on the validation dataset
    srg_loss_val, d_loss_s_val, g_loss_val, d_acc_val, g_acc_val = srg.evaluate([lr_images_val, hr_images_val], [val_real_labels, val_features], batch_size=batch_size)
    
    # Save the generator model every epochs
    gen.save('../model/%s%d_%s_%02dof%d.h5' % (model_name, n, data_type, epoch+1, epochs))

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
148/800 [====>.........................] - ETA: 11:43 - Dis: 0.6933 - Gen: 12.2178 - Dis Acc: 1.0000 - Gen Acc: 0.8539

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25


In [22]:
dis_fake_pred = dis.predict(sr_images_val, batch_size=batch_size)
dis_fake_pred[0:5]

array([[0.50186193],
       [0.5010398 ],
       [0.50284064],
       [0.5023785 ],
       [0.5033225 ]], dtype=float32)

In [23]:
dis_real_pred = dis.predict(hr_images_val, batch_size=batch_size)
dis_real_pred[0:5]

array([[0.5015513 ],
       [0.5007172 ],
       [0.5024384 ],
       [0.50227743],
       [0.503248  ]], dtype=float32)