Skip to content

Commit

Permalink
Merge pull request #13 from ydataai/feat/feature#11
Browse files Browse the repository at this point in the history
feat: feature#11 (WGAN-GP)
  • Loading branch information
fabclmnt committed Oct 30, 2020
2 parents 0d527b1 + f52a06e commit b7a3c82
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 43 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ Here you can find usage examples of the package and models to synthesize tabular
- [GAN](https://arxiv.org/abs/1406.2661)
- [CGAN (Conditional GAN)](https://arxiv.org/abs/1411.1784)
- [WGAN (Wasserstein GAN)](https://arxiv.org/abs/1701.07875)
- [WGAN-GP (Wassertein GAN with Gradient Penalty)](https://arxiv.org/abs/1704.00028)
80 changes: 80 additions & 0 deletions examples/wgan_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#Install ydata-synthetic lib
#! pip install pip install git+https://github.com/ydataai/ydata-synthetic.git

import importlib
import sys

import pandas as pd
import numpy as np
import sklearn.cluster as cluster
import matplotlib.pyplot as plt

from ydata_synthetic.synthesizers import WGAN_GP
from ydata_synthetic.preprocessing.credit_fraud import *

model = WGAN_GP

#Read the original data and have it preprocessed
data = pd.read_csv('data/creditcard.csv', index_col=[0])

#Data processing and analysis
data_cols = list(data.columns[ data.columns != 'Class' ])
label_cols = ['Class']

print('Dataset columns: {}'.format(data_cols))
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class']
processed_data = data[ sorted_cols ].copy()

#Before training the GAN do not forget to apply the required data transformations
#To ease here we've applied a PowerTransformation
data = transformations(data)

#For the purpose of this example we will only synthesize the minority class
train_data = data.loc[ data['Class']==1 ].copy()

print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))

algorithm = cluster.KMeans
args, kwds = (), {'n_clusters':2, 'random_state':0}
labels = algorithm(*args, **kwds).fit_predict(train_data[ data_cols ])

print( pd.DataFrame( [ [np.sum(labels==i)] for i in np.unique(labels) ], columns=['count'], index=np.unique(labels) ) )

fraud_w_classes = train_data.copy()
fraud_w_classes['Class'] = labels

# GAN training
#Define the GAN and training parameters
noise_dim = 32
dim = 128
batch_size = 128

log_step = 100
epochs = 200+1
learning_rate = 5e-4
beta_1 = 0.5
beta_2 = 0.9
models_dir = './cache'

train_sample = fraud_w_classes.copy().reset_index(drop=True)
train_sample = pd.get_dummies(train_sample, columns=['Class'], prefix='Class', drop_first=True)
label_cols = [ i for i in train_sample.columns if 'Class' in i ]
data_cols = [ i for i in train_sample.columns if i not in label_cols ]
train_sample[ data_cols ] = train_sample[ data_cols ] / 10 # scale to random noise size, one less thing to learn
train_no_label = train_sample[ data_cols ]

gan_args = [batch_size, learning_rate, beta_1, beta_2, noise_dim, train_sample.shape[1], dim]
train_args = ['', epochs, log_step]

seed = 17
test_size = 492 # number of fraud cases
noise_dim = 32

#Training the WGAN_GP model
synthesizer = model(gan_args, n_critic=2)
synthesizer.train(train_sample, train_args)

#WGAN_GP models is now trained
#So we can easily generate a few samples


4 changes: 3 additions & 1 deletion src/ydata_synthetic/synthesizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from ydata_synthetic.synthesizers.regular.cgan.model import CGAN
from ydata_synthetic.synthesizers.regular.wgan.model import WGAN
from ydata_synthetic.synthesizers.regular.vanillagan.model import VanilllaGAN
from ydata_synthetic.synthesizers.regular.wgangp.model import WGAN_GP

__all__ = [
"VanilllaGAN",
"CGAN",
"WGAN"
"WGAN",
"WGAN_GP"
]
17 changes: 13 additions & 4 deletions src/ydata_synthetic/synthesizers/gan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import os
import tqdm

import pandas as pd
import tensorflow as tf
from tensorflow.python import keras

class Model():
Expand All @@ -21,10 +25,6 @@ def define_gan(self):
def trainable_variables(self, network):
return network.trainable_variables

@property
def model(self):
return self._model

@property
def model_parameters(self):
return self._model_parameters
Expand All @@ -36,6 +36,15 @@ def model_name(self):
def train(self, data, train_arguments):
raise NotImplementedError

def sample(self, n_samples):
steps = n_samples // self.batch_size + 1
data = []
for step in tqdm.trange(steps):
z = tf.random.uniform([self.batch_size, self.noise_dim])
records = tf.make_ndarray(tf.make_tensor_proto(self.generator(z, training=False)))
data.append(pd.DataFrame(records))
return pd.concat(data)

def save(self, path, name):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
Expand Down
80 changes: 42 additions & 38 deletions src/ydata_synthetic/synthesizers/regular/wgan/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from os import path
import numpy as np
from tqdm import tqdm

from ydata_synthetic.synthesizers import gan

Expand Down Expand Up @@ -28,6 +29,7 @@ class WGAN(gan.Model):

def __init__(self, model_parameters, n_critic):
# As recommended in WGAN paper - https://arxiv.org/abs/1701.07875
# WGAN-GP - WGAN with Gradient Penalty
self.n_critic = n_critic
super().__init__(model_parameters)

Expand Down Expand Up @@ -80,52 +82,54 @@ def get_data_batch(self, train, batch_size, seed=0):
def train(self, data, train_arguments):
[cache_prefix, epochs, sample_interval] = train_arguments

#Create a summary file
train_summary_writer = tf.summary.create_file_writer(path.join('.', 'summaries', 'train'))

# Adversarial ground truths
valid = np.ones((self.batch_size, 1))
fake = -np.ones((self.batch_size, 1))

for epoch in range(epochs):
with train_summary_writer.as_default():
for epoch in tqdm.trange(epochs, desc='Epoch Iterations'):

for _ in range(self.n_critic):
# ---------------------
# Train the Critic
# ---------------------
batch_data = self.get_data_batch(data, self.batch_size)
noise = tf.random.normal((self.batch_size, self.noise_dim))

# Generate a batch of events
gen_data = self.generator(noise)

# Train the Critic
d_loss_real = self.critic.train_on_batch(batch_data, valid)
d_loss_fake = self.critic.train_on_batch(gen_data, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
l.set_weights(weights)

for _ in range(self.n_critic):
# ---------------------
# Train the Critic
# Train Generator
# ---------------------
batch_data = self.get_data_batch(data, self.batch_size)
noise = tf.random.normal((self.batch_size, self.noise_dim))

# Generate a batch of events
gen_data = self.generator(noise)

# Train the Critic
d_loss_real = self.critic.train_on_batch(batch_data, valid)
d_loss_fake = self.critic.train_on_batch(gen_data, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# Critic weight clipping
for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
l.set_weights(weights)

# ---------------------
# Train Generator
# ---------------------
noise = tf.random.normal((self.batch_size, self.noise_dim))
# Train the generator (to have the critic label samples as valid)
g_loss = self.model.train_on_batch(noise, valid)

# Plot the progress
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

#If at save interval => save generated events
if epoch % sample_interval == 0:
# Test here data generation step
# save model checkpoints
if path.exists('./cache') is False:
os.mkdir('./cache')
model_checkpoint_base_name = './cache/' + cache_prefix + '_{}_model_weights_step_{}.h5'
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch))
# Train the generator (to have the critic label samples as valid)
g_loss = self.model.train_on_batch(noise, valid)
# Plot the progress
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

#If at save interval => save generated events
if epoch % sample_interval == 0:
# Test here data generation step
# save model checkpoints
if path.exists('./cache') is False:
os.mkdir('./cache')
model_checkpoint_base_name = './cache/' + cache_prefix + '_{}_model_weights_step_{}.h5'
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch))

def load(self, path):
assert os.path.isdir(path) == True, \
Expand Down
Empty file.
Loading

0 comments on commit b7a3c82

Please sign in to comment.