In [None]:
import os
import time
import datetime
import tensorflow as tf
from numpy.random import randint
from numpy import load, zeros, ones
from tensorflow.keras.models import Sequential,Model
from tensorflow.keras.layers import Input, Conv2D, LeakyReLU, BatchNormalization, Concatenate, Activation
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.optimizers import Adam

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import shutil
shutil.copy("/content/drive/", "")

In [None]:
def discriminator(image_shape):
    init = RandomNormal(stddev=0.02)
    
    model = Sequential()

    model.add(Input(shape=image_shape))
    model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2D(1, (4,4), padding='same', kernel_initializer=init))
    model.add(Activation('sigmoid'))

    opt = Adam(lr=2e-5, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])

    return model

In [None]:
def generator(image_shape = (256, 256, 3)):
    init = RandomNormal(stddev=0.01)
    in_image = Input(shape=image_shape)

    e = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)
    e = BatchNormalization()(e, training=True)
    e3 = LeakyReLU(alpha=0.2)(e)

    e = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(e3)
    e = BatchNormalization()(e, training=True)
    e2 = LeakyReLU(alpha=0.2)(e)
    
    e = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(e2)
    e = BatchNormalization()(e, training=True)
    e1 = LeakyReLU(alpha=0.2)(e)

    for _ in range(3):
        e = Conv2D(256, (3,3), padding='same', kernel_initializer=init)(e1)
        e = BatchNormalization()(e, training=True)
        e = LeakyReLU(alpha=0.2)(e)

        e = Conv2D(256, (3,3), padding='same', kernel_initializer=init)(e)
        e = BatchNormalization()(e, training=True)

        e1 = Concatenate()([e, e1])

    d = UpSampling2D((2, 2))(d1)
    d = Conv2D(128, (1, 1), kernel_initializer=init)(d)
    d = Dropout(0.5)(d, training=True)
    d = Concatenate()([d, e2])
    d = BatchNormalization()(d, training=True)
    d = LeakyReLU(alpha=0.2)(d)

    d = UpSampling2D((2, 2))(d)
    d = Conv2D(64, (1, 1), kernel_initializer=init)(d)
    d = Dropout(0.5)(d, training=True)
    d = Concatenate()([d, e3])
    d = BatchNormalization()(d, training=True)
    d = LeakyReLU(alpha=0.2)(d)

    d = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(d)
    d = BatchNormalization()(d, training=True)
    out_im = Activation('tanh')(d)

    model = Model(in_image, out_im)
    return model

In [None]:
def define_gan(generator_model, discriminator_model, image_shape):
    # Make weights in the discriminator not trainable
    for layer in discriminator_model.layers:
        if not isinstance(layer, BatchNormalization):
            layer.trainable = False

    # Define the source image
    input_source_image = Input(shape=image_shape)

    # Connect the source image to the generator input
    generated_output = generator_model(input_source_image)

    # Connect the source input and generator output to the discriminator input
    discriminator_output = discriminator_model([input_source_image, generated_output])

    # Source image as input, generated image, and classification output
    gan_model = Model(input_source_image, [discriminator_output, generated_output])

    # Compile the model
    optimizer = Adam(lr=0.0002, beta_1=0.5)
    # Use binary cross-entropy for the discriminator and mean absolute error for the generator
    gan_model.compile(loss=['binary_crossentropy', 'mae'], optimizer=optimizer, loss_weights=[1, 100])

    return gan_model

In [None]:
def load_real_samples(filename):
    # load compressed arrays
    data = load(filename)
    # unpack arrays
    M1, M2 = data['arr_0'], data['arr_1']
    # scale from [0,255] to [-1,1]
    M1 = (M1 - 127.5) / 127.5
    M2 = (M2 - 127.5) / 127.5
    return [M2, M1]

In [None]:
def generate_real_samples(dataset, n_samples, patch_shape):
    trainA, trainB = dataset
    x = randint(0, trainA.shape[0], n_samples)
    M1, M2 = trainA[x], trainB[x]
    y = ones((n_samples, patch_shape, patch_shape, 1))
    return [M1, M2], y

In [None]:
def generate_fake_samples(g_model, samples, patch_shape):
    X = g_model.predict(samples)
    y = zeros((len(X), patch_shape, patch_shape, 1))
    return X, y

In [None]:
def save_model(step, g_model, d_model, dataset, n_samples=3):
    # save the generator model
    filename2 = model_output + 'gen_model_%06d.h5' % (step+1)
    g_model.save(filename2)
    # save the discriminator model
    filename3 = model_output + 'disc_model_%06d.h5' % (step+1)
    d_model.save(filename3)
    print('[.] Saved Step : %s' % (filename1))
    print('[.] Saved Model: %s' % (filename2))
    print('[.] Saved Model: %s' % (filename3))

In [None]:
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=12):
    n_patch = d_model.output_shape[1]
    trainA, trainB = dataset
    bat_per_epo = int(len(trainA) / n_batch)
    n_steps = bat_per_epo * n_epochs
    print("[!] Number of steps {}".format(n_steps))
    print("[!] Saves model/step output at every {}".format(bat_per_epo * 1))

    for i in range(n_steps):
        start = time.time()
        [X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
        X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
        d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
        d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
        g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])

        time_taken = time.time() - start
        print(
            '[*] %06d, d1[%.3f] d2[%.3f] g[%06.3f] ---> time[%.2f], time_left[%.08s]'
            %
            (i + 1, d_loss1, d_loss2, g_loss, time_taken, str(datetime.timedelta(seconds=((time_taken) * (n_steps - (i + 1))))).split('.')[0].zfill(8))
        )

        if (i + 1) % (bat_per_epo * 1) == 0:
            save_model(i, g_model, d_model, dataset)

In [None]:
dataset = load_real_samples('/MyDrive/img.npz')
image_shape = dataset[0].shape[1:]

In [None]:
d_model = discriminator(image_shape)
g_model = generator(image_shape)
gan_model = define_gan(g_model, d_model, image_shape)

In [None]:
dir = '/MyDrive/'
fileName = 'Enhancement Model'
step_output = dir + fileName + "/Step Output/"
model_output = dir + fileName + "/Model Output/"
if fileName not in os.listdir(dir):
    os.mkdir(dir + fileName)
    os.mkdir(step_output)
    os.mkdir(model_output)

train(d_model, g_model, gan_model, dataset, batch=12)