-
Notifications
You must be signed in to change notification settings - Fork 232
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add new Cramer Loss and Cramer GAN (#102)
- Loading branch information
Showing
6 changed files
with
323 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
#Install ydata-synthetic lib | ||
# pip install ydata-synthetic | ||
import sklearn.cluster as cluster | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters | ||
from ydata_synthetic.synthesizers.regular import CRAMERGAN | ||
from ydata_synthetic.preprocessing.regular.credit_fraud import transformations | ||
|
||
model = CRAMERGAN | ||
|
||
#Read the original data and have it preprocessed | ||
data = pd.read_csv('data/creditcard.csv', index_col=[0]) | ||
|
||
#Data processing and analysis | ||
data_cols = list(data.columns[ data.columns != 'Class' ]) | ||
label_cols = ['Class'] | ||
|
||
print('Dataset columns: {}'.format(data_cols)) | ||
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class'] | ||
processed_data = data[ sorted_cols ].copy() | ||
|
||
#Before training the GAN do not forget to apply the required data transformations | ||
#To ease here we've applied a PowerTransformation | ||
data = transformations(data) | ||
|
||
#For the purpose of this example we will only synthesize the minority class | ||
train_data = data.loc[ data['Class']==1 ].copy() | ||
|
||
print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1])) | ||
|
||
algorithm = cluster.KMeans | ||
args, kwds = (), {'n_clusters':2, 'random_state':0} | ||
labels = algorithm(*args, **kwds).fit_predict(train_data[ data_cols ]) | ||
|
||
print( pd.DataFrame( [ [np.sum(labels==i)] for i in np.unique(labels) ], columns=['count'], index=np.unique(labels) ) ) | ||
|
||
fraud_w_classes = train_data.copy() | ||
fraud_w_classes['Class'] = labels | ||
|
||
# GAN training | ||
#Define the GAN and training parameters | ||
noise_dim = 32 | ||
dim = 128 | ||
batch_size = 128 | ||
|
||
log_step = 100 | ||
epochs = 300+1 | ||
learning_rate = 5e-4 | ||
beta_1 = 0.5 | ||
beta_2 = 0.9 | ||
models_dir = './cache' | ||
|
||
train_sample = fraud_w_classes.copy().reset_index(drop=True) | ||
train_sample = pd.get_dummies(train_sample, columns=['Class'], prefix='Class', drop_first=True) | ||
label_cols = [ i for i in train_sample.columns if 'Class' in i ] | ||
data_cols = [ i for i in train_sample.columns if i not in label_cols ] | ||
train_sample[ data_cols ] = train_sample[ data_cols ] / 10 # scale to random noise size, one less thing to learn | ||
train_no_label = train_sample[ data_cols ] | ||
|
||
model_parameters = ModelParameters(batch_size=batch_size, | ||
lr=learning_rate, | ||
betas=(beta_1, beta_2), | ||
noise_dim=noise_dim, | ||
n_cols=train_sample.shape[1], | ||
layers_dim=dim) | ||
|
||
train_args = TrainParameters(epochs=epochs, | ||
sample_interval=log_step) | ||
|
||
test_size = 492 # number of fraud cases | ||
noise_dim = 32 | ||
|
||
#Training the CRAMERGAN model | ||
synthesizer = model(model_parameters, gradient_penalty_weight=10) | ||
synthesizer.train(train_sample, train_args) | ||
|
||
#Saving the synthesizer to later generate new events | ||
synthesizer.save(path='models/cramergan_creditcard.pkl') | ||
|
||
#Loading the synthesizer | ||
synth = model.load(path='models/cramergan_creditcard.pkl') | ||
|
||
#Sampling the data | ||
#Note that the data returned it is not inverse processed. | ||
data_sample = synth.sample(100000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
209 changes: 209 additions & 0 deletions
209
src/ydata_synthetic/synthesizers/regular/cramergan/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
import os | ||
from os import path | ||
import numpy as np | ||
from tqdm import trange | ||
|
||
from ydata_synthetic.synthesizers.gan import BaseModel | ||
from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty | ||
from ydata_synthetic.synthesizers import TrainParameters | ||
|
||
import tensorflow as tf | ||
from tensorflow.keras.layers import Input, Dense, Dropout | ||
from tensorflow.keras import Model | ||
from tensorflow.keras.optimizers import Adam | ||
|
||
class CRAMERGAN(BaseModel): | ||
|
||
__MODEL__='CRAMERGAN' | ||
|
||
def __init__(self, model_parameters, gradient_penalty_weight=10): | ||
"""Create a base CramerGAN. | ||
Based according to the WGAN paper - https://arxiv.org/pdf/1705.10743.pdf | ||
CramerGAN, a solution to biased Wassertein Gradients https://arxiv.org/abs/1705.10743""" | ||
self.gradient_penalty_weight = gradient_penalty_weight | ||
super().__init__(model_parameters) | ||
|
||
def define_gan(self): | ||
self.generator = Generator(self.batch_size). \ | ||
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim) | ||
|
||
self.critic = Critic(self.batch_size). \ | ||
build_model(input_shape=(self.data_dim,), dim=self.layers_dim) | ||
|
||
self.g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2) | ||
self.c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2) | ||
|
||
# The generator takes noise as input and generates records | ||
z = Input(shape=(self.noise_dim,), batch_size=self.batch_size) | ||
fake = self.generator(z, training=True) | ||
logits = self.critic(fake, training=True) | ||
|
||
# Compile the critic | ||
self.critic.compile(loss=self.c_lossfn, | ||
optimizer=self.c_optimizer, | ||
metrics=['accuracy']) | ||
|
||
# Generator and critic model | ||
_model = Model(z, logits) | ||
_model.compile(loss=self.g_lossfn, optimizer=self.g_optimizer) | ||
|
||
def gradient_penalty(self, real, fake): | ||
gp = gradient_penalty(self.f_crit, real, fake, mode=Mode.CRAMER) | ||
return gp | ||
|
||
def update_gradients(self, x): | ||
"""Compute and apply the gradients for both the Generator and the Critic. | ||
:param x: real data event | ||
:return: generator gradients, critic gradients | ||
""" | ||
# Update the gradients of critic for n_critic times (Training the critic) | ||
|
||
##New generator gradient_tape | ||
noise= tf.random.normal([x.shape[0], self.noise_dim], dtype=tf.dtypes.float32) | ||
noise2= tf.random.normal([x.shape[0], self.noise_dim], dtype=tf.dtypes.float32) | ||
|
||
with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: | ||
fake=self.generator(noise, training=True) | ||
fake2=self.generator(noise2, training=True) | ||
|
||
g_loss = self.g_lossfn(x, fake, fake2) | ||
|
||
c_loss = self.c_lossfn(x, fake, fake2) | ||
|
||
# Get the gradients of the generator | ||
g_gradients = g_tape.gradient(g_loss, self.generator.trainable_variables) | ||
|
||
# Update the weights of the generator | ||
self.g_optimizer.apply_gradients( | ||
zip(g_gradients, self.generator.trainable_variables) | ||
) | ||
|
||
c_gradient = d_tape.gradient(c_loss, self.critic.trainable_variables) | ||
# Update the weights of the critic using the optimizer | ||
self.c_optimizer.apply_gradients( | ||
zip(c_gradient, self.critic.trainable_variables) | ||
) | ||
|
||
return c_loss, g_loss | ||
|
||
def g_lossfn(self, real, fake, fake2): | ||
"""Compute generator loss function according to the CramerGAN paper. | ||
:param real: A real sample | ||
:param fake: A fake sample | ||
:param fak2: A second fake sample | ||
:return: Loss of the generator | ||
""" | ||
g_loss = tf.norm(self.critic(real, training=True) - self.critic(fake, training=True), axis=1) + \ | ||
tf.norm(self.critic(real, training=True) - self.critic(fake2, training=True), axis=1) - \ | ||
tf.norm(self.critic(fake, training=True) - self.critic(fake2, training=True), axis=1) | ||
return tf.reduce_mean(g_loss) | ||
|
||
def f_crit(self, real, fake): | ||
""" | ||
Computes the critic distance function f between two samples | ||
:param real: A real sample | ||
:param fake: A fake sample | ||
:return: Loss of the critic | ||
""" | ||
return tf.norm(self.critic(real, training=True) - self.critic(fake, training=True), axis=1) - tf.norm(self.critic(real, training=True), axis=1) | ||
|
||
def c_lossfn(self, real, fake, fake2): | ||
""" | ||
:param real: A real sample | ||
:param fake: A fake sample | ||
:param fak2: A second fake sample | ||
:return: Loss of the critic | ||
""" | ||
f_real = self.f_crit(real, fake2) | ||
f_fake = self.f_crit(fake, fake2) | ||
loss_surrogate = f_real - f_fake | ||
gp = self.gradient_penalty(real, [fake, fake2]) | ||
return tf.reduce_mean(- loss_surrogate + self.gradient_penalty_weight*gp) | ||
|
||
@staticmethod | ||
def get_data_batch(train, batch_size, seed=0): | ||
# np.random.seed(seed) | ||
# x = train.loc[ np.random.choice(train.index, batch_size) ].values | ||
# iterate through shuffled indices, so every sample gets covered evenly | ||
start_i = (batch_size * seed) % len(train) | ||
stop_i = start_i + batch_size | ||
shuffle_seed = (batch_size * seed) // len(train) | ||
np.random.seed(shuffle_seed) | ||
train_ix = np.random.choice(list(train.index), replace=False, size=len(train)) # wasteful to shuffle every time | ||
train_ix = list(train_ix) + list(train_ix) # duplicate to cover ranges past the end of the set | ||
x = train.loc[train_ix[start_i: stop_i]].values | ||
return np.reshape(x, (batch_size, -1)) | ||
|
||
def train_step(self, train_data): | ||
critic_loss, g_loss = self.update_gradients(train_data) | ||
return critic_loss, g_loss | ||
|
||
def train(self, data, train_arguments: TrainParameters): | ||
iterations = int(abs(data.shape[0] / self.batch_size) + 1) | ||
|
||
# Create a summary file | ||
train_summary_writer = tf.summary.create_file_writer(path.join('..\cramergan_test', 'summaries', 'train')) | ||
|
||
with train_summary_writer.as_default(): | ||
for epoch in trange(train_arguments.epochs): | ||
for iteration in range(iterations): | ||
batch_data = self.get_data_batch(data, self.batch_size) | ||
c_loss, g_loss = self.train_step(batch_data) | ||
|
||
if iteration % train_arguments.sample_interval == 0: | ||
# Test here data generation step | ||
# save model checkpoints | ||
if path.exists('./cache') is False: | ||
os.mkdir('./cache') | ||
model_checkpoint_base_name = './cache/' + train_arguments.cache_prefix + '_{}_model_weights_step_{}.h5' | ||
self.generator.save_weights(model_checkpoint_base_name.format('generator', iteration)) | ||
self.critic.save_weights(model_checkpoint_base_name.format('critic', iteration)) | ||
|
||
print( | ||
"Epoch: {} | critic_loss: {} | gen_loss: {}".format( | ||
epoch, c_loss, g_loss | ||
)) | ||
|
||
self.g_optimizer=self.g_optimizer.get_config() | ||
self.critic_optimizer=self.c_optimizer.get_config() | ||
|
||
def save(self, path): | ||
"""Strip down the optimizers from the model then save.""" | ||
for attr in ['g_optimizer', 'c_optimizer']: | ||
try: | ||
delattr(self, attr) | ||
except AttributeError: | ||
continue | ||
super().save(path) | ||
|
||
|
||
class Generator(tf.keras.Model): | ||
def __init__(self, batch_size): | ||
"""Simple generator with dense feedforward layers.""" | ||
self.batch_size = batch_size | ||
|
||
def build_model(self, input_shape, dim, data_dim): | ||
input_ = Input(shape=input_shape, batch_size=self.batch_size) | ||
x = Dense(dim, activation='relu')(input_) | ||
x = Dense(dim * 2, activation='relu')(x) | ||
x = Dense(dim * 4, activation='relu')(x) | ||
x = Dense(data_dim)(x) | ||
return Model(inputs=input_, outputs=x) | ||
|
||
class Critic(tf.keras.Model): | ||
def __init__(self, batch_size): | ||
"""Simple critic with dense feedforward and dropout layers.""" | ||
self.batch_size = batch_size | ||
|
||
def build_model(self, input_shape, dim): | ||
input_ = Input(shape=input_shape, batch_size=self.batch_size) | ||
x = Dense(dim * 4, activation='relu')(input_) | ||
x = Dropout(0.1)(x) | ||
x = Dense(dim * 2, activation='relu')(x) | ||
x = Dropout(0.1)(x) | ||
x = Dense(dim, activation='relu')(x) | ||
x = Dense(1)(x) | ||
return Model(inputs=input_, outputs=x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters