In [3]:
# from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Convolution2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

class GAN():
    def __init__(self):
        #mnistデータ用の入力データサイズ
        self.img_rows = 28 
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        # 潜在変数の次元数 
        self.z_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # discriminatorモデル
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', 
            optimizer=optimizer,
            metrics=['accuracy'])

        # Generatorモデル
        self.generator = self.build_generator()
        # generatorは単体で学習しないのでコンパイルは必要ない
        #self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        self.combined = self.build_combined1()
        #self.combined = self.build_combined2()
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):
        noise_shape = (self.z_dim,)
        model = Sequential()
        model.add(Dense(1024, input_shape=noise_shape))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Dense(128*7*7))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Reshape((7,7,128), input_shape=(128*7*7,)))
        model.add(UpSampling2D((2,2)))
        model.add(Convolution2D(64,5,5,border_mode='same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(UpSampling2D((2,2)))
        model.add(Convolution2D(1,5,5,border_mode='same'))
        model.add(Activation('tanh'))
        model.summary()
        return model

    def build_discriminator(self):
        img_shape = (self.img_rows, self.img_cols, self.channels)

        model = Sequential()
        model.add(Convolution2D(64,5,5, subsample=(2,2),\
                  border_mode='same', input_shape=img_shape))
        model.add(LeakyReLU(0.2))
        model.add(Convolution2D(128,5,5,subsample=(2,2)))
        model.add(LeakyReLU(0.2))
        model.add(Flatten())
        model.add(Dense(256))
        model.add(LeakyReLU(0.2))
        model.add(Dropout(0.5))
        model.add(Dense(1))
        model.add(Activation('sigmoid'))   
        return model
    
    def build_combined1(self):
        self.discriminator.trainable = False
        model = Sequential([self.generator, self.discriminator])
        return model

    def build_combined2(self):
        z = Input(shape=(self.z_dim,))
        img = self.generator(z)
        self.discriminator.trainable = False
        valid = self.discriminator(img)
        model = Model(z, valid)
        model.summary()
        return model

    def train(self, epochs, batch_size=128, save_interval=50):

        # mnistデータの読み込み
        (X_train, _), (_, _) = mnist.load_data()

        # 値を-1 to 1に規格化
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Discriminatorの学習
            # ---------------------

            # バッチサイズの半数をGeneratorから生成
            noise = np.random.normal(0, 1, (half_batch, self.z_dim))
            gen_imgs = self.generator.predict(noise)


            # バッチサイズの半数を教師データからピックアップ
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            # discriminatorを学習
            # 本物データと偽物データは別々に学習させる
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            # それぞれの損失関数を平均
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Generatorの学習
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.z_dim))

            # 生成データの正解ラベルは本物（1） 
            valid_y = np.array([1] * batch_size)

            # Train the generator
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # 進捗の表示
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # 指定した間隔で生成画像を保存
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        # 生成画像を敷き詰めるときの行数、列数
        r, c = 5, 5

        noise = np.random.normal(0, 1, (r * c, self.z_dim))
        gen_imgs = self.generator.predict(noise)

        # 生成画像を0-1に再スケール
        gen_imgs = 0.5 * gen_imgs +  0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/gcgan/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=1000, batch_size=32, save_interval=100)



_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_7 (Dense)              (None, 1024)              103424    
_________________________________________________________________
batch_normalization_4 (Batch (None, 1024)              4096      
_________________________________________________________________
activation_7 (Activation)    (None, 1024)              0         
_________________________________________________________________
dense_8 (Dense)              (None, 6272)              6428800   
_________________________________________________________________
batch_normalization_5 (Batch (None, 6272)              25088     
_________________________________________________________________
activation_8 (Activation)    (None, 6272)              0         
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 128)         0         
__________

  'Discrepancy between trainable weights and collected trainable'
  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.679790, acc.: 43.75%] [G loss: 0.602608]


  'Discrepancy between trainable weights and collected trainable'


1 [D loss: 0.723139, acc.: 50.00%] [G loss: 0.535291]
2 [D loss: 0.749040, acc.: 50.00%] [G loss: 0.588050]
3 [D loss: 0.609804, acc.: 50.00%] [G loss: 0.640486]
4 [D loss: 0.637524, acc.: 53.12%] [G loss: 0.558427]
5 [D loss: 0.475876, acc.: 100.00%] [G loss: 0.452320]
6 [D loss: 0.259969, acc.: 100.00%] [G loss: 0.306431]
7 [D loss: 0.134273, acc.: 100.00%] [G loss: 0.152187]
8 [D loss: 0.704814, acc.: 50.00%] [G loss: 0.140937]
9 [D loss: 1.128959, acc.: 50.00%] [G loss: 0.374712]
10 [D loss: 0.662062, acc.: 50.00%] [G loss: 0.679359]
11 [D loss: 0.732388, acc.: 37.50%] [G loss: 0.642520]
12 [D loss: 0.340690, acc.: 96.88%] [G loss: 0.470818]
13 [D loss: 0.226770, acc.: 93.75%] [G loss: 0.392800]
14 [D loss: 0.291762, acc.: 100.00%] [G loss: 0.274160]
15 [D loss: 0.850814, acc.: 46.88%] [G loss: 0.325305]
16 [D loss: 0.774884, acc.: 46.88%] [G loss: 0.398667]
17 [D loss: 0.765256, acc.: 50.00%] [G loss: 0.438745]
18 [D loss: 0.460355, acc.: 90.62%] [G loss: 0.494910]
19 [D loss: 0.7

151 [D loss: 0.463902, acc.: 84.38%] [G loss: 0.492100]
152 [D loss: 0.310937, acc.: 100.00%] [G loss: 0.429779]
153 [D loss: 0.789003, acc.: 53.12%] [G loss: 0.387524]
154 [D loss: 0.751383, acc.: 46.88%] [G loss: 0.341615]
155 [D loss: 0.577995, acc.: 65.62%] [G loss: 0.445442]
156 [D loss: 0.461816, acc.: 87.50%] [G loss: 0.630386]
157 [D loss: 0.745546, acc.: 46.88%] [G loss: 0.662896]
158 [D loss: 0.618476, acc.: 71.88%] [G loss: 0.442732]
159 [D loss: 0.440689, acc.: 81.25%] [G loss: 0.505979]
160 [D loss: 0.383972, acc.: 96.88%] [G loss: 0.477057]
161 [D loss: 0.955576, acc.: 37.50%] [G loss: 0.349443]
162 [D loss: 0.715627, acc.: 53.12%] [G loss: 0.404488]
163 [D loss: 0.420031, acc.: 87.50%] [G loss: 0.527239]
164 [D loss: 0.664609, acc.: 59.38%] [G loss: 0.618135]
165 [D loss: 0.820015, acc.: 31.25%] [G loss: 0.543220]
166 [D loss: 0.505501, acc.: 90.62%] [G loss: 0.494614]
167 [D loss: 0.373321, acc.: 93.75%] [G loss: 0.570128]
168 [D loss: 0.563521, acc.: 81.25%] [G loss: 0

299 [D loss: 0.729697, acc.: 56.25%] [G loss: 0.522760]
300 [D loss: 0.653680, acc.: 59.38%] [G loss: 0.515625]
301 [D loss: 0.574621, acc.: 75.00%] [G loss: 0.559991]
302 [D loss: 0.467990, acc.: 87.50%] [G loss: 0.590580]
303 [D loss: 0.506918, acc.: 84.38%] [G loss: 0.687122]
304 [D loss: 0.659304, acc.: 59.38%] [G loss: 0.605138]
305 [D loss: 0.762595, acc.: 43.75%] [G loss: 0.473519]
306 [D loss: 0.560452, acc.: 78.12%] [G loss: 0.463409]
307 [D loss: 0.475062, acc.: 93.75%] [G loss: 0.539960]
308 [D loss: 0.618023, acc.: 59.38%] [G loss: 0.552833]
309 [D loss: 0.842460, acc.: 37.50%] [G loss: 0.438467]
310 [D loss: 0.543950, acc.: 81.25%] [G loss: 0.514176]
311 [D loss: 0.467466, acc.: 78.12%] [G loss: 0.566512]
312 [D loss: 0.543305, acc.: 81.25%] [G loss: 0.525534]
313 [D loss: 0.640340, acc.: 62.50%] [G loss: 0.517448]
314 [D loss: 0.709880, acc.: 46.88%] [G loss: 0.513265]
315 [D loss: 0.598062, acc.: 71.88%] [G loss: 0.512807]
316 [D loss: 0.527684, acc.: 78.12%] [G loss: 0.

447 [D loss: 0.631819, acc.: 71.88%] [G loss: 1.019301]
448 [D loss: 0.694124, acc.: 53.12%] [G loss: 0.836700]
449 [D loss: 0.523440, acc.: 71.88%] [G loss: 0.734457]
450 [D loss: 0.386636, acc.: 90.62%] [G loss: 0.697042]
451 [D loss: 0.491359, acc.: 71.88%] [G loss: 0.542880]
452 [D loss: 0.619612, acc.: 62.50%] [G loss: 0.458482]
453 [D loss: 0.736099, acc.: 50.00%] [G loss: 0.510305]
454 [D loss: 0.694799, acc.: 50.00%] [G loss: 0.608297]
455 [D loss: 0.815841, acc.: 37.50%] [G loss: 0.646294]
456 [D loss: 0.699968, acc.: 46.88%] [G loss: 0.757479]
457 [D loss: 0.575718, acc.: 81.25%] [G loss: 0.877850]
458 [D loss: 0.682925, acc.: 50.00%] [G loss: 0.829193]
459 [D loss: 0.667107, acc.: 62.50%] [G loss: 0.729000]
460 [D loss: 0.577664, acc.: 71.88%] [G loss: 0.731620]
461 [D loss: 0.593993, acc.: 71.88%] [G loss: 0.687785]
462 [D loss: 0.613915, acc.: 65.62%] [G loss: 0.669252]
463 [D loss: 0.790443, acc.: 40.62%] [G loss: 0.683209]
464 [D loss: 0.687225, acc.: 46.88%] [G loss: 0.

594 [D loss: 0.679709, acc.: 59.38%] [G loss: 0.736798]
595 [D loss: 0.495716, acc.: 78.12%] [G loss: 0.695527]
596 [D loss: 0.510998, acc.: 71.88%] [G loss: 0.648045]
597 [D loss: 0.548328, acc.: 62.50%] [G loss: 0.564053]
598 [D loss: 0.621346, acc.: 62.50%] [G loss: 0.549291]
599 [D loss: 0.709900, acc.: 59.38%] [G loss: 0.594592]
600 [D loss: 0.702818, acc.: 56.25%] [G loss: 0.722911]
601 [D loss: 0.590063, acc.: 68.75%] [G loss: 0.822312]
602 [D loss: 0.602547, acc.: 65.62%] [G loss: 0.825549]
603 [D loss: 0.599898, acc.: 59.38%] [G loss: 0.821343]
604 [D loss: 0.659855, acc.: 65.62%] [G loss: 0.751943]
605 [D loss: 0.659065, acc.: 56.25%] [G loss: 0.735348]
606 [D loss: 0.594298, acc.: 68.75%] [G loss: 0.762591]
607 [D loss: 0.552417, acc.: 81.25%] [G loss: 0.693535]
608 [D loss: 0.603211, acc.: 68.75%] [G loss: 0.611278]
609 [D loss: 0.646463, acc.: 62.50%] [G loss: 0.635883]
610 [D loss: 0.735843, acc.: 50.00%] [G loss: 0.680682]
611 [D loss: 0.680310, acc.: 56.25%] [G loss: 0.

741 [D loss: 0.589056, acc.: 65.62%] [G loss: 1.013323]
742 [D loss: 0.754154, acc.: 46.88%] [G loss: 0.923599]
743 [D loss: 0.728587, acc.: 43.75%] [G loss: 0.788822]
744 [D loss: 0.694883, acc.: 53.12%] [G loss: 0.799528]
745 [D loss: 0.557714, acc.: 84.38%] [G loss: 0.782013]
746 [D loss: 0.510963, acc.: 75.00%] [G loss: 0.716947]
747 [D loss: 0.568292, acc.: 75.00%] [G loss: 0.656537]
748 [D loss: 0.521367, acc.: 75.00%] [G loss: 0.654833]
749 [D loss: 0.606602, acc.: 65.62%] [G loss: 0.513126]
750 [D loss: 0.730767, acc.: 62.50%] [G loss: 0.540645]
751 [D loss: 0.635189, acc.: 56.25%] [G loss: 0.635925]
752 [D loss: 0.782348, acc.: 40.62%] [G loss: 0.732633]
753 [D loss: 0.721061, acc.: 53.12%] [G loss: 0.843701]
754 [D loss: 0.667423, acc.: 65.62%] [G loss: 0.841965]
755 [D loss: 0.757009, acc.: 40.62%] [G loss: 1.021166]
756 [D loss: 0.678957, acc.: 40.62%] [G loss: 0.993238]
757 [D loss: 0.671498, acc.: 56.25%] [G loss: 0.908365]
758 [D loss: 0.761490, acc.: 43.75%] [G loss: 0.

889 [D loss: 0.664948, acc.: 59.38%] [G loss: 0.750603]
890 [D loss: 0.699206, acc.: 65.62%] [G loss: 0.745692]
891 [D loss: 0.647292, acc.: 62.50%] [G loss: 0.779088]
892 [D loss: 0.597092, acc.: 65.62%] [G loss: 0.745177]
893 [D loss: 0.645314, acc.: 56.25%] [G loss: 0.831403]
894 [D loss: 0.744260, acc.: 53.12%] [G loss: 0.839472]
895 [D loss: 0.700513, acc.: 59.38%] [G loss: 0.784212]
896 [D loss: 0.658480, acc.: 56.25%] [G loss: 0.817392]
897 [D loss: 0.658681, acc.: 59.38%] [G loss: 0.819688]
898 [D loss: 0.711230, acc.: 50.00%] [G loss: 0.783757]
899 [D loss: 0.658026, acc.: 59.38%] [G loss: 0.725820]
900 [D loss: 0.615886, acc.: 65.62%] [G loss: 0.774219]
901 [D loss: 0.571441, acc.: 84.38%] [G loss: 0.697828]
902 [D loss: 0.628001, acc.: 59.38%] [G loss: 0.667270]
903 [D loss: 0.712440, acc.: 53.12%] [G loss: 0.635409]
904 [D loss: 0.710423, acc.: 46.88%] [G loss: 0.696296]
905 [D loss: 0.656142, acc.: 65.62%] [G loss: 0.705749]
906 [D loss: 0.568334, acc.: 75.00%] [G loss: 0.