In [51]:
import tensorflow as tf
from keras.datasets import mnist
from keras.layers import Dense, Flatten, Reshape, Input, BatchNormalization
from keras.layers.activation import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt


In [52]:
img_r = 28
img_c = 28
channels = 1
image_shape = (img_r, img_c, channels)

In [53]:
def generator_build():
    noise_shape = (100,)

    model = Sequential()
    model.add(Dense(256, input_shape = noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum = 0.8))
    model.add(Dense(512, input_shape = noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum = 0.8))
    model.add(Dense(1024, input_shape = noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum = 0.8))
    model.add(Dense(np.prod(image_shape), activation='tanh'))
    model.add(Reshape(image_shape))

    noise = Input(shape=noise_shape)
    img = model(noise)

    return Model(noise, img)

In [54]:
def discriminator_build():

    model = Sequential()
    model.add(Flatten(input_shape=image_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation = 'sigmoid'))

    img = Input(shape=image_shape)
    validity = model(img)

    return Model(img, validity)

In [55]:
optimizer = Adam(0.0002, 0.5)

discriminator = discriminator_build()
discriminator.compile(loss = 'binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

generator = generator_build()
generator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

z = Input(shape=(100,))
img = generator(z)

discriminator.trainable = False

valid = discriminator(img)

combined = Model(z, valid)
combined.compile(loss = 'binary_crossentropy', optimizer=optimizer)

In [56]:
def save_images(epoch):
    r,c = 5,5
    noise = np.random.normal(0,1,(r*c,100))
    gen_imgs = generator.predict(noise)

    gen_imgs = 0.5*gen_imgs+0.5

    fig, axs = plt.subplots(r,c)
    cnt = 0

    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
            axs[i,j].axis('off')
    fig.savefig("./image/mnist_%d.png"% epoch)
    plt.close()

In [60]:
def train(epochs, batch_size=128, save_interval=500):
    (x_train, _), (_, _) = mnist.load_data()

    x_train = (x_train.astype(np.float32) - 127.5) / 127.5

    x_train = np.expand_dims(x_train, axis=3)

    half_batch = int(batch_size/2)

    for i in range(epochs):
        idx = np.random.randint(0, x_train.shape[0], half_batch)
        imgs = x_train[idx]

        noise = np.random.normal(0, 1, (half_batch, 100))

        gen_img = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch,1)))
        d_loss_fake = discriminator.train_on_batch(gen_img, np.zeros((half_batch,1)))

        d_loss = 0.5*np.add(d_loss_real, d_loss_fake)


        noise = np.random.normal(0, 1, (batch_size, 100))
        valid_y = np.array([1]*batch_size)
        g_loss = combined.train_on_batch(noise, valid_y)

        if i % save_interval == 0:
            save_images(i)

train(epochs=1000, batch_size=32, save_interval=10)




In [58]:
print(combined)

<keras.engine.functional.Functional object at 0x000002B9DCBF5B70>


In [59]:
(x_train, x_test), (y_train, y_test) = mnist.load_data()

print(len(x_train))

60000
