In [1]:
import tensorflow as tf

from tensorflow.keras.activations import relu
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, Add
from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D, Concatenate, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, UpSampling2D, Conv2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import regularizers
from tensorflow import identity

import matplotlib.pyplot as plt

import sys
import os

import numpy as np
import PIL

In [2]:
# Define image size
img_rows = 256
img_cols = 256
channels = 1
img_shape = (img_rows, img_cols, channels)

# Define hyperparameters
gp_coef = 1
latent_dim = 100
lr_d = 1e-4
lr_g = 2e-5
optimizer_d = Adam(lr = lr_d, beta_1 = 0, beta_2 = 0.9)
optimizer_g = Adam(lr = lr_g, beta_1 = 0, beta_2 = 0.9)

In [None]:
def ResBlock_G(X, filter_num):
    #Residule block for the generator
    X_copy = identity(X)
    
    X = Conv2D(filter_num, kernel_size = 3, strides = 1, padding = 'same')(X)
    X = LeakyReLU(alpha = 0.2)(X)
    X = Conv2D(filter_num, kernel_size = 3, strides = 1, padding = 'same')(X)
        
    X = Add()([X, X_copy])
    X = LeakyReLU(alpha = 0.2)(X)
    
    return X

In [3]:
def ResBlock_D(X, filter_num):
    #Residule block for the discriminator
    X_copy = identity(X)
    
    X = Conv2D(filter_num, kernel_size = 3, strides = 1, padding = 'same')(X)
    X = LeakyReLU(alpha = 0.2)(X)
    X = Conv2D(filter_num, kernel_size = 3, strides = 1, padding = 'same')(X)
    
    X = Add()([X, X_copy])
    X = LeakyReLU(alpha = 0.2)(X)
    
    return X

In [5]:
def generator():
    #Generator structure
    noise = Input(shape = (latent_dim,))
    in_label = Input(shape = (1,))
    
    gg = Dense(100*8*8)(noise)
    gg = Reshape((8,8,100))(gg)
    ll = Dense(8*8*16)(in_label)
    ll = Reshape((8,8,16))(ll)
    merge = Concatenate()([gg,ll])
    
    gen = Conv2DTranspose(64, kernel_size = 9, strides = 4, padding = 'same')(merge)
    
    gen = ResBlock_G(gen, 64)
    gen = ResBlock_G(gen, 64)
    gen = ResBlock_G(gen, 64)
    gen = ResBlock_G(gen, 64)
    gen = ResBlock_G(gen, 64)
    gen = ResBlock_G(gen, 64)
    
    gen = Conv2D(256, kernel_size = 3, strides = 1, padding = 'same')(gen)
    gen = UpSampling2D((2,2), interpolation='nearest')(gen)
    gen = Conv2D(256, kernel_size = 3, strides = 1, padding = 'same')(gen)
    gen = UpSampling2D((2,2), interpolation='nearest')(gen)
    
    gen = Conv2DTranspose(128, kernel_size = 7, strides = 2, padding = 'same')(gen)
    gen = LeakyReLU(alpha = 0.2)(gen)
    
    out_layer = Conv2D(1, kernel_size = 11, strides = 1, activation = 'tanh', padding = 'same')(gen)
    
    
    return Model([noise,in_label], out_layer)

In [10]:
def discriminator():
    #discriminator structure
    in_img = Input(shape = img_shape)
    in_label = Input(shape = (1,))
    
    fe = Conv2D(64, kernel_size=4, strides=2, padding="same")(in_img)
    fe = LeakyReLU(alpha = 0.2)(fe)
    fe = Conv2D(128, kernel_size=4, strides=2, padding="same")(fe)
    fe = LeakyReLU(alpha = 0.2)(fe)
    
    ll = Dense(64*64*20)(in_label)
    ll = Reshape((64,64,20))(ll)
    merge = Concatenate()([fe, ll])
    
    fe = Conv2D(256, kernel_size=4, strides=2, padding="same")(merge)
    fe = LeakyReLU(alpha = 0.2)(fe)   
    fe = Conv2D(512, kernel_size=5, strides=2, padding="same")(merge)
    fe = LeakyReLU(alpha = 0.2)(fe) 

    fe = Conv2D(256, kernel_size=3, strides=1, padding="same")(fe)
    fe = LeakyReLU(alpha = 0.2)(fe)
    fe = Conv2D(128, kernel_size=3, strides=1, padding="same")(fe)
    fe = LeakyReLU(alpha = 0.2)(fe)
    fe = Conv2D(64, kernel_size=3, strides=1, padding="same")(fe)
    fe = LeakyReLU(alpha = 0.2)(fe)
    fe = ResBlock_D(fe, 64)
    fe = Flatten()(fe)
    fe = Dense(512)(fe)
    fe = LeakyReLU(alpha = 0.2)(fe)
    
    out_layer = Dense(1)(fe)
    
    return Model([in_img, in_label], out_layer)

In [None]:
net_g = generator()
net_g.summary()

In [None]:
net_d = discriminator()
net_d.summary()

In [13]:
def d_training(real_img, noise, label):
    #One discriminator training step with gradient penalty
    with tf.GradientTape() as tape_d:
        fake_img = net_g([noise, label])
        loss_real = tf.reduce_mean(net_d([real_img, label]))
        loss_fake = tf.reduce_mean(net_d([fake_img, label]))
        
        #Calculate gradient penalty
        with tf.GradientTape() as tape_penalty:
            epsilon = tf.random.uniform([batch_size], 0, 1)
            epsilon = tf.reshape(epsilon, (-1,1,1,1))
            interpolated_img = epsilon*real_img + (1-epsilon)*fake_img
            tape_penalty.watch(interpolated_img)
            interpolated_out = net_d([interpolated_img, label])
            grad_interpolated = tape_penalty.gradient(interpolated_out, interpolated_img)
            grad_norm = tf.math.sqrt(tf.math.reduce_sum(tf.math.square(grad_interpolated), axis = [1,2,3]))
            grad_penalty = tf.reduce_mean(tf.math.square(grad_norm-1))
        
        loss = loss_fake - loss_real + gp_coef*grad_penalty
        grad_d = tape_d.gradient(loss, net_d.trainable_weights)
        optimizer_d.apply_gradients(zip(grad_d, net_d.trainable_weights))
    
    return (loss_real, loss_fake, loss)

In [14]:
def g_training(noise, label):
    #One generator training step with gradient penalty
    with tf.GradientTape() as tape_g:
        gen_img = net_g([noise, label])
        loss = -tf.reduce_mean(net_d([gen_img, label]))
        grad_g = tape_g.gradient(loss, net_g.trainable_weights)
        optimizer_g.apply_gradients(zip(grad_g, net_g.trainable_weights))
    return loss

In [15]:
def load_samples_and_rotate():
    #load data and apply rotation to quadruple the train set size
    dirct_name = 'training data/'
    train_set = []
    for i in range(8):
        if i!=5:
            #leave a subset as the validation set
            for j in range(1,6):
                subset_name = dirct_name + str(i) + '-' +str(j)
                image_names = os.listdir(subset_name)
                if '.ipynb_checkpoints' in image_names:
                    image_names.remove('.ipynb_checkpoints')
                for item in image_names:
                    img = PIL.Image.open(subset_name+'/'+item)
                    img_90 = img.transpose(PIL.Image.ROTATE_90)
                    img_180 = img.transpose(PIL.Image.ROTATE_180)
                    img_270 = img.transpose(PIL.Image.ROTATE_270)
                    arr = np.asarray(img)
                    arr_90 = np.asarray(img_90)
                    arr_180 = np.asarray(img_180)
                    arr_270 = np.asarray(img_270)
                    train_set.append(arr)
                    train_set.append(arr_90)
                    train_set.append(arr_180)
                    train_set.append(arr_270)
            
    train_X = np.reshape(train_set, (1080*6,256,256,1))
    
    labels = [0.73, 0.72, 0.7, 0.67, 0.66, 0.62, 0.56, 0.51]
    
    train_Y_l = []
    for i in range(8):
        if i != 5:
            #leave a subset as the validation set (0.62)
            train_Y_l.append(labels[i]*np.ones((1080,1)))

    train_Y = np.reshape(train_Y_l, (-1,1))
        
    return (train_X,train_Y)

In [16]:
#load data and normalize them
(train_X,train_Y) = load_samples_and_rotate()
train_X_n = (train_X.astype(np.float32) - 127.5) / 127.5
print(train_X_n.shape)
print(train_Y.shape)

(6480, 256, 256, 1)
(6480, 1)


In [None]:
epochs = 500
num_sample = train_X_n.shape[0]
batch_size = 128
num_minibatch = num_sample // batch_size

for i in range(epochs):
    shuffled_idx = np.random.randint(0, train_X_n.shape[0], train_X_n.shape[0])
    for j in range(num_minibatch):
        minibatch_idx = shuffled_idx[j*batch_size:(j+1)*batch_size]
        imgs = train_X_n[minibatch_idx]
        labels = train_Y[minibatch_idx]
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        err_d_real, err_d_fake, loss_total = d_training(imgs, noise, labels)
        err_g = g_training(noise, labels)
    
    if i%5 == 0:
        print ("epoch number: %d" % (i))
        r = 2
        c = 2
        noise_g = np.random.normal(0, 1, (1, 100))
        for m in range(2):
            noise_c = np.copy(noise_g)
            noise_g = np.concatenate((noise_g, noise_c), axis = 0)
        sampled_labels = np.asarray(([0.72,0.7,0.62,0.51]))
        fake_imgs = net_g.predict([noise_g, sampled_labels])
        fake_imgs = fake_imgs * 127.5 + 127.5
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for k in range(r):
            for l in range(c):
                axs[k,l].imshow(fake_imgs[cnt,:,:,0], cmap='gray')
                axs[k,l].axis('off')
                cnt += 1
        fig.savefig("saved_images/%d.png" % i)
        plt.close()
        
net_g.save('generator_0209')
net_d.save('discriminator_0209')

0 [D loss : -2617.074463] [G loss: -213.648804]
5 [D loss : -853.758057] [G loss: 138.943497]
10 [D loss : -509.043457] [G loss: 170.820892]
15 [D loss : -334.369507] [G loss: 1113.373413]
20 [D loss : -274.868958] [G loss: 1755.954590]
25 [D loss : -269.553650] [G loss: 1157.455078]
30 [D loss : -223.239746] [G loss: 1391.271729]
35 [D loss : -209.037766] [G loss: 984.831604]
40 [D loss : -210.923904] [G loss: 1015.291443]
45 [D loss : -157.022949] [G loss: 887.205933]
50 [D loss : -148.411987] [G loss: 1125.600586]
55 [D loss : -163.724930] [G loss: 806.031921]
