In [2]:
# from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Convolution2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np

class GAN():
    def __init__(self):
        #mnistデータ用の入力データサイズ
        self.img_rows = 28 
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        # 潜在変数の次元数 
        self.z_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # discriminatorモデル
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy', 
            optimizer=optimizer,
            metrics=['accuracy'])

        # Generatorモデル
        self.generator = self.build_generator()
        # generatorは単体で学習しないのでコンパイルは必要ない
        #self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        self.combined = self.build_combined1()
        #self.combined = self.build_combined2()
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):
        noise_shape = (self.z_dim,)
        model = Sequential()
        model.add(Dense(1024, input_shape=noise_shape))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Dense(128*7*7))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Reshape((7,7,128), input_shape=(128*7*7,)))
        model.add(UpSampling2D((2,2)))
        model.add(Convolution2D(64,5,5,border_mode='same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(UpSampling2D((2,2)))
        model.add(Convolution2D(1,5,5,border_mode='same'))
        model.add(Activation('tanh'))
        model.summary()
        return model

    def build_discriminator(self):
        img_shape = (self.img_rows, self.img_cols, self.channels)

        model = Sequential()
        model.add(Convolution2D(64,5,5, subsample=(2,2),\
                  border_mode='same', input_shape=img_shape))
        model.add(LeakyReLU(0.2))
        model.add(Convolution2D(128,5,5,subsample=(2,2)))
        model.add(LeakyReLU(0.2))
        model.add(Flatten())
        model.add(Dense(256))
        model.add(LeakyReLU(0.2))
        model.add(Dropout(0.5))
        model.add(Dense(1))
        model.add(Activation('sigmoid'))   
        return model
    
    def build_combined1(self):
        self.discriminator.trainable = False
        model = Sequential([self.generator, self.discriminator])
        return model

    def build_combined2(self):
        z = Input(shape=(self.z_dim,))
        img = self.generator(z)
        self.discriminator.trainable = False
        valid = self.discriminator(img)
        model = Model(z, valid)
        model.summary()
        return model

    def train(self, epochs, batch_size=128, save_interval=50):

        # mnistデータの読み込み
        (X_train, _), (_, _) = mnist.load_data()

        # 値を-1 to 1に規格化
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        half_batch = int(batch_size / 2)

        for epoch in range(epochs):

            # ---------------------
            #  Discriminatorの学習
            # ---------------------

            # バッチサイズの半数をGeneratorから生成
            noise = np.random.normal(0, 1, (half_batch, self.z_dim))
            gen_imgs = self.generator.predict(noise)


            # バッチサイズの半数を教師データからピックアップ
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            imgs = X_train[idx]

            # discriminatorを学習
            # 本物データと偽物データは別々に学習させる
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            # それぞれの損失関数を平均
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Generatorの学習
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.z_dim))

            # 生成データの正解ラベルは本物（1） 
            valid_y = np.array([1] * batch_size)

            # Train the generator
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # 進捗の表示
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # 指定した間隔で生成画像を保存
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        # 生成画像を敷き詰めるときの行数、列数
        r, c = 5, 5

        noise = np.random.normal(0, 1, (r * c, self.z_dim))
        gen_imgs = self.generator.predict(noise)

        # 生成画像を0-1に再スケール
        gen_imgs = 0.5 * gen_imgs +  0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j]a.xis('off')
                cnt += 1
        fig.savefig("images/gcgan/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=1000, batch_size=32, save_interval=100)

W0902 17:57:37.607595 4692112832 deprecation_wrapper.py:119] From /Users/sungwoo/anaconda3/envs/keras/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:541: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0902 17:57:37.609777 4692112832 deprecation_wrapper.py:119] From /Users/sungwoo/anaconda3/envs/keras/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4432: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0902 17:57:37.664798 4692112832 deprecation_wrapper.py:119] From /Users/sungwoo/anaconda3/envs/keras/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:148: The name tf.placeholder_with_default is deprecated. Please use tf.compat.v1.placeholder_with_default instead.

W0902 17:57:37.672465 4692112832 deprecation.py:506] From /Users/sungwoo/anaconda3/envs/keras/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3733: calling dropout (from tensorflow.python.ops.n

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_3 (Dense)              (None, 1024)              103424    
_________________________________________________________________
batch_normalization_1 (Batch (None, 1024)              4096      
_________________________________________________________________
activation_2 (Activation)    (None, 1024)              0         
_________________________________________________________________
dense_4 (Dense)              (None, 6272)              6428800   
_________________________________________________________________
batch_normalization_2 (Batch (None, 6272)              25088     
_________________________________________________________________
activation_3 (Activation)    (None, 6272)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 128)        

  'Discrepancy between trainable weights and collected trainable'
  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.670552, acc.: 46.88%] [G loss: 0.537502]


  'Discrepancy between trainable weights and collected trainable'


1 [D loss: 0.926850, acc.: 50.00%] [G loss: 0.502425]
2 [D loss: 0.712755, acc.: 50.00%] [G loss: 0.598763]
3 [D loss: 0.622661, acc.: 53.12%] [G loss: 0.675114]
4 [D loss: 0.662687, acc.: 50.00%] [G loss: 0.597554]
5 [D loss: 0.577654, acc.: 59.38%] [G loss: 0.454302]
6 [D loss: 0.317154, acc.: 100.00%] [G loss: 0.370394]
7 [D loss: 0.176194, acc.: 100.00%] [G loss: 0.202387]
8 [D loss: 0.260024, acc.: 87.50%] [G loss: 0.104367]
9 [D loss: 1.389304, acc.: 50.00%] [G loss: 0.236735]
10 [D loss: 0.704517, acc.: 50.00%] [G loss: 0.616106]
11 [D loss: 0.719627, acc.: 34.38%] [G loss: 0.649072]
12 [D loss: 0.510647, acc.: 87.50%] [G loss: 0.520975]
13 [D loss: 0.267875, acc.: 100.00%] [G loss: 0.408986]
14 [D loss: 0.326711, acc.: 100.00%] [G loss: 0.296245]
15 [D loss: 0.812876, acc.: 46.88%] [G loss: 0.331643]
16 [D loss: 0.711654, acc.: 50.00%] [G loss: 0.396956]
17 [D loss: 0.690297, acc.: 46.88%] [G loss: 0.456904]
18 [D loss: 0.492561, acc.: 87.50%] [G loss: 0.447109]
19 [D loss: 0.6

150 [D loss: 0.601584, acc.: 71.88%] [G loss: 0.728379]
151 [D loss: 0.910456, acc.: 31.25%] [G loss: 0.535139]
152 [D loss: 0.532625, acc.: 84.38%] [G loss: 0.494297]
153 [D loss: 0.396972, acc.: 87.50%] [G loss: 0.535485]
154 [D loss: 0.465591, acc.: 87.50%] [G loss: 0.598874]
155 [D loss: 0.769406, acc.: 50.00%] [G loss: 0.440861]
156 [D loss: 0.683315, acc.: 62.50%] [G loss: 0.434806]
157 [D loss: 0.408061, acc.: 93.75%] [G loss: 0.578193]
158 [D loss: 0.411713, acc.: 87.50%] [G loss: 0.642163]
159 [D loss: 0.867611, acc.: 34.38%] [G loss: 0.499100]
160 [D loss: 0.692575, acc.: 62.50%] [G loss: 0.439804]
161 [D loss: 0.498850, acc.: 81.25%] [G loss: 0.456170]
162 [D loss: 0.508497, acc.: 81.25%] [G loss: 0.493709]
163 [D loss: 0.859021, acc.: 34.38%] [G loss: 0.344728]
164 [D loss: 0.758102, acc.: 46.88%] [G loss: 0.391821]
165 [D loss: 0.581569, acc.: 78.12%] [G loss: 0.508416]
166 [D loss: 0.611311, acc.: 68.75%] [G loss: 0.553252]
167 [D loss: 0.702758, acc.: 50.00%] [G loss: 0.

297 [D loss: 0.720559, acc.: 59.38%] [G loss: 0.634566]
298 [D loss: 0.585133, acc.: 68.75%] [G loss: 0.606935]
299 [D loss: 0.497159, acc.: 87.50%] [G loss: 0.636418]
300 [D loss: 0.599915, acc.: 71.88%] [G loss: 0.649275]
301 [D loss: 0.731195, acc.: 50.00%] [G loss: 0.571555]
302 [D loss: 0.573800, acc.: 75.00%] [G loss: 0.571850]
303 [D loss: 0.502262, acc.: 75.00%] [G loss: 0.629014]
304 [D loss: 0.475763, acc.: 84.38%] [G loss: 0.617083]
305 [D loss: 0.608954, acc.: 65.62%] [G loss: 0.626717]
306 [D loss: 0.731175, acc.: 56.25%] [G loss: 0.511843]
307 [D loss: 0.599962, acc.: 68.75%] [G loss: 0.573445]
308 [D loss: 0.562996, acc.: 59.38%] [G loss: 0.468838]
309 [D loss: 0.679782, acc.: 53.12%] [G loss: 0.535441]
310 [D loss: 0.722687, acc.: 50.00%] [G loss: 0.453437]
311 [D loss: 0.681029, acc.: 56.25%] [G loss: 0.456189]
312 [D loss: 0.471370, acc.: 84.38%] [G loss: 0.546408]
313 [D loss: 0.498520, acc.: 84.38%] [G loss: 0.498911]
314 [D loss: 0.671498, acc.: 59.38%] [G loss: 0.

444 [D loss: 0.450114, acc.: 81.25%] [G loss: 0.454380]
445 [D loss: 0.392234, acc.: 87.50%] [G loss: 0.355939]
446 [D loss: 0.537588, acc.: 75.00%] [G loss: 0.306921]
447 [D loss: 0.538266, acc.: 68.75%] [G loss: 0.359956]
448 [D loss: 0.587381, acc.: 65.62%] [G loss: 0.376507]
449 [D loss: 0.671446, acc.: 50.00%] [G loss: 0.428036]
450 [D loss: 0.805127, acc.: 53.12%] [G loss: 0.562454]
451 [D loss: 0.686995, acc.: 65.62%] [G loss: 0.706618]
452 [D loss: 0.613240, acc.: 68.75%] [G loss: 0.821261]
453 [D loss: 0.633106, acc.: 71.88%] [G loss: 0.811379]
454 [D loss: 0.708692, acc.: 53.12%] [G loss: 0.732788]
455 [D loss: 0.554131, acc.: 71.88%] [G loss: 0.544124]
456 [D loss: 0.403427, acc.: 90.62%] [G loss: 0.500440]
457 [D loss: 0.490519, acc.: 75.00%] [G loss: 0.489628]
458 [D loss: 0.572076, acc.: 68.75%] [G loss: 0.464934]
459 [D loss: 0.676254, acc.: 56.25%] [G loss: 0.522751]
460 [D loss: 0.654096, acc.: 59.38%] [G loss: 0.615965]
461 [D loss: 0.653960, acc.: 59.38%] [G loss: 0.

591 [D loss: 0.776995, acc.: 46.88%] [G loss: 0.596335]
592 [D loss: 0.539344, acc.: 71.88%] [G loss: 0.804362]
593 [D loss: 0.676422, acc.: 62.50%] [G loss: 0.773072]
594 [D loss: 0.644549, acc.: 65.62%] [G loss: 0.793167]
595 [D loss: 0.587966, acc.: 71.88%] [G loss: 0.640528]
596 [D loss: 0.528049, acc.: 71.88%] [G loss: 0.646976]
597 [D loss: 0.453025, acc.: 81.25%] [G loss: 0.534260]
598 [D loss: 0.381834, acc.: 90.62%] [G loss: 0.564864]
599 [D loss: 0.533173, acc.: 68.75%] [G loss: 0.451503]
600 [D loss: 0.678373, acc.: 62.50%] [G loss: 0.489263]
601 [D loss: 0.740703, acc.: 37.50%] [G loss: 0.577756]
602 [D loss: 0.821635, acc.: 43.75%] [G loss: 0.680765]
603 [D loss: 0.683611, acc.: 59.38%] [G loss: 0.895083]
604 [D loss: 0.664655, acc.: 65.62%] [G loss: 0.964255]
605 [D loss: 0.655577, acc.: 65.62%] [G loss: 0.871146]
606 [D loss: 0.618092, acc.: 65.62%] [G loss: 0.842987]
607 [D loss: 0.501506, acc.: 71.88%] [G loss: 0.775239]
608 [D loss: 0.331259, acc.: 90.62%] [G loss: 0.

738 [D loss: 0.587827, acc.: 78.12%] [G loss: 0.757252]
739 [D loss: 0.510378, acc.: 81.25%] [G loss: 0.778481]
740 [D loss: 0.604578, acc.: 71.88%] [G loss: 0.742220]
741 [D loss: 0.564570, acc.: 75.00%] [G loss: 0.773871]
742 [D loss: 0.524164, acc.: 84.38%] [G loss: 0.797483]
743 [D loss: 0.626596, acc.: 65.62%] [G loss: 0.839729]
744 [D loss: 0.639117, acc.: 59.38%] [G loss: 0.789906]
745 [D loss: 0.534216, acc.: 75.00%] [G loss: 0.812982]
746 [D loss: 0.563732, acc.: 62.50%] [G loss: 0.797398]
747 [D loss: 0.664294, acc.: 59.38%] [G loss: 0.771454]
748 [D loss: 0.595396, acc.: 71.88%] [G loss: 0.690536]
749 [D loss: 0.617203, acc.: 68.75%] [G loss: 0.703737]
750 [D loss: 0.695555, acc.: 65.62%] [G loss: 0.747285]
751 [D loss: 0.746182, acc.: 50.00%] [G loss: 0.726678]
752 [D loss: 0.584967, acc.: 71.88%] [G loss: 0.761830]
753 [D loss: 0.605675, acc.: 71.88%] [G loss: 0.912273]
754 [D loss: 0.611235, acc.: 65.62%] [G loss: 0.902084]
755 [D loss: 0.573542, acc.: 78.12%] [G loss: 0.

885 [D loss: 0.637986, acc.: 59.38%] [G loss: 0.835599]
886 [D loss: 0.580726, acc.: 78.12%] [G loss: 0.966500]
887 [D loss: 0.654777, acc.: 53.12%] [G loss: 1.036336]
888 [D loss: 0.634794, acc.: 68.75%] [G loss: 0.972688]
889 [D loss: 0.669850, acc.: 56.25%] [G loss: 0.841989]
890 [D loss: 0.625755, acc.: 68.75%] [G loss: 0.819389]
891 [D loss: 0.577399, acc.: 71.88%] [G loss: 0.780542]
892 [D loss: 0.598433, acc.: 75.00%] [G loss: 0.777673]
893 [D loss: 0.583108, acc.: 71.88%] [G loss: 0.876898]
894 [D loss: 0.678227, acc.: 50.00%] [G loss: 0.827656]
895 [D loss: 0.650488, acc.: 62.50%] [G loss: 0.873708]
896 [D loss: 0.607754, acc.: 65.62%] [G loss: 0.932703]
897 [D loss: 0.726486, acc.: 59.38%] [G loss: 0.987862]
898 [D loss: 0.767125, acc.: 56.25%] [G loss: 0.911525]
899 [D loss: 0.645815, acc.: 56.25%] [G loss: 0.984620]
900 [D loss: 0.751375, acc.: 46.88%] [G loss: 0.899482]
901 [D loss: 0.754176, acc.: 53.12%] [G loss: 0.935513]
902 [D loss: 0.567285, acc.: 68.75%] [G loss: 0.