Skip to content

Commit

Permalink
fix(wgan): Gradients solved
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Oct 30, 2020
1 parent 6bdee5d commit 41fcb1d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 118 deletions.
5 changes: 3 additions & 2 deletions src/ydata_synthetic/synthesizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from ydata_synthetic.synthesizers.regular.cgan.model import CGAN
from ydata_synthetic.synthesizers.regular.wgan.model import WGAN
from ydata_synthetic.synthesizers.regular.vanillagan.model import VanilllaGAN
from ydata_synthetic.synthesizers.regular.wgan_gp.model import WGAN_GP
from ydata_synthetic.synthesizers.regular.wgangp.model import WGAN_GP

__all__ = [
"VanilllaGAN",
"CGAN",
"WGAN"
"WGAN",
"WGAN_GP"
]
1 change: 0 additions & 1 deletion src/ydata_synthetic/synthesizers/regular/wgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam


#Auxiliary Keras backend class to calculate the Random Weighted average
#https://stackoverflow.com/questions/58133430/how-to-substitute-keras-layers-merge-merge-in-tensorflow-keras
class RandomWeightedAverage(tf.keras.layers.Layer):
Expand Down
189 changes: 74 additions & 115 deletions src/ydata_synthetic/synthesizers/regular/wgangp/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from os import path
import numpy as np
from tqdm import tqdm
import tqdm
from functools import partial

from ydata_synthetic.synthesizers import gan
Expand All @@ -12,113 +12,88 @@
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam

#Auxiliary Keras backend class to calculate the Random Weighted average
#https://stackoverflow.com/questions/58133430/how-to-substitute-keras-layers-merge-merge-in-tensorflow-keras
class RandomWeightedAverage(tf.keras.layers.Layer):
def __init__(self, batch_size):
super().__init__()
self.batch_size = batch_size

def call(self, inputs, **kwargs):
alpha = tf.random_uniform((self.batch_size, 1, 1, 1))
return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

def compute_output_shape(self, input_shape):
return input_shape[0]

class WGAN_GP(gan.Model):


GRADIENT_PENALTY_WEIGHT = 10

def __init__(self, model_parameters, n_critic):
# As recommended in WGAN paper - https://arxiv.org/abs/1701.07875
# WGAN-GP - WGAN with Gradient Penalty
self.n_critic = n_critic
super().__init__(model_parameters)

def wasserstein_loss(self, y_true, y_pred):
return K.mean(y_true * y_pred)

def define_gan(self):
def define_gan(self):
self.generator = Generator(self.batch_size). \
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim)

self.critic = Critic(self.batch_size). \
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)

optimizer = Adam(self.lr, beta_1=self.beta_1, beta_2=self.beta_2)
self.g_optimizer = Adam(self.lr, beta_1=self.beta_1, beta_2=self.beta_2)
self.critic_optimizer = Adam(self.lr, beta_1=self.beta_1, beta_2=self.beta_2)

# Freeze the generator while discriminator training
self.generator.trainable = False

#Real event input
real_event = Input(shape=self.data_dim)

#Random noise object
z = Input(shape=(self.noise_dim,))
#Generate new record using the generator from noise
record = self.generator(z)

# Discriminator determines validity of the real and fake events
fake = self.critic(record)
valid = self.critic(real_event)

# Construct weighted average between real and the fake envents
interpolated_img = RandomWeightedAverage()([real_event, record])

# Determine validity of weighted sample
validity_interpolated = self.critic(interpolated_img)

partial_gp_loss = partial(self.gradient_penalty_loss,
averaged_samples=validity_interpolated)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

self._model_critic = Model(inputs=[real_event, z],
outputs=[valid, fake, validity_interpolated],
metrics=['accuracy'])

self._model_critic.compile(loss=[self.wasserstein_loss,
self.wasserstein_loss,
partial_gp_loss],
optimizer=optimizer,
loss_weights=[1, 1, 10])
def wasserstein_loss(self, y_true, y_pred):
return K.mean(y_true * y_pred)

# For the combined model we will only train the generator
self.critic.trainable = False
def gradient_penalty(self, real, fake):
epsilon = tf.random.uniform([real.shape[0], 1], 0.0, 1.0, dtype=tf.dtypes.float32)
x_hat = epsilon * real + (1 - epsilon) * fake
with tf.GradientTape() as t:
t.watch(x_hat)
d_hat = self.critic(x_hat)
gradients = t.gradient(d_hat, x_hat)
ddx = tf.sqrt(tf.reduce_sum(gradients ** 2))
d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2)
return d_regularizer

def compute_gradients(self, x):
"""
Compute the gradients for both the Generator and the Critic
:param x: real data event
:return: generator gradients, critic gradients
"""
with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
d_loss, g_loss = self.compute_loss(x)

# Computational graph for the Generator
#Freeze the critic training while training the generator
self.critic.trainable = False
self.generator.trainable = True
gen_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables)
disc_gradients = d_tape.gradient(d_loss, self.critic.trainable_variables)

z_gen = Input(shape=(self.noise_dim,))
fake_record = self.generator(z_gen)
valid = self.critic(fake_record)
return gen_gradients, disc_gradients

# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
# For the WGAN model use the Wassertein loss
self._model = Model(z_gen, valid)
self._model.compile(loss=self.wasserstein_loss, optimizer=optimizer)
def apply_gradients(self, ggradients, dgradients):
self.g_optimizer.apply_gradients(
zip(ggradients, self.generator.trainable_variables)
)
self.critic_optimizer.apply_gradients(
zip(dgradients, self.critic.trainable_variables)
)

def gradient_penalty_loss(self, y_pred, averaged_samples):
def compute_loss(self, real):
"""
passes through the network and computes the losses
"""
Computing gradient penalty based on the prediction for real, fake and weighted events
"""
gradients = K.gradients(y_pred, averaged_samples)[0]

gradients_sqr = K.square(gradients)
# generating noise from a uniform distribution

gradients_sqr_sum = K.sum(gradients_sqr,
axis=np.arange(1, len(gradients_sqr.shape)))
noise = tf.random.normal([real.shape[0], self.noise_dim], dtype=tf.dtypes.float32)

gradient_l2_norm = K.sqrt(gradients_sqr_sum)
# run noise through generator
fake = self.generator(noise)
# discriminate x and x_gen
logits_real = self.critic(real)
logits_fake = self.critic(fake)

# compute lambda * (1 - ||grad||)^2 still for each single sample
gradient_penalty = K.square(1 - gradient_l2_norm)

# return the mean as loss over all the batch samples
return K.mean(gradient_penalty)
# gradient penalty
d_regularizer = self.gradient_penalty(real, fake)
### losses
d_loss = (
tf.reduce_mean(logits_real)
- tf.reduce_mean(logits_fake)
+ d_regularizer * self.GRADIENT_PENALTY_WEIGHT
)

# losses of fake with label "1"
g_loss = tf.reduce_mean(logits_fake)
return d_loss, g_loss

def get_data_batch(self, train, batch_size, seed=0):
# np.random.seed(seed)
Expand All @@ -133,44 +108,28 @@ def get_data_batch(self, train, batch_size, seed=0):
x = train.loc[train_ix[start_i: stop_i]].values
return np.reshape(x, (batch_size, -1))

@tf.function
def train_step(self, train_data):
g_gradients, d_gradients = self.compute_gradients(train_data)
self.apply_gradients(g_gradients, d_gradients)

def train(self, data, train_arguments):
[cache_prefix, epochs, sample_interval] = train_arguments

#Create a summary file
# Create a summary file
train_summary_writer = tf.summary.create_file_writer(path.join('../wgan_gp_test', 'summaries', 'train'))

# Adversarial ground truths
valid = -np.ones((self.batch_size, 1))
fake = np.ones((self.batch_size, 1))
dummy = -np.zeros((self.batch_size, 1))

with train_summary_writer.as_default():
for epoch in tqdm.trange(epochs, desc='Epoch Iterations'):

for _ in range(self.n_critic):
# ---------------------
# Train the Critic
# ---------------------
batch_data = self.get_data_batch(data, self.batch_size)
noise = tf.random.normal((self.batch_size, self.noise_dim))

# Generate a batch of events
gen_data = self.generator(noise)

# Train the Critic
d_loss = self._model_critic.train_on_batch([batch_data, noise],
[valid, fake, dummy])

# ---------------------
# Train Generator
# ---------------------
noise = tf.random.normal((self.batch_size, self.noise_dim))
# Train the generator (to have the critic label samples as valid)
g_loss = self.model.train_on_batch(noise, valid)
# Plot the progress
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

#If at save interval => save generated events
for epoch in tqdm.trange(epochs):
batch_data = self.get_data_batch(data, self.batch_size).astype(np.float32)
self.train_step(batch_data)
loss = self.compute_loss(batch_data)

print(
"Epoch: {} | disc_loss: {} | gen_loss: {}".format(
epoch, loss[0], loss[1]
))

if epoch % sample_interval == 0:
# Test here data generation step
# save model checkpoints
Expand Down

0 comments on commit 41fcb1d

Please sign in to comment.