In [None]:
import os
import glob
import time

import PIL
import imageio
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from IPython import display
import matplotlib.pyplot as plt
%matplotlib inline

cifar10 = tf.keras.datasets.cifar10

def load_data(label=None):
    (train_x, train_y), (test_x, test_y) = tf.keras.datasets.cifar10.load_data()
    if label:
        df = pd.DataFrame(list(zip(train_x, train_y)), columns=['image', 'label']) 
        df = df[df['label']==label]
        train_x = np.array([i for i in list(df['image'])])
        df = pd.DataFrame(list(zip(test_x, test_y)), columns =['image', 'label']) 
        df = df[df['label']==label]
        test_x = np.array([i for i in list(df['image'])])
    return train_x, test_x

#train_x 50,000장 test_x 10,000장으로 이루어져 있음.
train_x,test_x = load_data()
train_x.shape

print("step1 데이터 가져오기 완료!")

train_x = (train_x - 127.5) / 127.5
BUFFER_SIZE = 50000
BATCH_SIZE = 256
train_dataset = tf.data.Dataset.from_tensor_slices(train_x).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
print("step2 데이터 전처리 완료!")

def make_generator_model():

    # Start
    model = tf.keras.Sequential()

    # First: Dense layer
    model.add(layers.Dense(8*8*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    # Second: Reshape layer
    model.add(layers.Reshape((8, 8, 256)))

    # Third: Conv2DTranspose layer
    model.add(layers.Conv2DTranspose(128, kernel_size=(5, 5), strides=(1, 1), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    # Fourth: Conv2DTranspose layer
    model.add(layers.Conv2DTranspose(64, kernel_size=(5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    # Fifth: Conv2DTranspose layer
    model.add(layers.Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='same', use_bias=False, \
                                     activation='tanh'))

    return model

generator = make_generator_model()
generator.summary()
print("step3 생성자 모델 구현 완료!")

def make_discriminator_model():

    # Start
    model = tf.keras.Sequential()

    # First: Conv2D Layer
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[32, 32, 3]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    # Second: Conv2D Layer
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    # Third: Flatten Layer
    model.add(layers.Flatten())

    # Fourth: Dense Layer
    model.add(layers.Dense(1))

    return model

discriminator = make_discriminator_model()
discriminator.summary()
print("step4 판별자 모델 구현 완료!")

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output),fake_output)
    
def discriminator_loss(real_output,fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output),real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output)
    total_loss = real_loss + fake_loss
    return total_loss 

def discriminator_accuracy(real_output,fake_output):
    real_accuracy = tf.reduce_mean(tf.cast(tf.math.greater_equal(real_output, tf.constant([0.5])), tf.float32))
    fake_accuracy = tf.reduce_mean(tf.cast(tf.math.less(fake_output, tf.constant([0.5])), tf.float32))
    return real_accuracy, fake_accuracy

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

BATCH_SIZE = 256
noise_dim = 100
num_examples_to_generate = 16

seed = tf.random.normal([num_examples_to_generate, noise_dim])
seed.shape

print("step5 손실함수 최적화 함수 구현 완료!")


def train_step(images):
    #1. 노이즈를 생성해준다.
    noise = tf.random.normal([BATCH_SIZE,noise_dim])
    #2. 미분을 계산할 수 있는 그라디언트 테이프를 생성자, 판별자용으로 2개 만든다.
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        #3. 그 안에서 계산을 해준다. 
        #4. 위에서 만든 노이즈를 가지고 생성자로 이미지를 만든다. 
        generated_images = generator(noise,training=True)
        #5. 판별자로 판별한다. 
        real_output= discriminator(images,training=True)
        fake_output = discriminator(generated_images,training=True)
    #6. loss 계산한다.
        gen_loss =  generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output,fake_output)
    #7. accuracy 계산한다. 
        real_accuracy, fake_accuracy = discriminator_accuracy(real_output,fake_output)
    #8. 미분을 계산해야 할 것들의 계산이 끝난다. 
    #9. gradient를 생성자,판별자 각각 계산한다. 
        gen_graident = gen_tape.gradient(gen_loss,generator.trainable_variables)
        disc_gradient = disc_tape.gradient(disc_loss,discriminator.trainable_variables)
    #10. 각각의 옵티마이저에 그라디언트를 적용한다. 
        generator_optimizer.apply_gradients(zip(gen_graident,generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(disc_gradient,discriminator.trainable_variables))
    #11. 생성자와 판별자의 loss, accuracy를 리턴한다. 
        return gen_loss,disc_loss,real_accuracy,fake_accuracy
    
def generate_and_save_images(model,epoch,it,sample_seeds):
    #1. 모델을 통해 샘플을 prediction을 하나 만든다. 
    predictions = model(sample_seeds,training=False)
    #2. 4,4 subplot으로 이미지 그려주기
    for i in range(predictions.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow(predictions[i,:,:,0],cmap='gray')
        plt.axis('off')
    #3. savefig를 통해 그렇게 만들어진 이미지를 저장해주기 
    plt.savefig('{}/aiffel/dcgan_newimage/cifar10/generated_samples/sample_epoch_{:04d}_iter_{:03d}.png'
                    .format(os.getenv('HOME'), epoch, it))
    plt.show()
    
def draw_train_history(history,epoch):
    #1. loss 그려주기
    plt.subplot(211)
    plt.plot(history['gen_loss'])
    plt.plot(history['disc_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('batch iters')
    plt.legend(['gen_loss','disc_loss'],loc='upper left')
    
    #2. accuracy 그려주기 
    plt.subplot(212)
    plt.plot(history['fake_accuracy'])
    plt.plot(history['real_accuracy'])
    plt.title('discriminator accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('batch iters')
    plt.legend(['fake_accuracy','real_accuracy'],loc='upper left')
    
    #3. 저장해주기 
    plt.savefig('{}/aiffel/dcgan_newimage/cifar10/training_history/train_history_{:04d}.png'
                    .format(os.getenv('HOME'), epoch))
    plt.show()
    
#1. 체크포인트를 저장할 폴더 위치 정해주기
checkpoint_dir = os.getenv('HOME')+'/aiffel/dcgan_newimage/cifar10/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
#2. 옵티마이저 각각, 모델 각각 저장해주기 
checkpoints = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator = generator,
                                 discriminator=discriminator)

print("step6 훈련과정 상세 기능 구현 완료!")

def train(dataset, epochs, save_every):
    start = time.time()
    history = {'gen_loss':[], 'disc_loss':[], 'real_accuracy':[], 'fake_accuracy':[]}

    for epoch in range(epochs):
        epoch_start = time.time()
        for it, image_batch in enumerate(dataset):
            gen_loss, disc_loss, real_accuracy, fake_accuracy = train_step(image_batch)
            history['gen_loss'].append(gen_loss)
            history['disc_loss'].append(disc_loss)
            history['real_accuracy'].append(real_accuracy)
            history['fake_accuracy'].append(fake_accuracy)

            if it % 50 == 0:
                display.clear_output(wait=True)
                generate_and_save_images(generator, epoch+1, it+1, seed)
                print('Epoch {} | iter {}'.format(epoch+1, it+1))
                print('Time for epoch {} : {} sec'.format(epoch+1, int(time.time()-epoch_start)))

        if (epoch + 1) % save_every == 0:
            checkpoints.save(file_prefix=checkpoint_prefix)

        display.clear_output(wait=True)
        generate_and_save_images(generator, epochs, it, seed)
        print('Time for training : {} sec'.format(int(time.time()-start)))

        draw_train_history(history, epoch)
        
save_every = 5
EPOCHS = 50
train(train_dataset, EPOCHS, save_every)