In [1]:
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

import os
import gc
from tqdm import tqdm
import random

import warnings
warnings.filterwarnings('ignore')

In [2]:
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import optimizers, callbacks, layers
from tensorflow.keras.layers import Dense, Concatenate, Activation, Add, BatchNormalization, Dropout, Input
from tensorflow.keras.models import Model, Sequential, load_model

SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
random.seed(SEED)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        # 프로그램 시작시에 메모리 증가가 설정되어야만 합니다
        print(e)

def mish(x):
    return x*tf.math.tanh(tf.math.softplus(x))

def decay(epochs):
    init = 1e-3
    drop = 10
    ratio = 0.9
    return max(5e-5, (init * (ratio ** (epochs//drop))))

es = callbacks.EarlyStopping(patience=10, restore_best_weights=True)
lrs = callbacks.LearningRateScheduler(decay, verbose=0)


In [3]:
class GAN(keras.Model):
    def __init__(self, d_shape, z_dim):
        super(GAN, self).__init__()
        self.d_shape = d_shape
        self.z_dim = z_dim
        
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
    
    def compile(self, g_optim, d_optim, loss_fn):
        super(GAN, self).compile()
        self.g_optim = g_optim
        self.d_optim = d_optim
        self.loss_fn = loss_fn
        
    def build_generator(self):
        activation = mish
        inputs = Input(shape=(self.z_dim, ))
        
        x = Dense(128, kernel_initializer='he_normal')(inputs)
        x = Activation(activation)(x)
        x = Dense(256, kernel_initializer='he_normal')(x)
        x = Activation(activation)(x)
        x = Dense(512, kernel_initializer='he_normal')(x)
        x = Activation(activation)(x)
        
        outputs = Dense(self.d_shape[0], activation='sigmoid', kernel_initializer='he_normal')(x)
        return Model(inputs, outputs)
    
    def build_discriminator(self):
        inputs = Input(shape = self.d_shape)
        
        x = Dense(512)(inputs)
        x = Dense(256)(x)
        x = Dense(128)(x)
        
        outputs = Dense(1, activation='sigmoid')(x)
        
        return Model(inputs, outputs)
    
    def train_step(self, x):
        batch_size = tf.shape(x)[0]
        
        noise = self.sampler(batch_size)
        
        fake_x = self.generator(noise)
        all_x = tf.concat([fake_x, x], 0)
        
        fake_labels = tf.ones((batch_size, 1))*0
        real_labels = tf.ones((batch_size, 1))
        labels = tf.concat([fake_labels, real_labels], 0)
        
        # keras official tutorial saids add noise to label is important trick
        # labels = 0.05*tf.random.uniform(tf.shape(labels))
        
        # disc / gen alternatively
        with tf.GradientTape() as tape:
            preds = self.discriminator(all_x)
            d_loss = self.loss_fn(labels, preds)

        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optim.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        with tf.GradientTape() as tape:
            preds = self.discriminator(self.generator(noise))
            g_loss = self.loss_fn(real_labels, preds)
            
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optim.apply_gradients(zip(grads, self.generator.trainable_weights))

        return {'d_loss': d_loss, 'g_loss': g_loss}
    
    def sampler(self, batch_size):
        return tf.random.normal(shape=(batch_size, self.z_dim))


In [4]:
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
X = np.concatenate([x_train, x_test])
X = X.astype("float32") / 255
X = X.reshape(X.shape[0], 28*28)

# tf.data 공부하고 추가 적용하기
# batch_size = 32
# data = tf.data.Dataset.from_tensor_slices(X)
# data = data.shuffle(buffer_size=1024).batch(batch_size).prefetch(32)

In [5]:
# Check list
## model architecture ~ for images or complicated data, carefully design
## use callbacks to check training progress
gan = GAN(X.shape[1:], 100)

In [6]:
gan.generator.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
dense (Dense)                (None, 128)               12928     
_________________________________________________________________
activation (Activation)      (None, 128)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               33024     
_________________________________________________________________
activation_1 (Activation)    (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 512)               131584    
_________________________________________________________________
activation_2 (Activation)    (None, 512)              

In [7]:
gan.discriminator.summary()

Model: "functional_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 784)]             0         
_________________________________________________________________
dense_4 (Dense)              (None, 512)               401920    
_________________________________________________________________
dense_5 (Dense)              (None, 256)               131328    
_________________________________________________________________
dense_6 (Dense)              (None, 128)               32896     
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 129       
Total params: 566,273
Trainable params: 566,273
Non-trainable params: 0
_________________________________________________________________


In [8]:
class SampleCallback(callbacks.Callback):
    def __init__(self):
        self.score = float('inf')
        
    def on_epoch_end(self, epoch, losg=None):
        '''
        if you have val dataset, use them here
        (i.e.) 
        score = criterion(val_y, self.generator.predict(val_X))
        if score < self.score:
            self.score = score
            self.best_weights = self.generator.get_weights()
        '''
        pass
    
    def on_train_end(self):
        self.generator.set_weights(self.best_weights)
        

In [9]:
gan.compile(
    optimizers.Adam(2e-4),
    optimizers.Adam(2e-4),
    keras.losses.BinaryCrossentropy()
           )

gan.fit(X, epochs=5)

# TODO
## custom callback to check training progress

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x214c274d208>