In [56]:
%matplotlib inline

from keras.datasets import mnist
import numpy as np
from PIL import Image
import math
import os

import keras.backend as K
import tensorflow as tf

In [57]:
K.set_image_data_format('channels_first')
print(K.image_data_format())

channels_first


GAN 모델 구축

In [58]:
from keras import models, layers, optimizers

커스텀 loss function 정의

In [59]:
def mse_4d(y_true, y_pred):
    return K.mean(K.square(y_pred - y_true), axis=(1,2,3)) # axis=()은 안에 있는 차원순서대로 하라는 뜻

def mse_4d_tf(label, pred):
    return tf.reduce_mean(tf.square(pred-label), axis=(1,2,3))

In [60]:
# https://stackoverflow.com/questions/54126451/what-does-axis-1-2-3-mean-in-k-sum-in-keras-backend
pred = np.arange(start=0, stop=150).reshape((2, 3, 5, 5))
label = np.arange(start=150, stop=300).reshape((2, 3, 5, 5))

# import tensorflow as tf
# with tf.Session() as session:
#     a = mse_4d(pred, label)
#     print(a.shape)
#     print(a.eval())


In [61]:
class GAN(models.Sequential):
    def __init__(self, input_dim=64):
        super().__init__()
        self.input_dim = input_dim        
        self.generator = self.GENERATOR()
        self.discriminator = self.DISCRIMINATOR()
        self.add(self.generator)
        self.discriminator.trainable = False
        self.add(self.discriminator)
        
        self.compile_all()
        
    def compile_all(self):
        d_optim = optimizers.SGD(lr=0.0005, momentum=0.9, nesterov=True)
        g_optim = optimizers.SGD(lr=0.0005, momentum=0.9, nesterov=True)
        self.generator.compile(loss=mse_4d_tf, optimizer='SGD')
        self.compile(loss='binary_crossentropy', optimizer=g_optim)
        self.discriminator.trainable = True
        self.discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)        
        
    def GENERATOR(self):
        input_dim = self.input_dim
        print("input_dim : ", input_dim)
        
        model = models.Sequential()
        model.add(layers.Dense(1024, activation='tanh', input_dim=input_dim))
        model.add(layers.Dense(128 * 7 * 7, activation='tanh'))
        model.add(layers.BatchNormalization())
        model.add(layers.Reshape((128, 7, 7), input_shape=(128 * 7 * 7,))) #이제 그림 모양으로 변경한다. 앞에가 채널번호임.
        model.add(layers.UpSampling2D(size=(2, 2)))
        model.add(layers.Conv2D(64, (5, 5), padding='same', activation='tanh'))
        model.add(layers.UpSampling2D(size=(2, 2)))
        model.add(layers.Conv2D(1, (5, 5), padding='same', activation='tanh'))
        return model
        
    def DISCRIMINATOR(self):
        model = models.Sequential()
        model.add(layers.Conv2D(64, (5, 5), padding='same', activation='tanh', input_shape=(1, 28, 28)))
        model.add(layers.MaxPooling2D(pool_size=(2, 2)))
        model.add(layers.Conv2D(128, (5, 5), activation='tanh'))
        model.add(layers.MaxPooling2D(pool_size=(2, 2)))
        model.add(layers.Flatten())
        model.add(layers.Dense(1024, activation='tanh'))
        model.add(layers.Dense(1, activation='sigmoid'))
        return model
    
    def get_random_image(self, data_count): #data_count는 batch_size라고 해도 좋을 듯하다.
        input_dim = self.input_dim
        return np.random.uniform(-1, 1, (data_count, input_dim))
    
    def train_both(self, x):
        print("############ 1")
        ln = x.shape[0]
        print("############ 2")
        # train discriminator
        input_image = self.get_random_image(ln)
        print("input_image.shape=", input_image.shape)
        w = self.generator.predict(input_image, verbose=0)
        print("w.shape=", w.shape)
        xw = np.concatenate((x, w)) # 진짜 뒤에 가짜를 붙인다.
        label = [1] * ln + [0] * ln # 진짜와 가짜를 붙였으므로 앞에것은 진짜 label(1), 뒤에것은 가짜 label(0)
        #print("label.shape=", np.array(label).shape)
        label = np.array(label).reshape(-1, 1)
        d_loss = self.discriminator.train_on_batch(xw, label)
        print("############ 3")
        # train generator
        input_image = self.get_random_image(ln)
        self.discriminator.trainable = False #컴파일되어있는것도 참조해서 그런건가?
        label = np.array([1] * ln).reshape(-1, 1)
        print("label.shape=", label.shape)
        g_loss = self.train_on_batch(input_image, label)
        self.discriminator.trainable = True
        


GAN 학습

In [62]:
def get_x(X_train, index, BATCH_SIZE):
    return X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]

def load_data(n_train):
    (X_train, y_train), (_, _) = mnist.load_data()
    return X_train[:n_train]


In [63]:
def train(BATCH_SIZE, epochs, output_fold, input_dim, n_train):
#     BATCH_SIZE = args.batch_size
#     epochs = args.epochs
#     output_fold = args.output_fold
#     input_dim = args.input_dim
#     n_train = args.n_train

    X_train = load_data(n_train)
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])

    gan = GAN(input_dim)
    
    d_loss_ll = []
    g_loss_ll = []
    for epoch in range(epochs):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))

        d_loss_l = []
        g_loss_l = []
        
        for index in range(int(X_train.shape[0] / BATCH_SIZE)):
            x = get_x(X_train, index, BATCH_SIZE)
            
            print("x.shape=", x.shape)
            (d_loss, g_loss) = gan.train_both(x)

            d_loss_l.append(d_loss)
            g_loss_l.append(g_loss)

        if epoch % 10 == 0 or epoch == epochs - 1:
            z = gan.get_z(x.shape[0])
            w = gan.generator.predict(z, verbose=0)
            #save_images(w, output_fold, epoch, 0)

        d_loss_ll.append(d_loss_l)
        g_loss_ll.append(g_loss_l)

    gan.generator.save_weights(output_fold + '/' + 'generator', True)
    gan.discriminator.save_weights(output_fold + '/' + 'discriminator', True)

    np.savetxt(output_fold + '/' + 'd_loss', d_loss_ll)
    np.savetxt(output_fold + '/' + 'g_loss', g_loss_ll)


In [64]:
def main():
    #def train(BATCH_SIZE, epochs, output_fold, input_dim, n_train):
    train(16, 1000, 'GAN_OUT', 10, 32)
    
if __name__ == '__main__':
    main()

input_dim :  10
Epoch is 0
Number of batches 2
x.shape= (16, 1, 28, 28)
############ 1
############ 2
input_image.shape= (16, 10)
w.shape= (16, 1, 28, 28)
############ 3
label.shape= (16, 1)


AttributeError: 'GAN' object has no attribute '_output_tensor_cache'