In [16]:
import keras
from keras import layers
import numpy as np
from load_data import load_data
from tensorflow.nn import depth_to_space

In [2]:
package_path = r"C:\Users\nedst\Desktop\synoptic-project-NedStickler\.venv\Lib\site-packages\tensorflow_datasets"
dataset, _ = load_data(package_path)
dataset = dataset[:2048, :, :, :]

In [3]:
residual_blocks = 5

In [20]:
def d_residual_block(x, n_filters, n_strides):
    x = layers.Conv2D(n_filters, kernel_size=3, strides=n_strides, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    return x

def d_downsample_pair(x, n_filters):
    x = d_residual_block(x, n_filters, 1)
    x = d_residual_block(x, n_filters, 2)
    return x

def g_residual_block(x_in):
    x = layers.Conv2D(64, kernel_size=3, padding="same")(x_in)
    x = layers.BatchNormalization()(x)
    x = layers.PReLU(shared_axes=[1, 2])(x)
    x = layers.Conv2D(64, kernel_size=3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Add()(x_in, x)
    return x

def discriminator():
    # HR/SR input
    inputs = layers.Input((None, None, 3))
    x = layers.Rescaling(scale=1/127.5, offset=-1)(inputs)
    
    # First convolution blocks
    x = layers.Conv2D(64, kernel_size=3, padding="same")(x)
    x = layers.LeakyReLU(0.2)(x)
    
    # Residual downsampling blocks
    x = d_residual_block(x, 64, 2)
    x = d_downsample_pair(x, 128)
    x = d_downsample_pair(x, 256)
    x = d_downsample_pair(x, 512)

    # Flatten and classify
    x = layers.Flatten()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dense(1, activation="sigmoid")(x)

    return keras.Model(inputs, x)

def generator(residual_blocks):
    # LR input
    inputs = layers.Input((None, None, 3))
    x_in = layers.Rescaling(scale=1/255)(inputs)

    # First convolution
    x_in = layers.Conv2D(64, kernel_size=3, padding="same")(x_in)
    x_in = x = layers.PReLU(shared_axes=[1, 2])(x_in)

    # Residual block set
    for _ in range(residual_blocks):
        x = g_residual_block(x)
    
    # Residual block without activation functions
    x = layers.Conv2D(64, kernel_size=3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Add()[x_in, x]

    # Upscaling blocks
    x = layers.Conv2D(256, kernel_size=3, padding="same")(x)
    x = depth_to_space(x, 2)
    x = layers.PReLU(shared_axes=[1, 2])(x)
    
    x = layers.Conv2D(256, kernel_size=3, padding="same")(x)
    x = depth_to_space(x, 2)
    x = layers.PReLU(shared_axes=[1, 2])(x)

    # Final convolve
    x = layers.Conv2D(3, kernel_size=3, padding="same")(x)
    x = layers.Rescaling(scale=127.5, offset=127.5)(x)

    return keras.Model(inputs, x)

In [21]:
class SRGAN(keras.Model):
    def __init__(self, discriminator, generator, vgg):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator
        self.vgg = vgg
    
    def compile(self, d_optimiser, g_optimiser, bce_loss, mse_loss):
        super().compile()
        self.d_optimiser = d_optimiser
        self.g_optimiser = g_optimiser
        self.bce_loss = bce_loss
        self.mse_loss = mse_loss