In [2]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input,Dense,Reshape,Flatten,Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

import os
if not os.path.exists("./gan_images"):
    os.makedirs("./gan_images")

In [3]:
generator = Sequential()
generator.add(Dense(128*7*7, input_dim=100, activation=LeakyReLU(0.2)))
generator.add(BatchNormalization())
generator.add(Reshape((7, 7, 128)))
generator.add(UpSampling2D())
generator.add(Conv2D(64, kernel_size=5, padding='same'))
generator.add(BatchNormalization())
generator.add(Activation(LeakyReLU(0.2)))
generator.add(UpSampling2D())
generator.add(Conv2D(1, kernel_size=5, padding='same', activation='tanh'))

In [4]:
generator.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 6272)              633472    
_________________________________________________________________
batch_normalization (BatchNo (None, 6272)              25088     
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 128)         0         
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 14, 14, 64)        204864    
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
activation (Activation)      (None, 14, 14, 64)        0

In [5]:
#판별자 모델을 만듭니다.
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, input_shape=(28,28,1), padding="same"))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same"))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer='adam')
discriminator.trainable = False

In [6]:
discriminator.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_2 (Conv2D)            (None, 14, 14, 64)        1664      
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 64)        0         
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 128)         204928    
_________________________________________________________________
activation_2 (Activation)    (None, 7, 7, 128)         0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 7, 7, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 6272)             

In [7]:
ginput = Input(shape=(100,))
dis_output = discriminator(generator(ginput))
gan = Model(ginput, dis_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
gan.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
sequential (Sequential)      (None, 28, 28, 1)         865281    
_________________________________________________________________
sequential_1 (Sequential)    (None, 1)                 212865    
Total params: 1,078,146
Trainable params: 852,609
Non-trainable params: 225,537
_________________________________________________________________


In [8]:
def gan_train(epoch, batch_size, saving_interval):
    (X_train, _), (_, _)=mnist.load_data()
    X_train=X_train.reshape(X_train.shape[0], 28,28,1).astype('float32')
    X_train=(X_train-127.5)/127.5 #절반을 빼고, 그 절반으로 나누면 -1~1사이의 range로 변환됨
    true=np.ones((batch_size,1))
    fake=np.zeros((batch_size,1))
    for i in range(epoch):
        idx=np.random.randint(0, X_train.shape[0], batch_size)
        imgs=X_train[idx]
        d_loss_real=discriminator.train_on_batch(imgs, true) # 데이터를 학습시키는 부분
        #train_on_batch로 학습 (input, target)
        noise=np.random.normal(0,1,(batch_size,100))  # 0과 1사이로 정규분포 노이즈는 100개 발생시킨다
        gen_imgs=generator.predict(noise)  # 가중치와 바이어스가 랜덤하게 되어 있어서 말도 안되는 그림이 나온다
        d_loss_fake=discriminator.train_on_batch(gen_imgs,fake) # 이렇게 나온 놈은 가짜이다.
        #판별자와 생성자의 오차 계산
        d_loss=0.5*np.add(d_loss_real, d_loss_fake)
        g_loss=gan.train_on_batch(noise,true) # discriminator 학습이 일어나지 않는다. 오직 gan 만 일어난다.
        print('epoch:%d' %i, 'd_loss:%.4f' %d_loss, 'g_loss:%.4f' %g_loss)
        if i % saving_interval == 0:  # 소스를 받은거, 중간 세이브 , 실제 액션은 윗부분
              #r, c = 5, 5
              noise = np.random.normal(0, 1, (25, 100))
              gen_imgs = generator.predict(noise)
              # Rescale images 0 - 1
              gen_imgs = 0.5 * gen_imgs + 0.5
              fig, axs = plt.subplots(5, 5)
              count = 0
              for j in range(5):
                  for k in range(5):
                      axs[j, k].imshow(gen_imgs[count, :, :, 0], cmap='gray')
                      axs[j, k].axis('off')
                      count += 1
              fig.savefig("gan_images/gan_mnist_%d.png" % i)

In [None]:
gan_train(4001,32,200)

epoch:0 d_loss:0.6970 g_loss:0.5471
epoch:1 d_loss:0.5308 g_loss:0.2316
epoch:2 d_loss:0.4808 g_loss:0.0801
epoch:3 d_loss:0.4624 g_loss:0.0467
epoch:4 d_loss:0.4716 g_loss:0.0481
epoch:5 d_loss:0.4487 g_loss:0.0898
epoch:6 d_loss:0.4461 g_loss:0.1838
epoch:7 d_loss:0.4569 g_loss:0.2870
epoch:8 d_loss:0.4346 g_loss:0.3519
epoch:9 d_loss:0.4449 g_loss:0.3775
epoch:10 d_loss:0.4407 g_loss:0.3696
epoch:11 d_loss:0.4261 g_loss:0.3629
epoch:12 d_loss:0.3936 g_loss:0.3284
epoch:13 d_loss:0.3803 g_loss:0.2598
epoch:14 d_loss:0.3646 g_loss:0.1964
epoch:15 d_loss:0.3065 g_loss:0.1390
epoch:16 d_loss:0.2338 g_loss:0.0949
epoch:17 d_loss:0.1655 g_loss:0.0483
epoch:18 d_loss:0.1153 g_loss:0.0218
epoch:19 d_loss:0.0994 g_loss:0.0047
epoch:20 d_loss:0.0465 g_loss:0.0069
epoch:21 d_loss:0.0683 g_loss:0.0202
epoch:22 d_loss:0.0845 g_loss:0.0041
epoch:23 d_loss:0.0759 g_loss:0.0150
epoch:24 d_loss:0.0533 g_loss:0.0674
epoch:25 d_loss:0.3739 g_loss:0.0013
epoch:26 d_loss:0.1688 g_loss:0.1867
epoch:27 d_

epoch:219 d_loss:0.4870 g_loss:1.4724
epoch:220 d_loss:0.4660 g_loss:1.4176
epoch:221 d_loss:0.5457 g_loss:1.2049
epoch:222 d_loss:0.3892 g_loss:1.5006
epoch:223 d_loss:0.4587 g_loss:1.5464
epoch:224 d_loss:0.3777 g_loss:1.3438
epoch:225 d_loss:0.3765 g_loss:1.6921
epoch:226 d_loss:0.4688 g_loss:1.5606
epoch:227 d_loss:0.3774 g_loss:1.6866
epoch:228 d_loss:0.3342 g_loss:1.5966
epoch:229 d_loss:0.3390 g_loss:1.7813
epoch:230 d_loss:0.5011 g_loss:1.4753
epoch:231 d_loss:0.4711 g_loss:1.4696
epoch:232 d_loss:0.4732 g_loss:1.5174
epoch:233 d_loss:0.4127 g_loss:1.8008
epoch:234 d_loss:0.3902 g_loss:2.0005
epoch:235 d_loss:0.4371 g_loss:1.8254
epoch:236 d_loss:0.4770 g_loss:1.7360
epoch:237 d_loss:0.5043 g_loss:1.3978
epoch:238 d_loss:0.4377 g_loss:1.7050
epoch:239 d_loss:0.5011 g_loss:1.9574
epoch:240 d_loss:0.5340 g_loss:1.7111
epoch:241 d_loss:0.5267 g_loss:1.9219
epoch:242 d_loss:0.5542 g_loss:1.9355
epoch:243 d_loss:0.5908 g_loss:2.1422
epoch:244 d_loss:0.6772 g_loss:2.0549
epoch:245 d_

epoch:435 d_loss:0.4228 g_loss:1.6287
epoch:436 d_loss:0.3236 g_loss:1.9863
epoch:437 d_loss:0.3093 g_loss:2.5221
epoch:438 d_loss:0.2935 g_loss:2.6309
epoch:439 d_loss:0.3514 g_loss:2.2546
epoch:440 d_loss:0.2952 g_loss:2.3919
epoch:441 d_loss:0.3340 g_loss:1.6679
epoch:442 d_loss:0.2855 g_loss:1.6346
epoch:443 d_loss:0.2839 g_loss:2.0789
epoch:444 d_loss:0.2910 g_loss:1.9345
epoch:445 d_loss:0.2961 g_loss:2.3942
epoch:446 d_loss:0.2421 g_loss:2.5871
epoch:447 d_loss:0.2460 g_loss:2.8735
epoch:448 d_loss:0.3094 g_loss:2.7801
epoch:449 d_loss:0.3143 g_loss:2.4031
epoch:450 d_loss:0.3223 g_loss:1.8867
epoch:451 d_loss:0.3041 g_loss:1.4652
epoch:452 d_loss:0.3531 g_loss:1.8164
epoch:453 d_loss:0.3979 g_loss:2.3040
epoch:454 d_loss:0.2694 g_loss:2.9867
epoch:455 d_loss:0.2620 g_loss:2.8551
epoch:456 d_loss:0.2293 g_loss:2.7118
epoch:457 d_loss:0.3187 g_loss:1.9680
epoch:458 d_loss:0.2121 g_loss:2.2807
epoch:459 d_loss:0.2538 g_loss:2.2678
epoch:460 d_loss:0.2164 g_loss:2.2404
epoch:461 d_

epoch:651 d_loss:0.3572 g_loss:2.3043
epoch:652 d_loss:0.3057 g_loss:2.2975
epoch:653 d_loss:0.4028 g_loss:2.4803
epoch:654 d_loss:0.5570 g_loss:2.7617
epoch:655 d_loss:0.4773 g_loss:2.2470
epoch:656 d_loss:0.9949 g_loss:1.9432
epoch:657 d_loss:1.2043 g_loss:1.0019
epoch:658 d_loss:0.7036 g_loss:1.2841
epoch:659 d_loss:0.8112 g_loss:1.3150
epoch:660 d_loss:0.6648 g_loss:1.1918
epoch:661 d_loss:1.0406 g_loss:1.3953
epoch:662 d_loss:0.7415 g_loss:0.8558
epoch:663 d_loss:0.9412 g_loss:0.5370
epoch:664 d_loss:0.9760 g_loss:0.6722
epoch:665 d_loss:0.8440 g_loss:1.0427
epoch:666 d_loss:1.0032 g_loss:1.3983
epoch:667 d_loss:0.9169 g_loss:1.5947
epoch:668 d_loss:0.6648 g_loss:1.7785
epoch:669 d_loss:0.4239 g_loss:1.8742
epoch:670 d_loss:0.3294 g_loss:2.2644
epoch:671 d_loss:0.4616 g_loss:2.3759
epoch:672 d_loss:0.5379 g_loss:1.9630
epoch:673 d_loss:0.3867 g_loss:2.0203
epoch:674 d_loss:0.6261 g_loss:1.9250
epoch:675 d_loss:1.0198 g_loss:1.8111
epoch:676 d_loss:0.9954 g_loss:1.6565
epoch:677 d_