<a href="https://colab.research.google.com/github/tomneo2004/Machine-Learning-Practice/blob/master/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reference

- code: https://github.com/eriklindernoren/Keras-GAN/blob/master/gan/gan.py#L141

- paper: https://arxiv.org/abs/1406.2661

# GAN model

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from keras.models import Sequential, Model
from keras.layers import Input, Dense, BatchNormalization, Flatten, Reshape, LeakyReLU
from keras.optimizers.legacy import Adam

class GAN():
  def __init__(self, image_shape, latent_dim=100):
    super().__init__()

    self.latent_dim = latent_dim
    self.image_shape = image_shape

    # optimizer
    optimizer = Adam(0.0002, 0.5)

    # create discriminator
    self.discriminator = self.build_discriminator()
    self.discriminator.compile(loss="binary_crossentropy",
                               optimizer=optimizer,
                               metrics=["accuracy"])

    # create generator
    self.generator = self.build_generator()

    # stack generator and discriminator
    z = Input(shape=(self.latent_dim,))
    image = self.generator(z)
    self.discriminator.trainable = False
    validity = self.discriminator(image)
    self.combine = Model(z, validity)
    self.combine.compile(loss="binary_crossentropy",
                         optimizer=optimizer)

  def build_discriminator(self):
    model = Sequential([
        Flatten(),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(256),
        LeakyReLU(alpha=0.2),
        Dense(1, activation="sigmoid")
    ])

    input = Input(shape=self.image_shape)
    validity = model(input)
    return Model(input, validity)

  def build_generator(self):
    model = Sequential([
        Dense(256, "relu"),
        BatchNormalization(momentum=0.8),
        Dense(512, "relu"),
        BatchNormalization(momentum=0.8),
        Dense(1024, "relu"),
        BatchNormalization(momentum=0.8),
        Dense(np.prod(self.image_shape), "tanh"),
        Reshape(self.image_shape)
    ])

    input = Input(shape=(self.latent_dim,))
    img = model(input)
    return Model(input, img)

  def fit(self, X, epochs, batch_size):
    # Rescale to range -1,1
    X = X / 127.5 - 1

    for epoch in range(epochs):
      # create groud truth for valid and fake
      valid = np.ones(shape=(batch_size, 1))
      fake = np.zeros(shape=(batch_size, 1))

      #########################
      ## train discriminator ##
      #########################

      # select random batch valid image
      batch_valid_idx = np.random.randint(0, X.shape[0], batch_size)
      batch_valid_img = X[batch_valid_idx]

      # generate batch fake image
      noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
      batch_fake_img = self.generator.predict(noise)

      # train discriminator
      valid_loss = self.discriminator.train_on_batch(batch_valid_img, valid)
      fake_loss = self.discriminator.train_on_batch(batch_fake_img, fake)
      d_loss = 0.5 * np.add(valid_loss, fake_loss)

      #########################
      ## train generator     ##
      #########################

      # generate noise
      noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

      # train generator
      g_loss = self.combine.train_on_batch(noise, valid)

      print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

      # If at save interval => save generated image samples
      if epoch % 200 == 0:
          self.sample_images(epoch)

  def sample_images(self, epoch):
      if os.path.exists("images") == False:
        os.mkdir("images")
      r, c = 5, 5
      noise = np.random.normal(0, 1, (r * c, self.latent_dim))
      gen_imgs = self.generator.predict(noise)

      # Rescale images 0 - 1
      gen_imgs = (gen_imgs + 1) * 127.5

      fig, axs = plt.subplots(r, c)
      cnt = 0
      for i in range(r):
          for j in range(c):
              axs[i,j].imshow(gen_imgs[cnt])
              axs[i,j].axis('off')
              cnt += 1
      fig.savefig("images/%d.png" % epoch)
      plt.close()

In [None]:
from keras.datasets.mnist import load_data

(X_train, y_train), (X_test, y_test) = load_data()

In [None]:
gan = GAN(X_train.shape[1:])
gan.fit(X_train, 30000, 32)

0 [D loss: 0.702423, acc.: 39.06%] [G loss: 0.646549]
1 [D loss: 0.354849, acc.: 81.25%] [G loss: 0.669724]
2 [D loss: 0.336779, acc.: 82.81%] [G loss: 0.738376]
3 [D loss: 0.320447, acc.: 90.62%] [G loss: 0.789457]
4 [D loss: 0.286810, acc.: 98.44%] [G loss: 0.895287]
5 [D loss: 0.276762, acc.: 96.88%] [G loss: 0.996196]
6 [D loss: 0.219063, acc.: 98.44%] [G loss: 1.185659]
7 [D loss: 0.168526, acc.: 100.00%] [G loss: 1.323694]
8 [D loss: 0.133856, acc.: 100.00%] [G loss: 1.460893]
9 [D loss: 0.098270, acc.: 100.00%] [G loss: 1.643706]
10 [D loss: 0.101310, acc.: 100.00%] [G loss: 1.782477]
11 [D loss: 0.101777, acc.: 100.00%] [G loss: 1.968493]
12 [D loss: 0.067075, acc.: 100.00%] [G loss: 1.983751]
13 [D loss: 0.096913, acc.: 100.00%] [G loss: 2.262216]
14 [D loss: 0.071114, acc.: 100.00%] [G loss: 2.430569]
15 [D loss: 0.056285, acc.: 100.00%] [G loss: 2.553810]
16 [D loss: 0.066790, acc.: 100.00%] [G loss: 2.676374]
17 [D loss: 0.043185, acc.: 100.00%] [G loss: 2.780374]
18 [D los