Skip to content

Commit

Permalink
Merge pull request #7 from ydataai/fix/betas
Browse files Browse the repository at this point in the history
fix: betas
  • Loading branch information
fabclmnt committed Oct 26, 2020
2 parents 9960f95 + be983df commit 85ed725
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 56 deletions.
5 changes: 4 additions & 1 deletion examples/cgan_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ydata_synthetic.preprocessing.credit_fraud import transformations

import pandas as pd
import numpy as np
from sklearn import cluster

#Read the original data and have it preprocessed
Expand Down Expand Up @@ -45,6 +46,8 @@
log_step = 100
epochs = 200+1
learning_rate = 5e-4
beta_1 = 0.5
beta_2 = 0.9
models_dir = './cache'

train_sample = fraud_w_classes.copy().reset_index(drop=True)
Expand All @@ -54,7 +57,7 @@
train_sample[ data_cols ] = train_sample[ data_cols ] / 10 # scale to random noise size, one less thing to learn
train_no_label = train_sample[ data_cols ]

gan_args = [batch_size, learning_rate, noise_dim, train_sample.shape[1], 2, (0, 1), dim]
gan_args = [batch_size, learning_rate, beta_1, beta_2, noise_dim, train_sample.shape[1], 2, (0, 1), dim]
train_args = ['', label_cols[0], epochs, log_step, '']

#Init the Conditional GAN providing the index of the label column as one of the arguments
Expand Down
95 changes: 47 additions & 48 deletions examples/gan_example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/ydata_synthetic/synthesizers/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def __init__(
model_parameters
):
self._model_parameters = model_parameters
[self.batch_size, self.lr, self.noise_dim,
[self.batch_size, self.lr, self.beta_1, self.beta_2, self.noise_dim,
self.data_dim, self.layers_dim] = model_parameters
self.define_gan()

Expand Down
4 changes: 2 additions & 2 deletions src/ydata_synthetic/synthesizers/regular/cgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class CGAN():

def __init__(self, model_parameters):
[self.batch_size, lr, self.noise_dim,
[self.batch_size, lr,self.beta_1, self.beta_2, self.noise_dim,
self.data_dim, num_classes, self.classes, layers_dim] = model_parameters

self.generator = Generator(self.batch_size, num_classes). \
Expand All @@ -22,7 +22,7 @@ def __init__(self, model_parameters):
self.discriminator = Discriminator(self.batch_size, num_classes). \
build_model(input_shape=(self.data_dim,), dim=layers_dim)

optimizer = Adam(lr, beta_1=0.5)
optimizer = Adam(lr, beta_1=self.beta_1, beta_2=self.beta_2)

# Build and compile the discriminator
self.discriminator.compile(loss='binary_crossentropy',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def define_gan(self):
self.discriminator = Discriminator(self.batch_size).\
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)

optimizer = Adam(self.lr, 0.5)
optimizer = Adam(self.lr, beta_1=self.beta_1, beta_2=self.beta_2)

# Build and compile the discriminator
self.discriminator.compile(loss='binary_crossentropy',
Expand Down
6 changes: 3 additions & 3 deletions src/ydata_synthetic/synthesizers/regular/wgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def define_gan(self):
self.critic = Critic(self.batch_size). \
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)

optimizer = Adam(self.lr, beta_1=0.5, beta_2=0.9)
self.critic_optimizer = Adam(self.lr, beta_1=0.5, beta_2=0.9)
optimizer = Adam(self.lr, beta_1=self.beta_1, beta_2=self.beta_2)
self.critic_optimizer = Adam(self.lr, beta_1=self.beta_1, beta_2=self.beta_2)

# Build and compile the critic
self.critic.compile(loss=self.wasserstein_loss,
Expand Down Expand Up @@ -111,7 +111,7 @@ def train(self, data, train_arguments):
# Plot the progress
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

# If at save interval => save generated events
#If at save interval => save generated events
if epoch % sample_interval == 0:
# Test here data generation step
# save model checkpoints
Expand Down

0 comments on commit 85ed725

Please sign in to comment.