In [1]:
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

import os
import gc
from tqdm import tqdm
import random

import warnings
warnings.filterwarnings('ignore')

In [2]:
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import optimizers, callbacks, layers
from tensorflow.keras.layers import Dense, Concatenate, Activation, Add, BatchNormalization, Dropout, Input
from tensorflow.keras.models import Model, Sequential, load_model

from sklearn.metrics import mean_squared_error as mse

SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
random.seed(SEED)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        # 프로그램 시작시에 메모리 증가가 설정되어야만 합니다
        print(e)

def mish(x):
    return x*tf.math.tanh(tf.math.softplus(x))

def decay(epochs):
    init = 1e-3
    drop = 10
    ratio = 0.9
    return max(5e-5, (init * (ratio ** (epochs//drop))))

es = callbacks.EarlyStopping(patience=10, restore_best_weights=True)
lrs = callbacks.LearningRateScheduler(decay, verbose=0)


In [3]:
class cGAN(keras.Model):
    def __init__(self, x_dim, y_dim, z_dim):
        super(cGAN, self).__init__()
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.z_dim = z_dim
        
        self.d_steps = 4
        self.gp_weight = 10
        
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
    
    def compile(self, g_optim, d_optim, g_loss_fn, d_loss_fn):
        super(cGAN, self).compile()
        self.g_optim = g_optim
        self.d_optim = d_optim
        self.g_loss_fn = g_loss_fn
        self.d_loss_fn = d_loss_fn
        
    def build_generator(self):
        activation = mish
        inputs_z = Input(shape=(self.z_dim, ))
        inputs_x = Input(shape=(self.x_dim, ))
        
        x = Concatenate()([inputs_z, inputs_x])
        
        x = Dense(128, kernel_initializer='he_normal')(x)
        x = Activation(activation)(x)
        x = Dense(64, kernel_initializer='he_normal')(x)
        x = Activation(activation)(x)

        outputs = Dense(self.y_dim, kernel_initializer='he_normal')(x)
        return Model([inputs_z, inputs_x], outputs)
    
    def build_discriminator(self):
        inputs_x = Input(shape = (self.x_dim, ))
        inputs_y = Input(shape = (self.y_dim, ))

        x = Concatenate()([inputs_x, inputs_y])
        x = Dense(64, activation=mish)(x)
        x = Dense(32, activation=mish)(x)
        
        outputs = Dense(1)(x)
        
        return Model([inputs_x, inputs_y], outputs)
    
    def gradient_penalty(self, batch_size, x, y, y_pred):
        alpha = tf.random.normal([batch_size, 1], 0.0, 1.0)
        diff = y - y_pred
        interpolated = y + alpha * diff
        
        with tf.GradientTape() as tape:
            tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator([x, interpolated], training=True)
    
        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = tape.gradient(pred, [interpolated])[0]
        # 3. Calcuate the norm of the gradients
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp
    
    def train_step(self, data):
        x, y = data
        batch_size = tf.shape(x)[0]
        
        real_label = tf.ones((batch_size, 1))
        fake_label = tf.ones((batch_size, 1))*-1
        labels = tf.concat([real_label, fake_label], -1)
        
        for i in range(self.d_steps):
            noise = self.sampler(batch_size)
            
            with tf.GradientTape() as tape:
                y_pred = self.generator([noise, x])
                fake_validity = self.discriminator([x, y_pred])
                real_validity = self.discriminator([x, y])
                validity = tf.concat([real_validity, fake_validity], -1)

                cost = self.d_loss_fn(real_validity, fake_validity)
                gp = self.gradient_penalty(batch_size, x, y, y_pred)
                
                d_loss = cost + gp*self.gp_weight
                
            grads = tape.gradient(d_loss, self.discriminator.trainable_variables)
            self.d_optim.apply_gradients(zip(grads, self.discriminator.trainable_variables))
            
            
        noise = self.sampler(batch_size)
        with tf.GradientTape() as tape:
            y_preds = self.generator([noise, x])
            validity = self.discriminator([x, y_preds])
            g_loss = self.g_loss_fn(validity)
            
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optim.apply_gradients(zip(grads, self.generator.trainable_weights))

        return {'d_loss': d_loss, 'g_loss': g_loss}
    
    def sampler(self, batch_size):
        return tf.random.normal((batch_size, self.z_dim), 0, 0.5)

In [4]:
X = np.random.normal(0, 1, (1000, 128))
y = np.random.normal(0, 1, (1000, 10))

In [5]:
def gen_loss(preds):
    return -tf.reduce_mean(preds)

def disc_loss(y, y_pred):
    real_loss = tf.reduce_mean(y)
    fake_loss = tf.reduce_mean(y_pred)
    return fake_loss - real_loss

gan = cGAN(128, 10, 30)

In [6]:
gan.compile(
    optimizers.RMSprop(2e-4),
    optimizers.RMSprop(2e-4),
    gen_loss,
    disc_loss)

In [7]:
class MSECallback(callbacks.Callback):
    def __init__(self):
        pass
    
    def on_epoch_end(self, epoch, logs=None):
        # TODO: redefine validation data
        if epoch % 4 == 0:
            true = y
            noise = self.model.sampler(y.shape[0])
            pred = self.model.generator.predict([noise, X])
            print(f'\n epoch {epoch} error')
            print(mse(y, pred))
        

In [8]:
gan.fit(X, y, 
        epochs=40,
       callbacks=[MSECallback()],
       verbose=0)


 epoch 0 error
3.1085530466622897

 epoch 4 error
2.5773129929305663

 epoch 8 error
2.362732527255694

 epoch 12 error
1.8791958815347187

 epoch 16 error
1.7382884525368474

 epoch 20 error
1.654639301853771

 epoch 24 error
1.729948321400752

 epoch 28 error
1.8014968250443129

 epoch 32 error
1.9001921689319279

 epoch 36 error
1.9271275913882788


<tensorflow.python.keras.callbacks.History at 0x2a08114dcc8>

In [9]:
gan.generator.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 30)]         0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 128)]        0                                            
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 158)          0           input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
dense (Dense)                   (None, 128)          20352       concatenate[0][0]     

In [10]:
gan.discriminator.summary()

Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 128)]        0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None, 10)]         0                                            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 138)          0           input_3[0][0]                    
                                                                 input_4[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 64)           8896        concatenate_1[0][0]   