In [31]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Dense, Dropout, Flatten, Reshape, Input, BatchNormalization, Activation, ZeroPadding2D, MaxPooling2D, AveragePooling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from tensorflow.keras.optimizers import Adam
from keras.models import Model, Sequential

import matplotlib.pyplot as plt
import numpy as np
import sys

In [32]:
latent_dim = 100
nh = 28
nw = 28
nc = 1
img_shape = (nh, nw, nc)

opt = Adam(0.0002, 0.5)

In [33]:
def build_generator():
  
  model = Sequential()
  
  model.add(Dense(128 * 7 * 7, activation="relu", input_dim = latent_dim))
  model.add(Reshape((7, 7, 128)))
  
  # deconvolve step
  # upsampling will increase the dimension of input image by repeating rows and columns
  model.add(UpSampling2D())
  model.add(Conv2D(128, kernel_size=3, padding="same"))
  model.add(BatchNormalization(momentum=0.8))
  model.add(Activation("relu"))
  model.add(UpSampling2D())
  model.add(Conv2D(64, kernel_size=3, padding="same"))
  model.add(BatchNormalization(momentum=0.8))
  model.add(Activation("relu"))
  model.add(Conv2D(nc, kernel_size=3, padding="same"))
  model.add(Activation("tanh"))
  
  print(model.summary())
  
  noise = Input(shape = (latent_dim,))
  img = model(noise)
  
  return Model(inputs = noise, outputs = img)

In [34]:
generator = build_generator()

In [35]:
generator.summary()

In [36]:
def build_discriminator():
  model = Sequential()
  
  model.add(Conv2D(32, kernel_size = 3, strides = 2, input_shape = img_shape, padding = 'same'))
  model.add(LeakyReLU(alpha = 0.2))
  model.add(Dropout(0.25))
  model.add(Conv2D(64, kernel_size = 3, strides = 2, padding = 'same'))
  model.add(ZeroPadding2D(padding = ((0,1), (0, 1))))
  model.add(BatchNormalization(momentum = 0.8))
  model.add(LeakyReLU(alpha = 0.2))
  model.add(Dropout(0.25))
  model.add(Conv2D(128, kernel_size = 3, strides = 2, padding = 'same'))
  model.add(BatchNormalization(momentum = 0.8))
  model.add(LeakyReLU(alpha = 0.2))
  model.add(Dropout(0.25))
  model.add(Conv2D(256, kernel_size = 3, strides = 2, padding = 'same'))
  model.add(BatchNormalization(momentum = 0.8))
  model.add(LeakyReLU(alpha = 0.2))
  model.add(Dropout(0.25))
  model.add(Flatten())
  model.add(Dense(1, activation = 'sigmoid'))
  
  print(model.summary())
  
  img = Input(shape = img_shape)
  validity = model(img)
  
  return Model(inputs = img, outputs = validity)

In [37]:
discriminator = build_discriminator()
discriminator.summary()

discriminator.compile(loss = 'binary_crossentropy', optimizer = opt, metrics = ['accuracy'])

In [38]:
z = Input(shape = (latent_dim,))
img = generator(z)

discriminator.trainable = False
valid = discriminator(img)

combined = Model(z, valid)
combined.compile(loss = 'binary_crossentropy', optimizer = opt)
combined.summary()

In [39]:
print(generator.summary())
print(discriminator.summary())
print(combined.summary())

In [40]:
def train(batch_size = 128, epochs = 4000):
  
  (X_train, _), (_, _) = mnist.load_data()

  X_train = X_train / 127.5 - 1.
  X_train = np.expand_dims(X_train, axis=3)
  print(X_train.shape)

  valid = np.ones((batch_size, 1))
  fakes = np.zeros((batch_size, 1))
  
  for epoch in range(epochs):
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgs = X_train[idx]
    
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    gen_imgs = generator.predict(noise)
    
    #Training Discriminator
    d_loss_real = discriminator.train_on_batch(imgs, valid)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fakes)
    d_loss = 0.5*np.add(d_loss_real, d_loss_fake)
    
    #Training Generator
    g_loss = combined.train_on_batch(noise, valid)
    
    #Progress
    print("epoch: " + str(epoch) + " " + "D_Loss = " + str(d_loss[0]) + " " + "acc: " + str(d_loss[1]*100) + " " +  "G_Loss = " + str(g_loss))

In [None]:
train(128, 5000)

In [None]:
def save_imgs(epoch):
  r, c = 5, 5
  noise = np.random.normal(0, 1, (r * c, latent_dim))
  gen_imgs = generator.predict(noise)

  # Rescale images 0 - 1
  gen_imgs = 0.5 * gen_imgs + 0.5

  fig, axs = plt.subplots(r, c)
  cnt = 0
  for i in range(r):
     for j in range(c):
        axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
        axs[i,j].axis('off')
        cnt += 1
  fig.savefig("mnist_%d.png" % epoch)
  plt.close()

In [None]:
save_imgs(5000)