In [3]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

class GAN():
    def __init__(self):
        #mnistデータ用の入力データサイズ
        self.img_rows = 28 
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        # 潜在変数の次元数 
        self.z_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # discriminatorモデル
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', 
            optimizer=optimizer,
            metrics=['accuracy'])

        # Generatorモデル
        self.generator = self.build_generator()
        # generatorは単体で学習しないのでコンパイルは必要ない
        #self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        self.combined = self.build_combined1()
        #self.combined = self.build_combined2()
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        noise_shape = (self.z_dim,)
        model = Sequential()

        model.add(Dense(256, input_shape=noise_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        return model

    def build_discriminator(self):

        img_shape = (self.img_rows, self.img_cols, self.channels)
        
        model = Sequential()

        model.add(Flatten(input_shape=img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        return model
    
    def build_combined1(self):
        self.discriminator.trainable = False
        model = Sequential([self.generator, self.discriminator])
        return model

    def build_combined2(self):
        z = Input(shape=(self.z_dim,))
        img = self.generator(z)
        self.discriminator.trainable = False
        valid = self.discriminator(img)
        model = Model(z, valid)
        model.summary()
        return model

    def train(self, epochs, batch_size=128, save_interval=50):

        # mnistデータの読み込み
        (X_train, _), (_, _) = mnist.load_data()

        # 値を-1 to 1に規格化
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Discriminatorの学習
            # ---------------------

            # バッチサイズの半数をGeneratorから生成
            noise = np.random.normal(0, 1, (half_batch, self.z_dim))
            gen_imgs = self.generator.predict(noise)


            # バッチサイズの半数を教師データからピックアップ
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            # discriminatorを学習
            # 本物データと偽物データは別々に学習させる
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            # それぞれの損失関数を平均
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Generatorの学習
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.z_dim))

            # 生成データの正解ラベルは本物（1） 
            valid_y = np.array([1] * batch_size)

            # Train the generator
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # 進捗の表示
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # 指定した間隔で生成画像を保存
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        # 生成画像を敷き詰めるときの行数、列数
        r, c = 5, 5

        noise = np.random.normal(0, 1, (r * c, self.z_dim))
        gen_imgs = self.generator.predict(noise)

        # 生成画像を0-1に再スケール
        gen_imgs = 0.5 * gen_imgs +  0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/gan/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=1000, batch_size=32, save_interval=100)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_3 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_15 (Dense)             (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_11 (LeakyReLU)   (None, 512)               0         
_________________________________________________________________
dense_16 (Dense)             (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_12 (LeakyReLU)   (None, 256)               0         
_________________________________________________________________
dense_17 (Dense)             (None, 1)                 257       
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
____

  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.629129, acc.: 40.62%] [G loss: 0.622638]


  'Discrepancy between trainable weights and collected trainable'


1 [D loss: 0.376094, acc.: 75.00%] [G loss: 0.724654]
2 [D loss: 0.338856, acc.: 81.25%] [G loss: 0.865592]
3 [D loss: 0.324815, acc.: 78.12%] [G loss: 0.947941]
4 [D loss: 0.242271, acc.: 93.75%] [G loss: 1.189906]
5 [D loss: 0.239594, acc.: 96.88%] [G loss: 1.244903]
6 [D loss: 0.213583, acc.: 90.62%] [G loss: 1.446320]
7 [D loss: 0.139553, acc.: 100.00%] [G loss: 1.530020]
8 [D loss: 0.147485, acc.: 100.00%] [G loss: 1.693536]
9 [D loss: 0.124042, acc.: 100.00%] [G loss: 1.816324]
10 [D loss: 0.086729, acc.: 100.00%] [G loss: 1.815099]
11 [D loss: 0.101408, acc.: 100.00%] [G loss: 1.914283]
12 [D loss: 0.109316, acc.: 100.00%] [G loss: 2.033732]
13 [D loss: 0.070656, acc.: 100.00%] [G loss: 2.156601]
14 [D loss: 0.093805, acc.: 100.00%] [G loss: 2.088259]
15 [D loss: 0.084493, acc.: 100.00%] [G loss: 2.261597]
16 [D loss: 0.070523, acc.: 100.00%] [G loss: 2.325712]
17 [D loss: 0.072979, acc.: 100.00%] [G loss: 2.390721]
18 [D loss: 0.060172, acc.: 100.00%] [G loss: 2.423348]
19 [D l

149 [D loss: 0.369417, acc.: 84.38%] [G loss: 3.878480]
150 [D loss: 1.354254, acc.: 50.00%] [G loss: 3.614594]
151 [D loss: 0.111090, acc.: 100.00%] [G loss: 3.629623]
152 [D loss: 0.551473, acc.: 71.88%] [G loss: 2.358260]
153 [D loss: 0.204672, acc.: 87.50%] [G loss: 2.813787]
154 [D loss: 0.175048, acc.: 93.75%] [G loss: 3.734856]
155 [D loss: 0.476932, acc.: 75.00%] [G loss: 3.013251]
156 [D loss: 0.115857, acc.: 96.88%] [G loss: 3.677480]
157 [D loss: 0.249564, acc.: 93.75%] [G loss: 2.640384]
158 [D loss: 0.285134, acc.: 90.62%] [G loss: 3.629907]
159 [D loss: 0.166087, acc.: 96.88%] [G loss: 3.961783]
160 [D loss: 0.249619, acc.: 87.50%] [G loss: 3.017440]
161 [D loss: 0.326613, acc.: 84.38%] [G loss: 3.377290]
162 [D loss: 0.410821, acc.: 81.25%] [G loss: 3.552097]
163 [D loss: 0.176907, acc.: 100.00%] [G loss: 2.808477]
164 [D loss: 0.230296, acc.: 90.62%] [G loss: 3.191915]
165 [D loss: 0.366565, acc.: 84.38%] [G loss: 2.191305]
166 [D loss: 0.232115, acc.: 87.50%] [G loss: 

298 [D loss: 0.744286, acc.: 40.62%] [G loss: 0.725923]
299 [D loss: 0.767985, acc.: 43.75%] [G loss: 0.729696]
300 [D loss: 0.614529, acc.: 53.12%] [G loss: 0.950136]
301 [D loss: 0.773065, acc.: 37.50%] [G loss: 0.844122]
302 [D loss: 0.761313, acc.: 43.75%] [G loss: 0.724131]
303 [D loss: 0.873984, acc.: 25.00%] [G loss: 0.602031]
304 [D loss: 0.769646, acc.: 37.50%] [G loss: 0.653271]
305 [D loss: 0.754203, acc.: 37.50%] [G loss: 0.680697]
306 [D loss: 0.703333, acc.: 43.75%] [G loss: 0.798347]
307 [D loss: 0.856223, acc.: 21.88%] [G loss: 0.671659]
308 [D loss: 0.722055, acc.: 46.88%] [G loss: 0.708060]
309 [D loss: 0.832800, acc.: 31.25%] [G loss: 0.623030]
310 [D loss: 0.773714, acc.: 40.62%] [G loss: 0.622650]
311 [D loss: 0.715605, acc.: 50.00%] [G loss: 0.667351]
312 [D loss: 0.729464, acc.: 40.62%] [G loss: 0.720230]
313 [D loss: 0.767680, acc.: 43.75%] [G loss: 0.653711]
314 [D loss: 0.721303, acc.: 37.50%] [G loss: 0.655999]
315 [D loss: 0.730025, acc.: 37.50%] [G loss: 0.

446 [D loss: 0.664754, acc.: 53.12%] [G loss: 0.672377]
447 [D loss: 0.648059, acc.: 56.25%] [G loss: 0.661273]
448 [D loss: 0.667011, acc.: 46.88%] [G loss: 0.672075]
449 [D loss: 0.678079, acc.: 56.25%] [G loss: 0.679965]
450 [D loss: 0.643066, acc.: 53.12%] [G loss: 0.704520]
451 [D loss: 0.647444, acc.: 65.62%] [G loss: 0.716516]
452 [D loss: 0.668329, acc.: 53.12%] [G loss: 0.716248]
453 [D loss: 0.655367, acc.: 65.62%] [G loss: 0.718661]
454 [D loss: 0.642480, acc.: 71.88%] [G loss: 0.698003]
455 [D loss: 0.672721, acc.: 56.25%] [G loss: 0.687448]
456 [D loss: 0.678614, acc.: 56.25%] [G loss: 0.681117]
457 [D loss: 0.645864, acc.: 62.50%] [G loss: 0.688395]
458 [D loss: 0.698021, acc.: 50.00%] [G loss: 0.680559]
459 [D loss: 0.664455, acc.: 50.00%] [G loss: 0.689848]
460 [D loss: 0.664225, acc.: 46.88%] [G loss: 0.680749]
461 [D loss: 0.658230, acc.: 56.25%] [G loss: 0.681917]
462 [D loss: 0.645291, acc.: 59.38%] [G loss: 0.678644]
463 [D loss: 0.646046, acc.: 46.88%] [G loss: 0.

594 [D loss: 0.678114, acc.: 50.00%] [G loss: 0.713923]
595 [D loss: 0.625403, acc.: 75.00%] [G loss: 0.721043]
596 [D loss: 0.632130, acc.: 65.62%] [G loss: 0.690931]
597 [D loss: 0.651274, acc.: 56.25%] [G loss: 0.670592]
598 [D loss: 0.664446, acc.: 50.00%] [G loss: 0.683542]
599 [D loss: 0.649318, acc.: 56.25%] [G loss: 0.709596]
600 [D loss: 0.663244, acc.: 53.12%] [G loss: 0.688307]
601 [D loss: 0.677596, acc.: 50.00%] [G loss: 0.707338]
602 [D loss: 0.660056, acc.: 62.50%] [G loss: 0.691089]
603 [D loss: 0.691200, acc.: 62.50%] [G loss: 0.729094]
604 [D loss: 0.690358, acc.: 46.88%] [G loss: 0.803655]
605 [D loss: 0.638366, acc.: 75.00%] [G loss: 0.802583]
606 [D loss: 0.666099, acc.: 65.62%] [G loss: 0.729439]
607 [D loss: 0.671577, acc.: 56.25%] [G loss: 0.753449]
608 [D loss: 0.649054, acc.: 65.62%] [G loss: 0.748643]
609 [D loss: 0.740645, acc.: 34.38%] [G loss: 0.736203]
610 [D loss: 0.637874, acc.: 56.25%] [G loss: 0.727694]
611 [D loss: 0.673338, acc.: 53.12%] [G loss: 0.

741 [D loss: 0.675362, acc.: 50.00%] [G loss: 0.689016]
742 [D loss: 0.675856, acc.: 53.12%] [G loss: 0.682107]
743 [D loss: 0.700160, acc.: 46.88%] [G loss: 0.727668]
744 [D loss: 0.624783, acc.: 65.62%] [G loss: 0.765933]
745 [D loss: 0.669664, acc.: 50.00%] [G loss: 0.759584]
746 [D loss: 0.695063, acc.: 43.75%] [G loss: 0.723350]
747 [D loss: 0.677984, acc.: 46.88%] [G loss: 0.712702]
748 [D loss: 0.666948, acc.: 59.38%] [G loss: 0.736302]
749 [D loss: 0.665341, acc.: 53.12%] [G loss: 0.713291]
750 [D loss: 0.649577, acc.: 59.38%] [G loss: 0.721857]
751 [D loss: 0.666139, acc.: 53.12%] [G loss: 0.743199]
752 [D loss: 0.628134, acc.: 62.50%] [G loss: 0.770421]
753 [D loss: 0.700577, acc.: 53.12%] [G loss: 0.722671]
754 [D loss: 0.692345, acc.: 40.62%] [G loss: 0.700320]
755 [D loss: 0.667044, acc.: 59.38%] [G loss: 0.730281]
756 [D loss: 0.671875, acc.: 53.12%] [G loss: 0.735020]
757 [D loss: 0.668140, acc.: 59.38%] [G loss: 0.711637]
758 [D loss: 0.663997, acc.: 62.50%] [G loss: 0.

891 [D loss: 0.664319, acc.: 59.38%] [G loss: 0.779497]
892 [D loss: 0.625675, acc.: 62.50%] [G loss: 0.756995]
893 [D loss: 0.589943, acc.: 75.00%] [G loss: 0.738217]
894 [D loss: 0.652587, acc.: 59.38%] [G loss: 0.752003]
895 [D loss: 0.634320, acc.: 71.88%] [G loss: 0.770870]
896 [D loss: 0.602519, acc.: 68.75%] [G loss: 0.791279]
897 [D loss: 0.642360, acc.: 56.25%] [G loss: 0.756295]
898 [D loss: 0.632413, acc.: 62.50%] [G loss: 0.763950]
899 [D loss: 0.608939, acc.: 71.88%] [G loss: 0.721839]
900 [D loss: 0.657096, acc.: 56.25%] [G loss: 0.734435]
901 [D loss: 0.649951, acc.: 68.75%] [G loss: 0.762222]
902 [D loss: 0.598403, acc.: 81.25%] [G loss: 0.765035]
903 [D loss: 0.675492, acc.: 50.00%] [G loss: 0.788597]
904 [D loss: 0.602254, acc.: 68.75%] [G loss: 0.849429]
905 [D loss: 0.644332, acc.: 56.25%] [G loss: 0.773429]
906 [D loss: 0.607097, acc.: 71.88%] [G loss: 0.789918]
907 [D loss: 0.631685, acc.: 56.25%] [G loss: 0.815407]
908 [D loss: 0.645749, acc.: 62.50%] [G loss: 0.