<a href="https://colab.research.google.com/github/trevormoon/GAN_study/blob/main/GAN_Example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import pandas as pd
import numpy as np
import keras
import keras.backend as K
from keras.layers import Conv2D, Activation, Dropout, Flatten, Dense, BatchNormalization, Reshape, UpSampling2D, Input
from keras.models import Model
from keras.optimizers import RMSprop
from keras.preprocessing.image import array_to_img

import warnings ; warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
Xtrain = np.load('full-numpy_bitmap-camel.npy')
Xtrain = np.reshape(Xtrain, (len(Xtrain), 28, 28, 1))
Xtrain = Xtrain / 255

disc_input = Input(shape=(28, 28, 1))

x = Conv2D(filters=64, kernel_size=5, strides=2, padding='same')(disc_input)
x = Activation('relu')(x)
x = Dropout(rate=0.4)(x)

x = Conv2D(filters = 64, kernel_size=5, strides=2, padding='same')(x)
x = Activation('relu')(x)
x = Dropout(rate=0.4)(x)

x = Conv2D(filters=128, kernel_size=5, strides=2, padding='same')(x)
x = Activation('relu')(x)
x = Dropout(rate=0.4)(x)

x = Conv2D(filters=128, kernel_size=5, strides=1, padding='same')(x)
x = Activation('relu')(x)
x = Dropout(rate=0.4)(x)

x = Flatten()(x)
disc_output = Dense(units=1, activation='sigmoid', kernel_initializer='he_normal')(x)

discriminator = Model(disc_input, disc_output)
discriminator.summary()

gen_dense_size=(7, 7, 64)

gen_input = Input(shape = (100, ))
x = Dense(units=np.prod(gen_dense_size))(gen_input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Reshape(gen_dense_size)(x)

x = UpSampling2D()(x)
x = Conv2D(filters=128, kernel_size=5, padding='same', strides=1)(x)
x = BatchNormalization(momentum=0.9)(x)
x = Activation('relu')(x)

x = UpSampling2D()(x)
x = Conv2D(filters = 64, kernel_size=5, padding='same', strides=1)(x)
x = BatchNormalization(momentum=0.9)(x)
x = Activation('relu')(x)

x = Conv2D(filters=64, kernel_size=5, padding='same', strides=1)(x)
x = BatchNormalization(momentum=0.9)(x)
x = Activation('relu')(x)

x = Conv2D(filters=1, kernel_size=5, padding='same', strides=1)(x)
gen_output = Activation('sigmoid')(x)

generator = Model(gen_input, gen_output)
generator.summary()

discriminator.compile(optimizer=RMSprop(lr=0.0008), loss='binary_crossentropy', metrics=['accuracy'])
discriminator.trainable = False
model_input = Input(shape=(100, ))
model_output = discriminator(generator(model_input))
model = Model(model_input, model_output)

model.compile(optimizer=RMSprop(lr=0.0004), loss='binary_crossentropy', metrics=['accuracy'])

def train_discriminator(x_train, batch_size):
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    idx = np.random.randint(0, len(Xtrain), batch_size)
    true_imgs = Xtrain[idx]
    discriminator.fit(true_imgs, valid, verbose=0)

    noise = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(noise)

    discriminator.fit(gen_imgs, fake, verbose=0)

def train_generator(batch_size):
    valid = np.ones((batch_size, 1))
    noise = np.random.normal(0, 1, (batch_size, 100))
    model.fit(noise, valid, verbose=1)

for epoch in tqdm(range(2000)):
    train_discriminator(Xtrain, 64)
    train_generator(64)