<a href="https://colab.research.google.com/github/youngchul-sung/three-minutes-keras/blob/master/ex7_2_gan_cnn_mnist_tf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
%tensorflow_version 2.x

import math
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import models, layers, optimizers
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Conv1D, Reshape, Flatten, Lambda


keras.backend.set_image_data_format('channels_first')
print(keras.backend.image_data_format)

################################
# GAN 모델링
################################

def mse_4d(y_true, y_pred):
    return keras.backend.mean(keras.backend.square(y_pred - y_true), axis=(1,2,3))

def mse_4d_tf(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_pred - y_true), axis=(1,2,3))


class GAN(models.Sequential):
    def __init__(self, input_dim=64):
        """
        self, self.generator, self.discriminator are all models
        """
        super().__init__()
        self.input_dim = input_dim

        self.generator = self.GENERATOR()
        self.discriminator = self.DISCRIMINATOR()
        self.add(self.generator)
        self.discriminator.trainable = False
        self.add(self.discriminator)
        
        self.compile_all()

    def GENERATOR(self):
        input_dim = self.input_dim

        model = models.Sequential()
        model.add(layers.Dense(1024, activation='tanh', input_dim=input_dim))
        model.add(layers.Dense(128 * 7 * 7, activation='tanh'))
        model.add(layers.BatchNormalization())
        model.add(layers.Reshape((128, 7, 7), input_shape=(128 * 7 * 7,)))
        model.add(layers.UpSampling2D(size=(2, 2)))
        model.add(layers.Conv2D(64, (5, 5), padding='same', activation='tanh'))
        model.add(layers.UpSampling2D(size=(2, 2)))
        model.add(layers.Conv2D(1, (5, 5), padding='same', activation='tanh'))
        return model

    def DISCRIMINATOR(self):
        model = models.Sequential()
        model.add(layers.Conv2D(64, (5, 5), padding='same', activation='tanh',
                                input_shape=(1, 28, 28)))
        model.add(layers.MaxPooling2D(pool_size=(2, 2)))
        model.add(layers.Conv2D(128, (5, 5), activation='tanh'))
        model.add(layers.MaxPooling2D(pool_size=(2, 2)))
        model.add(layers.Flatten())
        model.add(layers.Dense(1024, activation='tanh'))
        model.add(layers.Dense(1, activation='sigmoid'))
        return model

    def compile_all(self):
        # Compiling stage
        d_optim = optimizers.SGD(lr=0.0005, momentum=0.9, nesterov=True)
        g_optim = optimizers.SGD(lr=0.0005, momentum=0.9, nesterov=True)
        self.generator.compile(loss=mse_4d_tf, optimizer="SGD")
        self.compile(loss='binary_crossentropy', optimizer=g_optim)
        self.discriminator.trainable = True
        self.discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)

    def get_z(self, ln):
        input_dim = self.input_dim
        return np.random.uniform(-1, 1, (ln, input_dim))

    def train_both(self, x):
        ln = x.shape[0]
        # First trial for training discriminator
        z = self.get_z(ln)
        w = self.generator.predict(z, verbose=0)
        xw = np.concatenate((x, w))
        y2 = np.array([1] * ln + [0] * ln)
        d_loss = self.discriminator.train_on_batch(xw, y2)

        # Second trial for training generator
        z = self.get_z(ln)
        self.discriminator.trainable = False
        g_loss = self.train_on_batch(z, np.array([1] * ln))
        self.discriminator.trainable = True

        return d_loss, g_loss


################################
# GAN 학습하기
################################
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num) / width))
    shape = generated_images.shape[2:]
    image = np.zeros((height * shape[0], width * shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index / width)
        j = index % width
        image[i * shape[0]:(i + 1) * shape[0],
        j * shape[1]:(j + 1) * shape[1]] = img[0, :, :]
    return image


def get_x(X_train, index, BATCH_SIZE):
    return X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]


def save_images(generated_images, output_fold, epoch, index):
    image = combine_images(generated_images)
    image = image * 127.5 + 127.5
    Image.fromarray(image.astype(np.uint8)).save(
        output_fold + '/' +
        str(epoch) + "_" + str(index) + ".png")


def load_data(n_train):
    (X_train, y_train), (_, _) = mnist.load_data()

    return X_train[:n_train]


def train(args):
    BATCH_SIZE = args.batch_size
    epochs = args.epochs
    output_fold = args.output_fold
    input_dim = args.input_dim
    n_train = args.n_train

    os.makedirs(output_fold, exist_ok=True)
    print('Output_fold is', output_fold)

    X_train = load_data(n_train)

    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])

    gan = GAN(input_dim)

    d_loss_ll = []
    g_loss_ll = []
    for epoch in range(epochs):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))

        d_loss_l = []
        g_loss_l = []
        for index in range(int(X_train.shape[0] / BATCH_SIZE)):
            x = get_x(X_train, index, BATCH_SIZE)

            d_loss, g_loss = gan.train_both(x)

            d_loss_l.append(d_loss)
            g_loss_l.append(g_loss)

        if epoch % 10 == 0 or epoch == epochs - 1:
            z = gan.get_z(x.shape[0])
            w = gan.generator.predict(z, verbose=0)
            save_images(w, output_fold, epoch, 0)

        d_loss_ll.append(d_loss_l)
        g_loss_ll.append(g_loss_l)

    gan.summary()

    gan.generator.save_weights(output_fold + '/' + 'generator', True)
    gan.discriminator.save_weights(output_fold + '/' + 'discriminator', True)

    np.savetxt(output_fold + '/' + 'd_loss', d_loss_ll)
    np.savetxt(output_fold + '/' + 'g_loss', g_loss_ll)


################################
# GAN 예제 실행하기
################################

class ARGS:
    def __init__(self):
        self.batch_size = 16
        self.epochs = 1000
        self.output_fold = 'GAN_OUT'
        self.input_dim = 10
        self.n_train = 32

def main():
    args = ARGS()
    train(args)


if __name__ == '__main__':
    main()

<function image_data_format at 0x7f2fcef1e8c8>
Output_fold is GAN_OUT
Epoch is 0
Number of batches 2
Epoch is 1
Number of batches 2
Epoch is 2
Number of batches 2
Epoch is 3
Number of batches 2
Epoch is 4
Number of batches 2
Epoch is 5
Number of batches 2
Epoch is 6
Number of batches 2
Epoch is 7
Number of batches 2
Epoch is 8
Number of batches 2
Epoch is 9
Number of batches 2
Epoch is 10
Number of batches 2
Epoch is 11
Number of batches 2
Epoch is 12
Number of batches 2
Epoch is 13
Number of batches 2
Epoch is 14
Number of batches 2
Epoch is 15
Number of batches 2
Epoch is 16
Number of batches 2
Epoch is 17
Number of batches 2
Epoch is 18
Number of batches 2
Epoch is 19
Number of batches 2
Epoch is 20
Number of batches 2
Epoch is 21
Number of batches 2
Epoch is 22
Number of batches 2
Epoch is 23
Number of batches 2
Epoch is 24
Number of batches 2
Epoch is 25
Number of batches 2
Epoch is 26
Number of batches 2
Epoch is 27
Number of batches 2
Epoch is 28
Number of batches 2
Epoch is 29
