In [None]:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

# configure the GPU
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
def load_vgg(hr_shape):
    vgg = keras.applications.VGG19(include_top=False, weights="imagenet", input_shape=hr_shape)
    # vgg.load_weights('./models/VGG19/weights.h5')
    vgg.trainable = False
    return keras.Model(inputs=vgg.input, outputs=vgg.layers[10].output, name='vgg')

In [None]:
hr_shape = (None, None, 3)

In [None]:
# B residual block
def res_block(_input):
    # k3n64s1
    res_model = layers.Conv2D(64, (3, 3), padding='same')(_input)
    res_model = layers.BatchNormalization(momentum=0.5)(res_model)
    res_model = layers.PReLU(shared_axes=[1,2])(res_model)

    # k3n64s1
    res_model = layers.Conv2D(64, (3, 3), padding='same')(res_model)
    res_model = layers.BatchNormalization(momentum=0.5)(res_model)

    return layers.add([_input, res_model])

def upscale_block(_input):
    # k3n256s1
    up_model = layers.Conv2D(256, (3, 3), padding='same')(_input)
    up_model = layers.UpSampling2D(size=2)(up_model)
    up_model = layers.PReLU(shared_axes=[1,2])(up_model)

    return up_model

def create_generator(_input, num_residual_blocks=16):

    # Initial Convolution

    # k9n64s1
    gen_model = layers.Conv2D(64, (9, 9), padding='same')(_input)
    gen_model = layers.PReLU(shared_axes=[1,2])(gen_model)

    temp = gen_model

    # Residual Blocks
    for i in range(num_residual_blocks):
        gen_model = res_block(gen_model)

    # Post Residual Blocks
    
    # k3n64s1
    gen_model = layers.Conv2D(64, (3, 3), padding='same')(gen_model)
    gen_model = layers.BatchNormalization(momentum=0.5)(gen_model)
    gen_model = layers.add([gen_model, temp])

    # Upsampling

    gen_model = upscale_block(gen_model)
    gen_model = upscale_block(gen_model)

    # Output

    # k9n3s1
    output = layers.Conv2D(3, (9, 9), padding='same')(gen_model)

    return keras.Model(inputs=_input, outputs=output, name='generator')

In [None]:
def discriminator_block(_input, filters, strides=1, bn=True):
    d_model = layers.Conv2D(filters, (3, 3), strides=strides, padding='same')(_input)
    if bn:
        d_model = layers.BatchNormalization(momentum=0.8)(d_model)
    d_model = layers.LeakyReLU(alpha=0.2)(d_model)
    return d_model

def discriminator(_input):
    # k3n64s1
    d_model = discriminator_block(_input, 64, bn=False)
    # k3n64s2
    d_model = discriminator_block(d_model, 64, strides=2)
    # k3n128s1
    d_model = discriminator_block(d_model, 128)
    # k3n128s2
    d_model = discriminator_block(d_model, 128, strides=2)
    # k3n256s1
    d_model = discriminator_block(d_model, 256)
    # k3n256s2
    d_model = discriminator_block(d_model, 256, strides=2)
    # k3n512s1
    d_model = discriminator_block(d_model, 512)
    # k3n512s2
    d_model = discriminator_block(d_model, 512, strides=2)

    d_model = layers.Flatten()(d_model)
    d_model = layers.Dense(1024)(d_model)
    d_model = layers.LeakyReLU(alpha=0.2)(d_model)
    output = layers.Dense(1, activation="sigmoid")(d_model)
    
    return keras.Model(inputs=_input, outputs=output, name='discriminator')

In [None]:
from PIL import Image

In [None]:
lr_dim = 64
hr_dim = 256
load_img_dim = 1024
epochs = 100
scale_factor = hr_dim // lr_dim
batch_size = 1

In [None]:
hr_dataset = keras.preprocessing.image_dataset_from_directory(
    directory="./dataset/DIV2K_train_HR",
    labels=None,
    image_size=(load_img_dim, load_img_dim),
    batch_size=batch_size,
    shuffle=True,
    color_mode="rgb",
).map(lambda x: x/255)

In [None]:
# enlarge the image to the dim_ratio
def unpack(img, dim_ratio):
    img = img * 255
    img = img.astype(np.uint8)
    img = Image.fromarray(img)
    img = img.resize((int(img.size[0] * dim_ratio), int(img.size[1])))
    return img

# enlarges the image in the given axis to the dim_ratio
def unpack_in_axis(img, dim_ratio, axis):
    img = img * 255
    img = img.astype(np.uint8)
    img = Image.fromarray(img)

    if axis == 1:
        img = img.resize((int(img.size[0] * dim_ratio), int(img.size[1])))

    if axis == 0:
        img = img.resize((int(img.size[0]), int(img.size[1] * dim_ratio)))
        
    return img


# squeezes the image such that both dimensions are divisible by 64
def pack(img):
    pack_axis = np.argmin(img.shape[:-1])
    lower_dim = img.shape[pack_axis]
    dim_ratio = img.shape[0] / img.shape[1]
    # resize image to the lower dimension
    img = img * 255
    img = img.astype(np.uint8)
    img = Image.fromarray(img)
    img = img.resize((int(lower_dim), int(lower_dim)))  
    img = np.array(img)
    img = img.astype(np.float32)

    return img/255, dim_ratio, pack_axis
    

In [None]:
def create_windows(img):
    # create windows
    windows = []
    img = img[0]
    for x in range(0, img.shape[0], hr_dim):
        for y in range(0, img.shape[1], hr_dim):
            # temp = np.expand_dims(img[x:x+hr_dim, y:y+hr_dim, :], axis=0)
            # windows.append(temp)
            windows.append(img[x:x+hr_dim, y:y+hr_dim, :])
    return windows

In [None]:
items = hr_dataset.take(800)

In [None]:
from tqdm import tqdm
windowed_dataset = []
for item in tqdm(items):
    windowed_dataset.append(create_windows(item))

windowed_dataset = np.stack(windowed_dataset)

In [None]:
test.shape

In [None]:
plt.imshow(item[0])

In [None]:
def convolute(img, model):
    upscaled = []

    for x in range(0, img.shape[0], 64):
        for y in range(0, img.shape[1], 64):
            temp = np.expand_dims(img[x:x+64, y:y+64, :], axis=0)
            # temp = model(np.array(img[x:x+64, y:y+64, :]))
            temp = model(temp)
            upscaled.append(temp)

    return upscaled

In [None]:
# stitch the windows into one image
def stitch(windows, dims=None):
    # create empty image of the right size
    img = np.zeros(dims)
    
    # loop through windows and add them to the image
    x, y = 0, 0
    for i in range(len(windows)):
        # print(i)
        img[x:x+256, y:y+256, :] = windows[i][0]
        y += 256
        if y >= dims[0]:
            y = 0
            x += 256  

    return img

In [None]:
class SRGAN(keras.Model):
    def __init__(self, discriminator, generator, vgg):
        super(SRGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.vgg = vgg

    def compile(self, loss_mse, loss_bce, opt_dis, opt_gen):
        super(SRGAN, self).compile()
        self.loss_mse = loss_mse
        self.loss_bce = loss_bce
        self.opt_dis = opt_dis
        self.opt_gen = opt_gen
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")
        self.p_loss_metric = keras.metrics.Mean(name="p_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def create_windows(self, img):
        # create windows
        windows = []
        img = img[0]
        for x in range(0, img.shape[0], hr_dim):
            for y in range(0, img.shape[1], hr_dim):
                windows.append(img[x:x+hr_dim, y:y+hr_dim, :])
        return np.stack(windows)

    # stitch the windows into one image
    def stitch(self, windows, dims=None):
        # create empty image of the right size
        img = np.zeros(dims)
        
        # loop through windows and add them to the image
        x, y = 0, 0
        for i in range(len(windows)):
            # print(i)
            img[x:x+256, y:y+256, :] = windows[i][0]
            y += 256
            if y >= dims[0]:
                y = 0
                x += 256  

        return img

    def downscale(self, x, factor=2):
        return tf.image.resize(x, (x.shape[1] // factor, x.shape[2] // factor), method="area")

    def train_step(self, hr_img):
        image_windows = self.create_windows(hr_img)
        lr_img = self.downscale(hr_img, scale_factor)

        batch_size = tf.shape(hr_img)[0]
        fake_labels = tf.zeros((batch_size, 1))
        real_labels = tf.ones((batch_size, 1))
        
        # add noise to labels 
        labels = tf.concat([fake_labels, real_labels], axis=0)
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
            fake_images = self.generator(lr_img)
            
            # train discriminator
            self.discriminator.trainable = True
            d_hr_out = self.discriminator(hr_img)
            d_fake_out = self.discriminator(fake_images)

            dis_output = tf.concat([d_hr_out, d_fake_out], axis=0)

            d_loss = self.loss_bce(dis_output, labels)
            # sum all values in d_loss  
            d_loss = tf.reduce_sum(d_loss)
            
            # adversarial loss
            adv_loss = self.loss_bce(tf.ones_like(d_fake_out), d_fake_out)

            # content loss
            gen_features = self.vgg(fake_images)/12.75
            hr_features = self.vgg(hr_img)/12.75

            content_loss = self.loss_mse(hr_features, gen_features)

            # total loss (perceptual loss function = content loss + adversarial loss)
            total_loss = content_loss + (1e-3 * adv_loss)

        gradients_gen = gen_tape.gradient(total_loss, self.generator.trainable_variables)
        gradients_dis = dis_tape.gradient(d_loss, self.discriminator.trainable_variables)

        self.opt_gen.apply_gradients(zip(gradients_gen, self.generator.trainable_variables))
        self.opt_dis.apply_gradients(zip(gradients_dis, self.discriminator.trainable_variables))

        # Update metrics
        self.d_loss_metric(d_loss)
        self.g_loss_metric(total_loss)
        self.p_loss_metric(content_loss)

        return {
            "d_loss": d_loss,
            "g_loss": total_loss,
            "p_loss": content_loss
        }