In [1]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [2]:
import os
if not os.path.exists("./gan_images"):
    os.mkdir("./gan_images")

In [3]:
# 생성자 모델  : pooling 안함
generator = Sequential()
generator.add(Dense(128*7*7,input_dim=100, activation=LeakyReLU(0.2)))
generator.add(BatchNormalization())   # 정규화 -> 분산 1이 되도록 재배치
generator.add(Reshape((7,7,128)))  # 배열 reshape
generator.add(UpSampling2D())   # 사이즈 업
generator.add(Conv2D(64, kernel_size=5, padding='same'))
generator.add(BatchNormalization())
generator.add(Activation(LeakyReLU(0.2)))
generator.add(UpSampling2D())
generator.add(Conv2D(1, kernel_size=5,padding='same', activation='tanh' ))

In [4]:
# 판별자 모델  : 학습 안함
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, input_shape=(28,28,1),
                 padding='same', activation=LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Conv2D(128,kernel_size=5,strides=2,padding='same'))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer='adam')
discriminator.trainable=False

In [5]:
# 생성자와 판별자 모델을 연결시키는 gan 모델 생성
ginput = Input(shape=(100,))
goutput = discriminator(generator(ginput)) 
gan = Model(ginput, goutput)
gan.compile(loss='binary_crossentropy', optimizer='adam')
gan.summary()
# gan.train_on_batch()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 100)]             0         
                                                                 
 sequential (Sequential)     (None, 28, 28, 1)         865281    
                                                                 
 sequential_1 (Sequential)   (None, 1)                 212865    
                                                                 
Total params: 1,078,146
Trainable params: 852,609
Non-trainable params: 225,537
_________________________________________________________________


In [6]:
# 신경망을 실행시키는 함수를 정의
def gan_train(epoch, batch_size, saving_interval):
    # MNIST 데이터 불러오기
    (X_train,_), (_,_) = mnist.load_data()  # train의 이미지만 필요함
    X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
    # 픽셀값이 0 ~ 255 => -1 ~ 1 로 변환
    X_train = ( X_train - 127.5) / 127.5
    
    # 실데이터와 가상 데이터의 타겟 생성 
    true = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    # 생성된 이미지와 실제 이미지 비교 학습 ( epoch 만큼 )
    for i in range(epoch):
        # 실제 이미지를 판별자에 입력
        idx = np.random.randint(0, X_train.shape[0], batch_size)  # 인덱스 랜덤
        imgs = X_train[idx]  # 이미지를 랜덤하게 batch_size만큼 가져옴
        d_loss_real = discriminator.train_on_batch(imgs, true)
        
        # 가상 이미지를 판별자에 입력
        noise = np.random.normal(0, 1, (batch_size,100))
        gen_imgs = generator.predict(noise)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        
        # 판별자와 생성자의 오차 계산
        d_loss = np.add(d_loss_real, d_loss_fake) * 0.5
        g_loss = gan.train_on_batch(noise, true)
        # 오차 출력
        print('epoch:%d' %i, ' d_loss:%.4f' % d_loss, ' g_loss:%.4f' % g_loss)
        
        # 중간에 이미지 저장 : 5 * 5개의 이미지를 저장
        if i % saving_interval == 0:
            noise = np.random.normal(0, 1, (25,100))
            gen_imgs = generator.predict(noise)
            gen_imgs = gen_imgs * 0.5 + 0.5
            
            fig, axs = plt.subplots(5, 5)
            count = 0
            for j in range(5):
                for k in range(5):
                    axs[j, k].imshow(gen_imgs[count, :, :, 0], cmap='gray')
                    axs[j, k].axis('off')
                    count += 1
            fig.savefig("gan_images/gan_mnist_%d.png" % i)

In [None]:
gan_train(4001, 32, 200)

epoch:0  d_loss:0.6879  g_loss:0.5430
epoch:1  d_loss:0.4394  g_loss:0.1530
epoch:2  d_loss:0.4066  g_loss:0.0173
epoch:3  d_loss:0.4756  g_loss:0.0045
epoch:4  d_loss:0.5134  g_loss:0.0060
epoch:5  d_loss:0.5364  g_loss:0.0388
epoch:6  d_loss:0.4489  g_loss:0.2404
epoch:7  d_loss:0.4870  g_loss:0.4736
epoch:8  d_loss:0.5123  g_loss:0.4979
epoch:9  d_loss:0.5047  g_loss:0.4203
epoch:10  d_loss:0.4365  g_loss:0.4966
epoch:11  d_loss:0.4166  g_loss:0.7088
epoch:12  d_loss:0.3471  g_loss:1.1423
epoch:13  d_loss:0.3312  g_loss:1.7042
epoch:14  d_loss:0.3390  g_loss:2.1136
epoch:15  d_loss:0.4726  g_loss:1.7082
epoch:16  d_loss:0.5518  g_loss:1.0929
epoch:17  d_loss:0.6324  g_loss:0.8172
epoch:18  d_loss:0.5525  g_loss:1.0656
epoch:19  d_loss:0.4065  g_loss:1.1238
epoch:20  d_loss:0.4399  g_loss:0.7258
epoch:21  d_loss:0.3018  g_loss:0.5959
epoch:22  d_loss:0.2178  g_loss:0.6494
epoch:23  d_loss:0.1455  g_loss:1.0400
epoch:24  d_loss:0.1868  g_loss:0.7990
epoch:25  d_loss:0.1772  g_loss:0.5