<a href="https://colab.research.google.com/github/trevormoon/GAN_study/blob/main/GAN_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import pandas as pd
import numpy as np
import keras
import keras.backend as K
from keras.layers import Conv2D, Activation, Dropout, Flatten, Dense, BatchNormalization, Reshape, UpSampling2D, Input
from keras.models import Model
from keras.optimizers import RMSprop
from keras.preprocessing.image import array_to_img
import tensorflow as tf
import warnings ; warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
from tqdm import tqdm

In [5]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
X_train, x_test = x_train / 255.0, x_test / 255.0
print(X_train.shape)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
(60000, 28, 28)


In [6]:
disc_input = Input(shape=(28, 28, 1))

x = Conv2D(filters=64, kernel_size=5, strides=2, padding='same')(disc_input)
x = Activation('relu')(x)
x = Dropout(rate=0.4)(x)

x = Conv2D(filters = 64, kernel_size=5, strides=2, padding='same')(x)
x = Activation('relu')(x)
x = Dropout(rate=0.4)(x)

x = Conv2D(filters=128, kernel_size=5, strides=2, padding='same')(x)
x = Activation('relu')(x)
x = Dropout(rate=0.4)(x)

x = Conv2D(filters=128, kernel_size=5, strides=1, padding='same')(x)
x = Activation('relu')(x)
x = Dropout(rate=0.4)(x)

x = Flatten()(x)
disc_output = Dense(units=1, activation='sigmoid', kernel_initializer='he_normal')(x)

discriminator = Model(disc_input, disc_output)
discriminator.summary()



Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 conv2d (Conv2D)             (None, 14, 14, 64)        1664      
                                                                 
 activation (Activation)     (None, 14, 14, 64)        0         
                                                                 
 dropout (Dropout)           (None, 14, 14, 64)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 7, 7, 64)          102464    
                                                                 
 activation_1 (Activation)   (None, 7, 7, 64)          0         
                                                                 
 dropout_1 (Dropout)         (None, 7, 7, 64)          0     

In [7]:
gen_dense_size=(7, 7, 64)

gen_input = Input(shape = (100, ))
x = Dense(units=np.prod(gen_dense_size))(gen_input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Reshape(gen_dense_size)(x)

x = UpSampling2D()(x)
x = Conv2D(filters=128, kernel_size=5, padding='same', strides=1)(x)
x = BatchNormalization(momentum=0.9)(x)
x = Activation('relu')(x)

x = UpSampling2D()(x)
x = Conv2D(filters = 64, kernel_size=5, padding='same', strides=1)(x)
x = BatchNormalization(momentum=0.9)(x)
x = Activation('relu')(x)

x = Conv2D(filters=64, kernel_size=5, padding='same', strides=1)(x)
x = BatchNormalization(momentum=0.9)(x)
x = Activation('relu')(x)

x = Conv2D(filters=1, kernel_size=5, padding='same', strides=1)(x)
gen_output = Activation('sigmoid')(x)

generator = Model(gen_input, gen_output)
generator.summary()



Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 100)]             0         
                                                                 
 dense_1 (Dense)             (None, 3136)              316736    
                                                                 
 batch_normalization (Batch  (None, 3136)              12544     
 Normalization)                                                  
                                                                 
 activation_4 (Activation)   (None, 3136)              0         
                                                                 
 reshape (Reshape)           (None, 7, 7, 64)          0         
                                                                 
 up_sampling2d (UpSampling2  (None, 14, 14, 64)        0         
 D)                                                        

In [8]:
discriminator.compile(optimizer=RMSprop(lr=0.0008), loss='binary_crossentropy', metrics=['accuracy'])
discriminator.trainable = False
model_input = Input(shape=(100, ))
model_output = discriminator(generator(model_input))
model = Model(model_input, model_output)

model.compile(optimizer=RMSprop(lr=0.0004), loss='binary_crossentropy', metrics=['accuracy'])





In [None]:
def train_discriminator(x_train, batch_size):
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    idx = np.random.randint(0, len(X_train), batch_size)
    true_imgs = X_train[idx]
    discriminator.fit(true_imgs, valid, verbose=0)

    noise = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(noise)

    discriminator.fit(gen_imgs, fake, verbose=0)

def train_generator(batch_size):
    valid = np.ones((batch_size, 1))
    noise = np.random.normal(0, 1, (batch_size, 100))
    model.fit(noise, valid, verbose=1)

for epoch in tqdm(range(2000)):
    train_discriminator(X_train, 64)
    train_generator(64)

  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 1/2000 [00:08<4:30:04,  8.11s/it]



  0%|          | 2/2000 [00:12<3:11:44,  5.76s/it]



  0%|          | 3/2000 [00:19<3:30:33,  6.33s/it]



  0%|          | 4/2000 [00:23<3:00:25,  5.42s/it]



  0%|          | 5/2000 [00:29<3:12:46,  5.80s/it]



  0%|          | 6/2000 [00:33<2:49:37,  5.10s/it]



  0%|          | 7/2000 [00:37<2:36:25,  4.71s/it]



  0%|          | 8/2000 [00:42<2:37:36,  4.75s/it]



  0%|          | 9/2000 [00:48<2:55:35,  5.29s/it]



  0%|          | 10/2000 [01:03<4:37:30,  8.37s/it]



  1%|          | 11/2000 [01:09<4:12:39,  7.62s/it]



  1%|          | 12/2000 [01:13<3:33:22,  6.44s/it]



  1%|          | 13/2000 [01:18<3:13:23,  5.84s/it]



  1%|          | 14/2000 [01:22<3:02:17,  5.51s/it]



  1%|          | 15/2000 [01:26<2:45:45,  5.01s/it]



  1%|          | 16/2000 [01:33<3:00:15,  5.45s/it]



  1%|          | 17/2000 [01:37<2:45:08,  5.00s/it]



  1%|          | 18/2000 [01:40<2:33:27,  4.65s/it]



  1%|          | 19/2000 [01:47<2:52:09,  5.21s/it]



  1%|          | 20/2000 [01:51<2:37:26,  4.77s/it]



  1%|          | 21/2000 [01:57<2:54:20,  5.29s/it]



  1%|          | 22/2000 [02:01<2:41:31,  4.90s/it]



  1%|          | 23/2000 [02:05<2:30:53,  4.58s/it]



  1%|          | 24/2000 [02:12<2:56:42,  5.37s/it]



  1%|▏         | 25/2000 [02:16<2:43:27,  4.97s/it]



  1%|▏         | 26/2000 [02:23<2:58:13,  5.42s/it]



  1%|▏         | 27/2000 [02:27<2:43:31,  4.97s/it]



  1%|▏         | 28/2000 [02:33<2:58:20,  5.43s/it]



  1%|▏         | 29/2000 [02:40<3:08:52,  5.75s/it]



  2%|▏         | 30/2000 [02:46<3:17:15,  6.01s/it]



  2%|▏         | 31/2000 [02:50<2:57:25,  5.41s/it]



  2%|▏         | 32/2000 [02:54<2:42:18,  4.95s/it]



  2%|▏         | 33/2000 [02:59<2:40:40,  4.90s/it]



  2%|▏         | 34/2000 [03:05<2:54:56,  5.34s/it]



  2%|▏         | 35/2000 [03:10<2:50:05,  5.19s/it]



  2%|▏         | 36/2000 [03:14<2:37:49,  4.82s/it]



  2%|▏         | 37/2000 [03:18<2:27:30,  4.51s/it]



  2%|▏         | 38/2000 [03:24<2:46:58,  5.11s/it]



  2%|▏         | 39/2000 [03:28<2:35:10,  4.75s/it]



  2%|▏         | 40/2000 [03:35<2:53:17,  5.30s/it]



  2%|▏         | 41/2000 [03:39<2:45:42,  5.08s/it]



  2%|▏         | 42/2000 [03:43<2:32:31,  4.67s/it]



  2%|▏         | 43/2000 [03:50<2:52:45,  5.30s/it]



  2%|▏         | 44/2000 [03:54<2:39:00,  4.88s/it]



  2%|▏         | 45/2000 [04:00<2:54:10,  5.35s/it]



  2%|▏         | 46/2000 [04:04<2:40:19,  4.92s/it]



  2%|▏         | 47/2000 [04:08<2:29:56,  4.61s/it]



  2%|▏         | 48/2000 [04:15<2:56:15,  5.42s/it]



  2%|▏         | 49/2000 [04:19<2:41:04,  4.95s/it]



  2%|▎         | 50/2000 [04:26<2:57:08,  5.45s/it]



  3%|▎         | 51/2000 [04:30<2:39:49,  4.92s/it]



  3%|▎         | 52/2000 [04:33<2:28:18,  4.57s/it]



  3%|▎         | 53/2000 [04:40<2:51:50,  5.30s/it]



  3%|▎         | 54/2000 [04:44<2:39:09,  4.91s/it]



  3%|▎         | 55/2000 [04:49<2:37:18,  4.85s/it]



  3%|▎         | 56/2000 [04:55<2:51:20,  5.29s/it]



  3%|▎         | 57/2000 [05:02<3:01:05,  5.59s/it]



  3%|▎         | 58/2000 [05:05<2:43:57,  5.07s/it]



  3%|▎         | 59/2000 [05:10<2:34:10,  4.77s/it]



  3%|▎         | 60/2000 [05:16<2:52:53,  5.35s/it]



  3%|▎         | 61/2000 [05:20<2:37:40,  4.88s/it]



  3%|▎         | 62/2000 [05:24<2:32:24,  4.72s/it]



  3%|▎         | 63/2000 [05:30<2:42:08,  5.02s/it]



  3%|▎         | 64/2000 [05:34<2:32:33,  4.73s/it]



  3%|▎         | 65/2000 [05:40<2:48:12,  5.22s/it]



  3%|▎         | 66/2000 [05:44<2:33:56,  4.78s/it]



  3%|▎         | 67/2000 [05:48<2:25:26,  4.51s/it]



  3%|▎         | 68/2000 [05:55<2:48:46,  5.24s/it]



  3%|▎         | 69/2000 [05:59<2:35:57,  4.85s/it]



  4%|▎         | 70/2000 [06:05<2:51:13,  5.32s/it]



  4%|▎         | 71/2000 [06:09<2:36:10,  4.86s/it]



  4%|▎         | 72/2000 [06:13<2:24:19,  4.49s/it]



  4%|▎         | 73/2000 [06:18<2:26:20,  4.56s/it]



  4%|▎         | 74/2000 [06:21<2:18:18,  4.31s/it]



  4%|▍         | 75/2000 [06:25<2:15:46,  4.23s/it]



  4%|▍         | 76/2000 [06:32<2:38:20,  4.94s/it]

In [None]:
original=array_to_img(X_train[0])
plt.imshow(original, cmap='gray')

In [None]:
random_noise=np.random.normal(0, 1, (1, 100))
gen_result=generator.predict(random_noise)
gen_img=array_to_img(gen_result[0])
plt.imshow(gen_img, cmap='gray')