In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras import layers, Model
from keras.models import load_model

# Defining Utility functions

In [None]:
Conv2D = layers.Conv2D
BatchNormalization = layers.BatchNormalization
PReLU = layers.PReLU
UpSampling2D = layers.UpSampling2D
Dense = layers.Dense
add = layers.add
LeakyReLU = layers.LeakyReLU
Input = layers.Input
Flatten = layers.Flatten


np.random.seed(10)
number_of_images =1000
hr_wdt, hr_hgt=256,256
downscale=4
input_dir="/content/drive/MyDrive/val2017"
number_of_images = 1000
train_test_ratio = 0.8

In [None]:
def load_path(path):
    directories = []
    if os.path.isdir(path):
        directories.append(path)
    for elem in os.listdir(path):
        if os.path.isdir(os.path.join(path,elem)):
            directories = directories + load_path(os.path.join(path,elem))
            directories.append(os.path.join(path,elem))
    return directories
    
def load_data_from_dirs(dirs, ext):
    files_hr, files_lr = [],[]
    file_names = []
    count = 0
    for d in dirs:
        for f in os.listdir(d): 
            if f.endswith(ext):
                image = cv2.imread(d+"/"+f)
                h, w, c = image.shape
                if h>hr_hgt and w>hr_wdt and count<number_of_images:
                  image = cv2.resize(image, (hr_wdt, hr_hgt))                  
                  files_hr.append(image)
                  image = cv2.resize(image, (hr_wdt//downscale, hr_hgt//downscale))
                  files_lr.append(image)
                  file_names.append(os.path.join(d,f))
                  count = count + 1
                  print(count)
                  if(count==number_of_images):
                    break
    return files_hr, files_lr
    
def load_training_testing_data(directory, ext, number_of_images = 1000, train_test_ratio = 0.8):

    number_of_train_images = int(number_of_images * train_test_ratio)
    
    files_hr, files_lr = load_data_from_dirs(load_path(directory), ext)
    
    if len(files_hr) < number_of_images:
        print("Number of image files are less then you specified")
        print("Please reduce number of images to %d" % len(files))
        sys.exit()
         

    x_train_hr = files_hr[:number_of_train_images]
    x_train_lr = files_lr[:number_of_train_images]
    
    x_test_hr = files_hr[number_of_train_images:number_of_images]
    x_test_lr = files_lr[number_of_train_images:number_of_images]
    
    x_train_hr = np.array(x_train_hr) / 255

    x_train_lr = np.array(x_train_lr) / 255
    
    x_test_hr = np.array(x_test_hr) / 255
    
    x_test_lr = np.array(x_test_lr) / 255
    
    return x_train_lr, x_train_hr, x_test_lr, x_test_hr


In [None]:
train_lr, train_hr, test_lr, test_hr= load_training_testing_data(input_dir, '.jpg', number_of_images , train_test_ratio)

#Defining the Architecture

###VGG for perceptual loss

In [None]:
from keras.applications.vgg19 import VGG19

def build_vgg():
    #vgg = VGG19(include_top=False, weights='imagenet', input_shape=hr_shape)
    #vgg.outputs = [vgg.layers[9].output]

    #img = Input(shape=hr_shape)

    #img_features = vgg(img)

    #vgg19 = VGG19(include_top=False, weights='imagenet', input_shape=hr_shape)
    #img = Input(shape=hr_shape)
    #vgg19.trainable = False
    #for l in vgg19.layers:
    #    l.trainable = False
    #img_features = Model(inputs=img, outputs=vgg19.get_layer('block5_conv4').output)

    #return Model(img, img_features)

    vgg = VGG19(weights="imagenet", input_shape= hr_shape, include_top = False)
    img = Input(shape=hr_shape)
    outputs = vgg.layers[9].output
    img_features = vgg(img)
    return Model(img, img_features)


###Blocks used in Generator and Discriminator

In [None]:
def res_block(ip):
    
    res_model = Conv2D(64, (3,3), padding = "same")(ip)
    res_model = BatchNormalization(momentum = 0.5)(res_model)
    res_model = PReLU(shared_axes = [1,2])(res_model)
    
    res_model = Conv2D(64, (3,3), padding = "same")(res_model)
    res_model = BatchNormalization(momentum = 0.5)(res_model)
    
    return add([ip,res_model])

def upscale_block(ip):
    
    up_model = Conv2D(256, (3,3), padding="same")(ip)
    up_model = UpSampling2D( size = 2 )(up_model)
    up_model = PReLU(shared_axes=[1,2])(up_model)
    
    return up_model

def discriminator_block(ip, filters, strides=1, bn=True):
    
    disc_model = Conv2D(filters, (3,3), strides = strides, padding="same")(ip)
    disc_model = LeakyReLU( alpha=0.2 )(disc_model)
    if bn:
        disc_model = BatchNormalization( momentum=0.8 )(disc_model)

    
    return disc_model

###Generator

In [None]:
def create_gen(gen_ip):
    layers = Conv2D(64, (9,9), padding="same")(gen_ip)
    layers = PReLU(shared_axes=[1,2])(layers)

    temp = layers

    for i in range(num_res_block):
        layers = res_block(layers)

    layers = Conv2D(64, (3,3), padding="same")(layers)
    layers = BatchNormalization(momentum=0.5)(layers)
    layers = add([layers,temp])

    layers = upscale_block(layers)
    layers = upscale_block(layers)

    op = Conv2D(3, (9,9), padding="same")(layers)

    return Model(inputs=gen_ip, outputs=op)

###Discriminator

In [None]:
def create_disc(disc_ip):

    df = 64
    
    d1 = discriminator_block(disc_ip, df, bn=False)
    d2 = discriminator_block(d1, df, strides=2)
    d3 = discriminator_block(d2, df*2)
    d4 = discriminator_block(d3, df*2, strides=2)
    d5 = discriminator_block(d4, df*4)
    d6 = discriminator_block(d5, df*4, strides=2)
    d7 = discriminator_block(d6, df*8)
    d8 = discriminator_block(d7, df*8, strides=2)
    
    d8_5 = Flatten()(d8)
    d9 = Dense(df*16)(d8_5)
    d10 = LeakyReLU(alpha=0.2)(d9)
    validity = Dense(1, activation='sigmoid')(d10)

    return Model(disc_ip, validity)

###GAN

In [None]:
def create_comb(gen_model, disc_model, vgg, lr_ip, hr_ip):
    gen_img = gen_model(lr_ip)
    
    gen_features = vgg(gen_img)
    
    disc_model.trainable = False
    validity = disc_model(gen_img)
    
    return Model(inputs=[lr_ip, hr_ip], outputs=[validity, gen_features])

#Training the network

###Loading data in batches

In [None]:
batch_size = 16
epochs_pretraining_generator=300
epochs_training_gan=500
train_lr_batches = []
train_hr_batches = []
image_shape = (hr_wdt, hr_hgt,3)
for it in range(int(train_hr.shape[0] / batch_size)):
    start_idx = it * batch_size
    end_idx = start_idx + batch_size
    train_hr_batches.append(train_hr[start_idx:end_idx])
    train_lr_batches.append(train_lr[start_idx:end_idx])


num_res_block = 8
hr_shape = (train_hr.shape[1], train_hr.shape[2], train_hr.shape[3])
lr_shape = (train_lr.shape[1], train_lr.shape[2], train_lr.shape[3])

lr_ip = Input(shape=lr_shape)
hr_ip = Input(shape=hr_shape)

generator = create_gen(lr_ip)
generator.compile(loss="mse", optimizer="adam")

###Pre-training Generator

* The GAN losses don't stabilize unless you first pretrain the generator. 
* The network will still end up being able to improve image quality if you don't, but it will be only from the VGG loss and the GAN part will be basically useless. 
* First we will train the generator to minimize the MSE between the training inputs and the training targets with no GAN.
* Once this model is trained, removed we will continue with training (using VGG perceptual loss + GAN loss and training the descriminator in the training loop)
For more info (https://www.fast.ai/2019/05/03/decrappify/)

In [None]:
def pre_train_gen(epochs, batch_size, train_hr_batches, train_lr_batches,generator):
  for e in range(epochs):
    
    gen_losses = []
    for b in range(len(train_hr_batches)):
        lr_imgs = train_lr_batches[b]
        hr_imgs = train_hr_batches[b]
        
        d_loss_gen = generator.train_on_batch(lr_imgs, hr_imgs)
        
        gen_losses.append(d_loss_gen)
        
    gen_losses = np.array(gen_losses)
    
    gen_loss = np.sum(gen_losses, axis=0) / len(gen_losses)
    
    print("epoch:", e+1 ,"gen_loss:", gen_loss)

    if (e+1) % 5 == 0:
        generator.save("/content/output/gen/pre_trained_e_"+ str(e+1) +".h5")

pre_train_gen(epochs_pretraining_generator, batch_size, train_hr_batches, train_lr_batches,generator)

###Loading the pretrained generator and training the complete GAN

In [None]:
#Load the pretrained generator model
generator = load_model("/content/output/gen/pre_trained_e_45.h5")
#generator = create_gen(lr_ip)
discriminator = create_disc(hr_ip)
discriminator.compile(loss="binary_crossentropy", optimizer="adam", metrics=['accuracy'])

vgg = build_vgg()
vgg.trainable = False

gan_model = create_comb(generator, discriminator, vgg, lr_ip, hr_ip)
gan_model.compile(loss=["binary_crossentropy","mse"], loss_weights=[1e-3, 1], optimizer="adam")

In [None]:
def train_gan(epochs, batch_size, train_hr_batches, train_lr_batches,generator,discriminator):
  for e in range(epochs):
    
    gen_label = np.zeros((batch_size, 1))
    real_label = np.ones((batch_size,1))
    g_losses = []
    d_losses = []
    for b in range(len(train_hr_batches)):
        lr_imgs = train_lr_batches[b]
        hr_imgs = train_hr_batches[b]
        
        gen_imgs = generator.predict_on_batch(lr_imgs)
        
        discriminator.trainable = True
        d_loss_gen = discriminator.train_on_batch(gen_imgs, gen_label)
        d_loss_real = discriminator.train_on_batch(hr_imgs, real_label)
        discriminator.trainable = False
        
        d_loss = 0.5 * np.add(d_loss_gen, d_loss_real) 
        
        image_features = vgg.predict(hr_imgs)

        
        g_loss, _, _ = gan_model.train_on_batch([lr_imgs, hr_imgs], [real_label, image_features])
        
        d_losses.append(d_loss)
        g_losses.append(g_loss)
        
    g_losses = np.array(g_losses)
    d_losses = np.array(d_losses)
    
    g_loss = np.sum(g_losses, axis=0) / len(g_losses)
    d_loss = np.sum(d_losses, axis=0) / len(d_losses)
    
    print("epoch:", e+1 ,"g_loss:", g_loss, "d_loss:", d_loss)

    if (e+1) % 20 == 0:
        discriminator.save_weights("/content/output/disc/e_"+ str(e+1) +".h5")
        generator.save_weights("/content/output/gen/e_"+ str(e+1) +".h5")
train_gan(epochs_training_gan, batch_size, train_hr_batches, train_lr_batches,generator,discriminator)

In [None]:
test_gen = generator.predict_on_batch(test_lr)
plt.imshow(test_gen[4])
plt.show()
plt.imshow(test_hr[4])
plt.show()