Skip to content

Commit

Permalink
feat(wgan-gp): Add gradient penalty to wgan
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Oct 29, 2020
1 parent 0d527b1 commit 5d37e62
Show file tree
Hide file tree
Showing 5 changed files with 265 additions and 38 deletions.
1 change: 1 addition & 0 deletions src/ydata_synthetic/synthesizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ydata_synthetic.synthesizers.regular.wgan.model import WGAN
from ydata_synthetic.synthesizers.regular.vanillagan.model import VanilllaGAN


__all__ = [
"VanilllaGAN",
"CGAN",
Expand Down
5 changes: 5 additions & 0 deletions src/ydata_synthetic/synthesizers/gan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import tensorflow as tf
from tensorflow.python import keras

class Model():
Expand Down Expand Up @@ -36,6 +37,10 @@ def model_name(self):
def train(self, data, train_arguments):
raise NotImplementedError

@tf.function
def samples(self, z):
return self.generator(z, training=False)

def save(self, path, name):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
Expand Down
81 changes: 43 additions & 38 deletions src/ydata_synthetic/synthesizers/regular/wgan/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from os import path
import numpy as np
from tqdm import tqdm

from ydata_synthetic.synthesizers import gan

Expand All @@ -10,6 +11,7 @@
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam


#Auxiliary Keras backend class to calculate the Random Weighted average
#https://stackoverflow.com/questions/58133430/how-to-substitute-keras-layers-merge-merge-in-tensorflow-keras
class RandomWeightedAverage(tf.keras.layers.Layer):
Expand All @@ -28,6 +30,7 @@ class WGAN(gan.Model):

def __init__(self, model_parameters, n_critic):
# As recommended in WGAN paper - https://arxiv.org/abs/1701.07875
# WGAN-GP - WGAN with Gradient Penalty
self.n_critic = n_critic
super().__init__(model_parameters)

Expand Down Expand Up @@ -80,52 +83,54 @@ def get_data_batch(self, train, batch_size, seed=0):
def train(self, data, train_arguments):
[cache_prefix, epochs, sample_interval] = train_arguments

#Create a summary file
train_summary_writer = tf.summary.create_file_writer(path.join('.', 'summaries', 'train'))

# Adversarial ground truths
valid = np.ones((self.batch_size, 1))
fake = -np.ones((self.batch_size, 1))

for epoch in range(epochs):
with train_summary_writer.as_default():
for epoch in tqdm.trange(epochs, desc='Epoch Iterations'):

for _ in range(self.n_critic):
# ---------------------
# Train the Critic
# ---------------------
batch_data = self.get_data_batch(data, self.batch_size)
noise = tf.random.normal((self.batch_size, self.noise_dim))

# Generate a batch of events
gen_data = self.generator(noise)

# Train the Critic
d_loss_real = self.critic.train_on_batch(batch_data, valid)
d_loss_fake = self.critic.train_on_batch(gen_data, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
l.set_weights(weights)

for _ in range(self.n_critic):
# ---------------------
# Train the Critic
# Train Generator
# ---------------------
batch_data = self.get_data_batch(data, self.batch_size)
noise = tf.random.normal((self.batch_size, self.noise_dim))

# Generate a batch of events
gen_data = self.generator(noise)

# Train the Critic
d_loss_real = self.critic.train_on_batch(batch_data, valid)
d_loss_fake = self.critic.train_on_batch(gen_data, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# Critic weight clipping
for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
l.set_weights(weights)

# ---------------------
# Train Generator
# ---------------------
noise = tf.random.normal((self.batch_size, self.noise_dim))
# Train the generator (to have the critic label samples as valid)
g_loss = self.model.train_on_batch(noise, valid)

# Plot the progress
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

#If at save interval => save generated events
if epoch % sample_interval == 0:
# Test here data generation step
# save model checkpoints
if path.exists('./cache') is False:
os.mkdir('./cache')
model_checkpoint_base_name = './cache/' + cache_prefix + '_{}_model_weights_step_{}.h5'
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch))
# Train the generator (to have the critic label samples as valid)
g_loss = self.model.train_on_batch(noise, valid)
# Plot the progress
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

#If at save interval => save generated events
if epoch % sample_interval == 0:
# Test here data generation step
# save model checkpoints
if path.exists('./cache') is False:
os.mkdir('./cache')
model_checkpoint_base_name = './cache/' + cache_prefix + '_{}_model_weights_step_{}.h5'
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch))

def load(self, path):
assert os.path.isdir(path) == True, \
Expand Down
Empty file.
216 changes: 216 additions & 0 deletions src/ydata_synthetic/synthesizers/regular/wgan_gp/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import os
from os import path
import numpy as np
from tqdm import tqdm
from functools import partial

from ydata_synthetic.synthesizers import gan

import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout
import tensorflow.keras.backend as K
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam


#Auxiliary Keras backend class to calculate the Random Weighted average
#https://stackoverflow.com/questions/58133430/how-to-substitute-keras-layers-merge-merge-in-tensorflow-keras
class RandomWeightedAverage(tf.keras.layers.Layer):
def __init__(self, batch_size):
super().__init__()
self.batch_size = batch_size

def call(self, inputs, **kwargs):
alpha = tf.random_uniform((self.batch_size, 1, 1, 1))
return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

def compute_output_shape(self, input_shape):
return input_shape[0]

class WGAN(gan.Model):

def __init__(self, model_parameters, n_critic):
# As recommended in WGAN paper - https://arxiv.org/abs/1701.07875
# WGAN-GP - WGAN with Gradient Penalty
self.n_critic = n_critic
super().__init__(model_parameters)

def wasserstein_loss(self, y_true, y_pred):
return K.mean(y_true * y_pred)

def define_gan(self):
self.generator = Generator(self.batch_size). \
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim)

self.critic = Critic(self.batch_size). \
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)

optimizer = Adam(self.lr, beta_1=self.beta_1, beta_2=self.beta_2)
self.critic_optimizer = Adam(self.lr, beta_1=self.beta_1, beta_2=self.beta_2)

# Freeze the generator while discriminator training
self.generator.trainable = False

#Real event input
real_event = Input(shape=self.data_dim)

#Random noise object
z = Input(shape=(self.noise_dim,))
#Generate new record using the generator from noise
record = self.generator(z)

# Discriminator determines validity of the real and fake events
fake = self.critic(record)
valid = self.critic(real_event)

# Construct weighted average between real and the fake envents
interpolated_img = RandomWeightedAverage()([real_event, record])

# Determine validity of weighted sample
validity_interpolated = self.critic(interpolated_img)

partial_gp_loss = partial(self.gradient_penalty_loss,
averaged_samples=validity_interpolated)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

self._model_critic = Model(inputs=[real_event, z],
outputs=[valid, fake, validity_interpolated],
metrics=['accuracy'])

self._model_critic.compile(loss=[self.wasserstein_loss,
self.wasserstein_loss,
partial_gp_loss],
optimizer=optimizer,
loss_weights=[1, 1, 10])

# For the combined model we will only train the generator
self.critic.trainable = False

# Computational graph for the Generator
#Freeze the critic training while training the generator
self.critic.trainable = False
self.generator.trainable = True

z_gen = Input(shape=(self.noise_dim,))
fake_record = self.generator(z_gen)
valid = self.critic(fake_record)

# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
# For the WGAN model use the Wassertein loss
self._model = Model(z_gen, valid)
self._model.compile(loss=self.wasserstein_loss, optimizer=optimizer)

def gradient_penalty_loss(self, y_pred, averaged_samples):
"""
Computing gradient penalty based on the prediction for real, fake and weighted events
"""
gradients = K.gradients(y_pred, averaged_samples)[0]

gradients_sqr = K.square(gradients)

gradients_sqr_sum = K.sum(gradients_sqr,
axis=np.arange(1, len(gradients_sqr.shape)))

gradient_l2_norm = K.sqrt(gradients_sqr_sum)

# compute lambda * (1 - ||grad||)^2 still for each single sample
gradient_penalty = K.square(1 - gradient_l2_norm)

# return the mean as loss over all the batch samples
return K.mean(gradient_penalty)


def get_data_batch(self, train, batch_size, seed=0):
# np.random.seed(seed)
# x = train.loc[ np.random.choice(train.index, batch_size) ].values
# iterate through shuffled indices, so every sample gets covered evenly
start_i = (batch_size * seed) % len(train)
stop_i = start_i + batch_size
shuffle_seed = (batch_size * seed) // len(train)
np.random.seed(shuffle_seed)
train_ix = np.random.choice(list(train.index), replace=False, size=len(train)) # wasteful to shuffle every time
train_ix = list(train_ix) + list(train_ix) # duplicate to cover ranges past the end of the set
x = train.loc[train_ix[start_i: stop_i]].values
return np.reshape(x, (batch_size, -1))

def train(self, data, train_arguments):
[cache_prefix, epochs, sample_interval] = train_arguments

#Create a summary file
train_summary_writer = tf.summary.create_file_writer(path.join('.', 'summaries', 'train'))

# Adversarial ground truths
valid = -np.ones((self.batch_size, 1))
fake = np.ones((self.batch_size, 1))
dummy = -np.zeros((self.batch_size, 1))

with train_summary_writer.as_default():
for epoch in tqdm.trange(epochs, desc='Epoch Iterations'):

for _ in range(self.n_critic):
# ---------------------
# Train the Critic
# ---------------------
batch_data = self.get_data_batch(data, self.batch_size)
noise = tf.random.normal((self.batch_size, self.noise_dim))

# Generate a batch of events
gen_data = self.generator(noise)

# Train the Critic
d_loss = self._model_critic.train_on_batch([batch_data, noise],
[valid, fake, dummy])

# ---------------------
# Train Generator
# ---------------------
noise = tf.random.normal((self.batch_size, self.noise_dim))
# Train the generator (to have the critic label samples as valid)
g_loss = self.model.train_on_batch(noise, valid)
# Plot the progress
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

#If at save interval => save generated events
if epoch % sample_interval == 0:
# Test here data generation step
# save model checkpoints
if path.exists('./cache') is False:
os.mkdir('./cache')
model_checkpoint_base_name = './cache/' + cache_prefix + '_{}_model_weights_step_{}.h5'
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch))

def load(self, path):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
self.generator = Generator(self.batch_size)
self.generator = self.generator.load_weights(path)
return self.generator


class Generator(tf.keras.Model):
def __init__(self, batch_size):
self.batch_size = batch_size

def build_model(self, input_shape, dim, data_dim):
input = Input(shape=input_shape, batch_size=self.batch_size)
x = Dense(dim, activation='relu')(input)
x = Dense(dim * 2, activation='relu')(x)
x = Dense(dim * 4, activation='relu')(x)
x = Dense(data_dim)(x)
return Model(inputs=input, outputs=x)

class Critic(tf.keras.Model):
def __init__(self, batch_size):
self.batch_size = batch_size

def build_model(self, input_shape, dim):
input = Input(shape=input_shape, batch_size=self.batch_size)
x = Dense(dim * 4, activation='relu')(input)
x = Dropout(0.1)(x)
x = Dense(dim * 2, activation='relu')(x)
x = Dropout(0.1)(x)
x = Dense(dim, activation='relu')(x)
x = Dense(1)(x)
return Model(inputs=input, outputs=x)

0 comments on commit 5d37e62

Please sign in to comment.