- Reference
  - Blog: https://work-in-progress.hatenablog.com/entry/2019/04/06/113629
  - Source: https://github.com/eriklindernoren/Keras-GAN/blob/master/gan/gan.py
  - Source: https://github.com/eriklindernoren/Keras-GAN/pull/117
    - Added functionality to save/load Keras model for intermittent training.

In [2]:
import os

In [3]:
os.makedirs('data/images', exist_ok=True)
os.makedirs('data/saved_models', exist_ok=True)

In [4]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.models import Sequential, Model, model_from_json
from tensorflow.keras.optimizers import Adam

In [5]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import datetime

In [6]:
class GAN():
    def __init__(self):
        self.history = pd.DataFrame({}, columns=['d_loss', 'acc', 'g_loss'])

        self.img_save_dir = 'data/images'
        self.model_save_dir = 'data/saved_models'
        self.discriminator_name = 'discriminator_model'
        self.generator_name = 'generator_model'
        self.combined_name = 'combined_model'

        self.discriminator = None
        self.generator = None
        self.combined = None

        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

    def init(self):
        optimizer = Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])

        self.generator = self.build_generator()

        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.summary()
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def load(self):
        optimizer = Adam(0.0002, 0.5)

        self.discriminator = self.load_model(self.discriminator_name)
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])
        self.discriminator.summary()

        self.generator = self.load_model(self.generator_name)
        self.generator.summary()

        #z = Input(shape=(self.latent_dim,))
        #img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        #validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        #self.combined = Model(z, validity)
        self.combined = self.load_model(self.combined_name)
        self.combined.summary()
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):
        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50, save_interval=50):
        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()
        print(X_train.shape)

        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)
        print(X_train.shape)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        print(datetime.datetime.now().isoformat(), 'Epoch Start')

        for epoch in range(epochs):
            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            gen_imgs = self.generator.predict(noise)

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid)

            # print (datetime.datetime.now().isoformat(), '%d [D loss: %f, acc.: %.2f%%] [G loss: %f]' % (epoch, d_loss[0], 100*d_loss[1], g_loss))
            self.history = self.history.append({'d_loss': d_loss[0], 'acc': d_loss[1], 'g_loss': g_loss}, ignore_index=True)

            if epoch % sample_interval == 0:
                print (datetime.datetime.now().isoformat(), '%d [D loss: %f, acc.: %.2f%%] [G loss: %f]' % (epoch, d_loss[0], 100*d_loss[1], g_loss))
                self.sample_images(epoch)

            if epoch != 0 and epoch % save_interval == 0:
                self.save_models()
                
        print(datetime.datetime.now().isoformat(), 'Epoch End')

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0

        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1

        file_name = '{}.png'.format(epoch)
        file_path = os.path.join(self.img_save_dir, file_name)
        fig.savefig(file_path)

        plt.close()

    def plot_hisotry(self, columns=[]):
        if len(columns) == 0:
            columns = ['d_loss', 'acc', 'g_loss']
        self.history[columns].plot()
 
    def save_models(self):
        self.save_model(self.discriminator, self.discriminator_name)
        self.save_model(self.generator, self.generator_name)
        self.save_model(self.combined, self.combined_name)

    def save_model(self, model, model_name):
        json_path = os.path.join(self.model_save_dir, '{}.json'.format(model_name))
        weights_path = os.path.join(self.model_save_dir, '{}.h5'.format(model_name))

        with open(json_path, 'w') as f:
            f.write(model.to_json())

        model.save_weights(weights_path)

        print (datetime.datetime.now().isoformat(), 'Model saved.', model_name)

    def load_model(self, model_name):
        json_path = os.path.join(self.model_save_dir, '{}.json'.format(model_name))
        weights_path = os.path.join(self.model_save_dir, '{}.h5'.format(model_name))

        with open(json_path, 'r') as f:
            loaded_json = f.read()

        model = model_from_json(loaded_json)
        model.load_weights(weights_path)

        return model

In [7]:
gan = GAN()

In [8]:
#gan.load()

In [9]:
gan.init()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 257       
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
__________________________________________________

In [10]:
gan.train(epochs=100, batch_size=32, sample_interval=10, save_interval=10)

(60000, 28, 28)
(60000, 28, 28, 1)
2020-07-20T16:09:20.328264 Epoch Start
2020-07-20T16:09:21.737633 0 [D loss: 0.666037, acc.: 70.31%] [G loss: 0.702898]
2020-07-20T16:09:22.805699 10 [D loss: 0.118922, acc.: 100.00%] [G loss: 1.886346]
2020-07-20T16:09:23.372645 Model saved. discriminator_model
2020-07-20T16:09:23.390075 Model saved. generator_model
2020-07-20T16:09:23.410323 Model saved. combined_model
2020-07-20T16:09:23.853307 20 [D loss: 0.034616, acc.: 100.00%] [G loss: 2.851344]
2020-07-20T16:09:24.308319 Model saved. discriminator_model
2020-07-20T16:09:24.326179 Model saved. generator_model
2020-07-20T16:09:24.346452 Model saved. combined_model
2020-07-20T16:09:24.784131 30 [D loss: 0.026591, acc.: 100.00%] [G loss: 3.409873]
2020-07-20T16:09:25.351566 Model saved. discriminator_model
2020-07-20T16:09:25.370032 Model saved. generator_model
2020-07-20T16:09:25.390199 Model saved. combined_model
2020-07-20T16:09:25.827469 40 [D loss: 0.014306, acc.: 100.00%] [G loss: 3.863801]


In [11]:
gan.save_models()

2020-07-20T16:09:46.480285 Model saved. discriminator_model
2020-07-20T16:09:46.496689 Model saved. generator_model
2020-07-20T16:09:46.517935 Model saved. combined_model
