Skip to content

Commit

Permalink
feat: Add new Cramer Loss and Cramer GAN (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Oct 28, 2021
1 parent 1c84754 commit a9de1ab
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 6 deletions.
87 changes: 87 additions & 0 deletions examples/regular/cramergan_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#Install ydata-synthetic lib
# pip install ydata-synthetic
import sklearn.cluster as cluster
import numpy as np
import pandas as pd

from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
from ydata_synthetic.synthesizers.regular import CRAMERGAN
from ydata_synthetic.preprocessing.regular.credit_fraud import transformations

model = CRAMERGAN

#Read the original data and have it preprocessed
data = pd.read_csv('data/creditcard.csv', index_col=[0])

#Data processing and analysis
data_cols = list(data.columns[ data.columns != 'Class' ])
label_cols = ['Class']

print('Dataset columns: {}'.format(data_cols))
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class']
processed_data = data[ sorted_cols ].copy()

#Before training the GAN do not forget to apply the required data transformations
#To ease here we've applied a PowerTransformation
data = transformations(data)

#For the purpose of this example we will only synthesize the minority class
train_data = data.loc[ data['Class']==1 ].copy()

print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))

algorithm = cluster.KMeans
args, kwds = (), {'n_clusters':2, 'random_state':0}
labels = algorithm(*args, **kwds).fit_predict(train_data[ data_cols ])

print( pd.DataFrame( [ [np.sum(labels==i)] for i in np.unique(labels) ], columns=['count'], index=np.unique(labels) ) )

fraud_w_classes = train_data.copy()
fraud_w_classes['Class'] = labels

# GAN training
#Define the GAN and training parameters
noise_dim = 32
dim = 128
batch_size = 128

log_step = 100
epochs = 300+1
learning_rate = 5e-4
beta_1 = 0.5
beta_2 = 0.9
models_dir = './cache'

train_sample = fraud_w_classes.copy().reset_index(drop=True)
train_sample = pd.get_dummies(train_sample, columns=['Class'], prefix='Class', drop_first=True)
label_cols = [ i for i in train_sample.columns if 'Class' in i ]
data_cols = [ i for i in train_sample.columns if i not in label_cols ]
train_sample[ data_cols ] = train_sample[ data_cols ] / 10 # scale to random noise size, one less thing to learn
train_no_label = train_sample[ data_cols ]

model_parameters = ModelParameters(batch_size=batch_size,
lr=learning_rate,
betas=(beta_1, beta_2),
noise_dim=noise_dim,
n_cols=train_sample.shape[1],
layers_dim=dim)

train_args = TrainParameters(epochs=epochs,
sample_interval=log_step)

test_size = 492 # number of fraud cases
noise_dim = 32

#Training the CRAMERGAN model
synthesizer = model(model_parameters, gradient_penalty_weight=10)
synthesizer.train(train_sample, train_args)

#Saving the synthesizer to later generate new events
synthesizer.save(path='models/cramergan_creditcard.pkl')

#Loading the synthesizer
synth = model.load(path='models/cramergan_creditcard.pkl')

#Sampling the data
#Note that the data returned it is not inverse processed.
data_sample = synth.sample(100000)
25 changes: 22 additions & 3 deletions src/ydata_synthetic/synthesizers/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
from tensorflow import reshape, shape, math, GradientTape, reduce_mean
from tensorflow import norm as tfnorm

from enum import Enum

class Mode(Enum):
WGANGP = 'wgangp'
DRAGAN = 'dragan'
CRAMER = 'cramer'

## Original code loss from
## https://github.com/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Tensorflow-2/blob/master/tf2gan/loss.py
def gradient_penalty(f, real, fake, mode):
Expand All @@ -23,12 +30,24 @@ def _interpolate(a, b=None):
grad = t.gradient(pred, x)
norm = tfnorm(reshape(grad, [shape(grad)[0], -1]), axis=1)
gp = reduce_mean((norm - 1.)**2)

return gp

if mode == 'dragan':
def _gradient_penalty_cramer(f_crit, real, fake):
epsilon = random.uniform([real.shape[0], 1], 0.0, 1.0)
x_hat = epsilon * real + (1 - epsilon) * fake[0]
with GradientTape() as t:
t.watch(x_hat)
f_x_hat = f_crit(x_hat, fake[1])
gradients = t.gradient(f_x_hat, x_hat)
c_dx = tfnorm(reshape(gradients, [shape(gradients)[0], -1]), axis=1)
c_regularizer = (c_dx - 1.0) ** 2
return c_regularizer

if mode == Mode.DRAGAN:
gp = _gradient_penalty(f, real)
elif mode == 'wgangp':
elif mode == Mode.CRAMER:
gp = _gradient_penalty_cramer(f, real, fake)
elif mode == Mode.WGANGP:
gp = _gradient_penalty(f, real, fake)

return gp
4 changes: 3 additions & 1 deletion src/ydata_synthetic/synthesizers/regular/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from ydata_synthetic.synthesizers.regular.vanillagan.model import VanilllaGAN
from ydata_synthetic.synthesizers.regular.wgangp.model import WGAN_GP
from ydata_synthetic.synthesizers.regular.dragan.model import DRAGAN
from ydata_synthetic.synthesizers.regular.cramergan.model import CRAMERGAN

__all__ = [
"VanilllaGAN",
"CGAN",
"WGAN",
"WGAN_GP",
"DRAGAN"
"DRAGAN",
"CRAMERGAN"
]
Empty file.
209 changes: 209 additions & 0 deletions src/ydata_synthetic/synthesizers/regular/cramergan/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import os
from os import path
import numpy as np
from tqdm import trange

from ydata_synthetic.synthesizers.gan import BaseModel
from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty
from ydata_synthetic.synthesizers import TrainParameters

import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam

class CRAMERGAN(BaseModel):

__MODEL__='CRAMERGAN'

def __init__(self, model_parameters, gradient_penalty_weight=10):
"""Create a base CramerGAN.
Based according to the WGAN paper - https://arxiv.org/pdf/1705.10743.pdf
CramerGAN, a solution to biased Wassertein Gradients https://arxiv.org/abs/1705.10743"""
self.gradient_penalty_weight = gradient_penalty_weight
super().__init__(model_parameters)

def define_gan(self):
self.generator = Generator(self.batch_size). \
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim)

self.critic = Critic(self.batch_size). \
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)

self.g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2)
self.c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2)

# The generator takes noise as input and generates records
z = Input(shape=(self.noise_dim,), batch_size=self.batch_size)
fake = self.generator(z, training=True)
logits = self.critic(fake, training=True)

# Compile the critic
self.critic.compile(loss=self.c_lossfn,
optimizer=self.c_optimizer,
metrics=['accuracy'])

# Generator and critic model
_model = Model(z, logits)
_model.compile(loss=self.g_lossfn, optimizer=self.g_optimizer)

def gradient_penalty(self, real, fake):
gp = gradient_penalty(self.f_crit, real, fake, mode=Mode.CRAMER)
return gp

def update_gradients(self, x):
"""Compute and apply the gradients for both the Generator and the Critic.
:param x: real data event
:return: generator gradients, critic gradients
"""
# Update the gradients of critic for n_critic times (Training the critic)

##New generator gradient_tape
noise= tf.random.normal([x.shape[0], self.noise_dim], dtype=tf.dtypes.float32)
noise2= tf.random.normal([x.shape[0], self.noise_dim], dtype=tf.dtypes.float32)

with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
fake=self.generator(noise, training=True)
fake2=self.generator(noise2, training=True)

g_loss = self.g_lossfn(x, fake, fake2)

c_loss = self.c_lossfn(x, fake, fake2)

# Get the gradients of the generator
g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables)

# Update the weights of the generator
self.g_optimizer.apply_gradients(
zip(g_gradients, self.generator.trainable_variables)
)

c_gradient = d_tape.gradient(c_loss, self.critic.trainable_variables)
# Update the weights of the critic using the optimizer
self.c_optimizer.apply_gradients(
zip(c_gradient, self.critic.trainable_variables)
)

return c_loss, g_loss

def g_lossfn(self, real, fake, fake2):
"""Compute generator loss function according to the CramerGAN paper.
:param real: A real sample
:param fake: A fake sample
:param fak2: A second fake sample
:return: Loss of the generator
"""
g_loss = tf.norm(self.critic(real, training=True) - self.critic(fake, training=True), axis=1) + \
tf.norm(self.critic(real, training=True) - self.critic(fake2, training=True), axis=1) - \
tf.norm(self.critic(fake, training=True) - self.critic(fake2, training=True), axis=1)
return tf.reduce_mean(g_loss)

def f_crit(self, real, fake):
"""
Computes the critic distance function f between two samples
:param real: A real sample
:param fake: A fake sample
:return: Loss of the critic
"""
return tf.norm(self.critic(real, training=True) - self.critic(fake, training=True), axis=1) - tf.norm(self.critic(real, training=True), axis=1)

def c_lossfn(self, real, fake, fake2):
"""
:param real: A real sample
:param fake: A fake sample
:param fak2: A second fake sample
:return: Loss of the critic
"""
f_real = self.f_crit(real, fake2)
f_fake = self.f_crit(fake, fake2)
loss_surrogate = f_real - f_fake
gp = self.gradient_penalty(real, [fake, fake2])
return tf.reduce_mean(- loss_surrogate + self.gradient_penalty_weight*gp)

@staticmethod
def get_data_batch(train, batch_size, seed=0):
# np.random.seed(seed)
# x = train.loc[ np.random.choice(train.index, batch_size) ].values
# iterate through shuffled indices, so every sample gets covered evenly
start_i = (batch_size * seed) % len(train)
stop_i = start_i + batch_size
shuffle_seed = (batch_size * seed) // len(train)
np.random.seed(shuffle_seed)
train_ix = np.random.choice(list(train.index), replace=False, size=len(train)) # wasteful to shuffle every time
train_ix = list(train_ix) + list(train_ix) # duplicate to cover ranges past the end of the set
x = train.loc[train_ix[start_i: stop_i]].values
return np.reshape(x, (batch_size, -1))

def train_step(self, train_data):
critic_loss, g_loss = self.update_gradients(train_data)
return critic_loss, g_loss

def train(self, data, train_arguments: TrainParameters):
iterations = int(abs(data.shape[0] / self.batch_size) + 1)

# Create a summary file
train_summary_writer = tf.summary.create_file_writer(path.join('..\cramergan_test', 'summaries', 'train'))

with train_summary_writer.as_default():
for epoch in trange(train_arguments.epochs):
for iteration in range(iterations):
batch_data = self.get_data_batch(data, self.batch_size)
c_loss, g_loss = self.train_step(batch_data)

if iteration % train_arguments.sample_interval == 0:
# Test here data generation step
# save model checkpoints
if path.exists('./cache') is False:
os.mkdir('./cache')
model_checkpoint_base_name = './cache/' + train_arguments.cache_prefix + '_{}_model_weights_step_{}.h5'
self.generator.save_weights(model_checkpoint_base_name.format('generator', iteration))
self.critic.save_weights(model_checkpoint_base_name.format('critic', iteration))

print(
"Epoch: {} | critic_loss: {} | gen_loss: {}".format(
epoch, c_loss, g_loss
))

self.g_optimizer=self.g_optimizer.get_config()
self.critic_optimizer=self.c_optimizer.get_config()

def save(self, path):
"""Strip down the optimizers from the model then save."""
for attr in ['g_optimizer', 'c_optimizer']:
try:
delattr(self, attr)
except AttributeError:
continue
super().save(path)


class Generator(tf.keras.Model):
def __init__(self, batch_size):
"""Simple generator with dense feedforward layers."""
self.batch_size = batch_size

def build_model(self, input_shape, dim, data_dim):
input_ = Input(shape=input_shape, batch_size=self.batch_size)
x = Dense(dim, activation='relu')(input_)
x = Dense(dim * 2, activation='relu')(x)
x = Dense(dim * 4, activation='relu')(x)
x = Dense(data_dim)(x)
return Model(inputs=input_, outputs=x)

class Critic(tf.keras.Model):
def __init__(self, batch_size):
"""Simple critic with dense feedforward and dropout layers."""
self.batch_size = batch_size

def build_model(self, input_shape, dim):
input_ = Input(shape=input_shape, batch_size=self.batch_size)
x = Dense(dim * 4, activation='relu')(input_)
x = Dropout(0.1)(x)
x = Dense(dim * 2, activation='relu')(x)
x = Dropout(0.1)(x)
x = Dense(dim, activation='relu')(x)
x = Dense(1)(x)
return Model(inputs=input_, outputs=x)
4 changes: 2 additions & 2 deletions src/ydata_synthetic/synthesizers/regular/dragan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tensorflow.keras import Model, initializers

from ydata_synthetic.synthesizers.gan import BaseModel
from ydata_synthetic.synthesizers.loss import gradient_penalty
from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty

class DRAGAN(BaseModel):

Expand All @@ -33,7 +33,7 @@ def define_gan(self):
self.d_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2, clipvalue=0.001)

def gradient_penalty(self, real, fake):
gp = gradient_penalty(self.discriminator, real, fake, mode='dragan')
gp = gradient_penalty(self.discriminator, real, fake, mode= Mode.DRAGAN)
return gp

def update_gradients(self, x):
Expand Down

0 comments on commit a9de1ab

Please sign in to comment.