<a href="https://colab.research.google.com/github/raspberryscorn/2023/blob/main/GAN_2023.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GAN



## import modules

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from matplotlib import pyplot as plt
from tensorflow import keras
import tensorflow as tf
from PIL import Image
import numpy as np
import imageio
import glob
from IPython import display
from matplotlib import pyplot as plt
import time

## dataset

In [None]:
datasets = {
    'mnist' : keras.datasets.mnist,
    'fashion_mnist' : keras.datasets.fashion_mnist,
    'cifar10' : keras.datasets.cifar10
}

In [None]:
class DataReader():
    def __init__(self, dataset):
        self.dataset = dataset
        data_len = 60000 if self.dataset == 'mnist' else 50000
        (self.train_X, _), (_, _) = datasets[dataset].load_data()
        self.train_X = self.preprocess(self.train_X)
        self.train_dataset = tf.data.Dataset.from_tensor_slices(self.train_X).shuffle(data_len).batch(256)

    def preprocess(self, images):
        if self.dataset != 'cifar10':
            images = images.reshape(images.shape[0], 28, 28, 1).astype('float32')
        images = images / 127.5 - 1
        return images

    def show_processed_images(self):
        plt.figure(figsize=(10, 10))
        for i in range(25):
            plt.subplot(5, 5, i + 1)
            plt.xticks([])
            plt.yticks([])
            plt.grid(False)
            plt.imshow(self.train_X[i])
        plt.show()



## gan model

In [None]:
class GAN:
    def __init__(self, dataset, noise_dim=100):
        self.in_channels = 8 if dataset=='cifar10' else 7
        self.out_channels = 3 if dataset=='cifar10' else 1
        self.img_shape = (32, 32, 3) if dataset=='cifar10' else (28, 28, 1)

    def make_generator(self):
        model = keras.Sequential([
        keras.layers.Dense(self.in_channels*self.in_channels*128, use_bias=False, input_shape=(noise_dim,)),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),

        keras.layers.Reshape((self.in_channels, self.in_channels, 128)),

        keras.layers.Conv2DTranspose(32, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),

        keras.layers.Conv2DTranspose(8, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        keras.layers.BatchNormalization(),
        keras.layers.LeakyReLU(),

        keras.layers.Conv2DTranspose(self.out_channels, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
        ])

        return model

    def make_discriminator(self):
        model = keras.Sequential([
        keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=self.img_shape),
        keras.layers.LeakyReLU(),
        keras.layers.Dropout(0.3),

        keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        keras.layers.LeakyReLU(),
        keras.layers.Dropout(0.3),

        keras.layers.Flatten(),
        keras.layers.Dense(1)
        ])

        return model

## train functions

In [None]:
def loss_D(real_output, fake_output):
    real_loss = keras.losses.BinaryCrossentropy(from_logits=True)(tf.ones_like(real_output), real_output)
    fake_loss = keras.losses.BinaryCrossentropy(from_logits=True)(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss


def loss_G(fake_output):
    return keras.losses.BinaryCrossentropy(from_logits=True)(tf.ones_like(fake_output), fake_output)


generator_optimizer = keras.optimizers.Adam(1e-4)
discriminator_optimizer = keras.optimizers.Adam(1e-4)

noise_dim = 100
seed = tf.random.normal([36, noise_dim])


# `tf.function`이 어떻게 사용되는지 주목해 주세요.
# 이 데코레이터는 함수를 "컴파일"합니다.
@tf.function
def train_step(generator, discriminator, images):
    noise = tf.random.normal([256, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = loss_G(fake_output)
        disc_loss = loss_D(real_output, fake_output)

    gradient_G = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradient_D = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradient_G, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradient_D, discriminator.trainable_variables))

    return gen_loss, disc_loss


def generate_and_save_images(model, epoch, test_input, dataset):
    predictions = model(test_input, training=False)
    fig = plt.figure(figsize=(6, 6))

    for i in range(predictions.shape[0]):
        plt.subplot(6, 6, i+1)
        if dataset != 'cifar10':
            plt.imshow(((predictions[i, :, :, 0]) + 1)/2)
        else:
            plt.imshow(((predictions[i]) + 1)/2)
        plt.axis('off')

    plt.savefig(f'results/{dataset}/image_at_epoch_{epoch:04d}.png')
    plt.close(fig)


def train(generator, discriminator, dataset, epochs, dataset_name):
    os.makedirs(f"results/{dataset_name}", exist_ok=True)

    for epoch in range(epochs):
        start = time.time()
        for image_batch in dataset:
            gen_loss, disc_loss = train_step(generator, discriminator, image_batch)
        duration = time.time() - start
        display.clear_output(wait=True)
        generate_and_save_images(generator, epoch + 1, seed, dataset_name)
        print("Epoch " + str(epoch + 1) + "   Generator Loss : " + str(float(gen_loss))[:7]
                        + "   Discriminator Loss : " + str(float(disc_loss))[:7]
                        + "   Time : " + str(duration)[:5] + " seconds")

    # 마지막 에포크가 끝난 후 생성합니다.
    display.clear_output(wait=True)
    generate_and_save_images(generator, epochs, seed, dataset_name)


def gif_generation(dataset):
    anim_file = f'results/{dataset}.gif'

    with imageio.get_writer(anim_file, mode='I') as writer:
        filenames = glob.glob(f'results/{dataset}/image*.png')
        filenames = sorted(filenames)
        last = -1
        for i, filename in enumerate(filenames):
            frame = 2 * (i ** 0.5)
            if round(frame) > round(last):
                last = frame
            else:
                continue
            image = imageio.imread(filename)
            writer.append_data(image)
        image = imageio.imread(filename)
        writer.append_data(image)

    import IPython
    if IPython.version_info > (6, 2, 0, ''):
        display.Image(filename=anim_file)


## main

In [None]:
EPOCHS = 10  # 예제 기본값은 200입니다.

dataset = 'mnist'
# dataset = 'fashion_mnist'
# dataset = 'cifar10'


# 데이터를 읽어옵니다.
dr = DataReader(dataset)

# GAN을 불러옵니다.
gan = GAN(dataset, noise_dim)

# Generator
generator = gan.make_generator()
# Discriminator
discriminator = gan.make_discriminator()

# 인공신경망을 학습시킵니다.
print("\n\n************ TRAINING START ************ ")
train(generator, discriminator, dr.train_dataset, EPOCHS, dataset_name=dataset)

# GIF 애니메이션을 저장합니다.
gif_generation(dataset)

  image = imageio.imread(filename)
  image = imageio.imread(filename)
