-
Notifications
You must be signed in to change notification settings - Fork 228
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(gan): VanillaGAN structure correction
- Loading branch information
Showing
10 changed files
with
858 additions
and
5,882 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from models.cgan.model import CGAN | ||
from models.vanillagan.model import VanilllaGAN | ||
|
||
__all__ = [ | ||
"CGAN", | ||
"VanilllaGAN" | ||
] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from abc import ABC | ||
from abc import abstractmethod | ||
|
||
import os | ||
from easydict import EasyDict as edict | ||
from tensorflow.python import keras | ||
|
||
class Model(ABC): | ||
def __init__( | ||
self, | ||
model_parameters: edict = None, | ||
): | ||
self._model_parameters = model_parameters | ||
[self.batch_size, self.lr, self.noise_dim, | ||
self.data_dim, self.layers_dim] = model_parameters | ||
self.define_gan() | ||
|
||
def __call__(self, inputs, **kwargs): | ||
return self.model(inputs=inputs, **kwargs) | ||
|
||
@abstractmethod | ||
def define_gan(self) -> keras.Model: | ||
raise NotImplementedError | ||
|
||
@property | ||
def trainable_variables(self, network): | ||
return network.trainable_variables | ||
|
||
@property | ||
def model(self): | ||
return self._model | ||
|
||
@property | ||
def model_parameters(self) -> edict: | ||
return self._model_parameters | ||
|
||
@property | ||
def model_name(self) -> str: | ||
return self.__class__.__name__ | ||
|
||
@abstractmethod | ||
def train(self, data, train_arguments): | ||
raise NotImplementedError | ||
|
||
def save(self, path, name): | ||
assert os.path.isdir(path) == True, \ | ||
"Please provide a valid path. Path must be a directory." | ||
model_path = os.path.join(path, name) | ||
self.generator.save_weights(model_path) # Load the generator | ||
return |
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import os | ||
from os import path | ||
import numpy as np | ||
from functools import partial | ||
|
||
from models import gan | ||
|
||
import tensorflow as tf | ||
from tensorflow.keras.layers import Input, Dense, Dropout | ||
import tensorflow.keras.backend as K | ||
from tensorflow.keras import Model | ||
from tensorflow.keras.optimizers import Adam | ||
|
||
#Auxiliary Keras backend class to calculate the Random Weighted average | ||
#https://stackoverflow.com/questions/58133430/how-to-substitute-keras-layers-merge-merge-in-tensorflow-keras | ||
class RandomWeightedAverage(tf.keras.layers.Layer): | ||
def __init__(self, batch_size): | ||
super().__init__() | ||
self.batch_size = batch_size | ||
|
||
def call(self, inputs, **kwargs): | ||
alpha = tf.random_uniform((self.batch_size, 1, 1, 1)) | ||
return (alpha * inputs[0]) + ((1 - alpha) * inputs[1]) | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape[0] | ||
|
||
class WGAN(gan.Model): | ||
|
||
def __init__(self, model_parameters, n_critic): | ||
# As recommended in WGAN paper - https://arxiv.org/abs/1701.07875 | ||
self.n_critic = n_critic | ||
super().__init__(model_parameters) | ||
|
||
def wasserstein_loss(self, y_true, y_pred): | ||
return K.mean(y_true * y_pred) | ||
|
||
def define_gan(self): | ||
self.generator = Generator(self.batch_size). \ | ||
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim) | ||
|
||
self.critic = Critic(self.batch_size). \ | ||
build_model(input_shape=(self.data_dim,), dim=self.layers_dim) | ||
|
||
optimizer = Adam(self.lr, beta_1=0.5, beta_2=0.9) | ||
self.critic_optimizer = Adam(self.lr, beta_1=0.5, beta_2=0.9) | ||
|
||
# Build and compile the critic | ||
self.critic.compile(loss=self.wasserstein_loss, | ||
optimizer=self.critic_optimizer, | ||
metrics=['accuracy']) | ||
|
||
# The generator takes noise as input and generates imgs | ||
z = Input(shape=(self.noise_dim,)) | ||
record = self.generator(z) | ||
# The discriminator takes generated images as input and determines validity | ||
validity = self.critic(record) | ||
|
||
# For the combined model we will only train the generator | ||
self.critic.trainable = False | ||
|
||
# The combined model (stacked generator and discriminator) | ||
# Trains the generator to fool the discriminator | ||
#For the WGAN model use the Wassertein loss | ||
self._model = Model(z, validity) | ||
self._model.compile(loss=self.wasserstein_loss, optimizer=optimizer) | ||
|
||
def get_data_batch(self, train, batch_size, seed=0): | ||
# np.random.seed(seed) | ||
# x = train.loc[ np.random.choice(train.index, batch_size) ].values | ||
# iterate through shuffled indices, so every sample gets covered evenly | ||
start_i = (batch_size * seed) % len(train) | ||
stop_i = start_i + batch_size | ||
shuffle_seed = (batch_size * seed) // len(train) | ||
np.random.seed(shuffle_seed) | ||
train_ix = np.random.choice(list(train.index), replace=False, size=len(train)) # wasteful to shuffle every time | ||
train_ix = list(train_ix) + list(train_ix) # duplicate to cover ranges past the end of the set | ||
x = train.loc[train_ix[start_i: stop_i]].values | ||
return np.reshape(x, (batch_size, -1)) | ||
|
||
def train(self, data, train_arguments): | ||
[cache_prefix, epochs, sample_interval] = train_arguments | ||
|
||
# Adversarial ground truths | ||
valid = -np.ones((self.batch_size, 1)) | ||
fake = np.ones((self.batch_size, 1)) | ||
|
||
for epoch in range(epochs): | ||
|
||
for _ in range(self.n_critic): | ||
# --------------------- | ||
# Train the Critic | ||
# --------------------- | ||
batch_data = self.get_data_batch(data, self.batch_size) | ||
noise = tf.random.normal((self.batch_size, self.noise_dim)) | ||
|
||
# Generate a batch of events | ||
gen_data = self.generator(noise) | ||
|
||
# Train the Critic | ||
d_loss_real = self.critic.train_on_batch(batch_data, valid) | ||
d_loss_fake = self.critic.train_on_batch(gen_data, fake) | ||
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) | ||
|
||
# --------------------- | ||
# Train Generator | ||
# --------------------- | ||
noise = tf.random.normal((self.batch_size, self.noise_dim)) | ||
# Train the generator (to have the critic label samples as valid) | ||
g_loss = self.model.train_on_batch(noise, valid) | ||
|
||
# Plot the progress | ||
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss)) | ||
|
||
# If at save interval => save generated events | ||
if epoch % sample_interval == 0: | ||
# Test here data generation step | ||
# save model checkpoints | ||
if path.exists('./cache') is False: | ||
os.mkdir('./cache') | ||
model_checkpoint_base_name = './cache/' + cache_prefix + '_{}_model_weights_step_{}.h5' | ||
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch)) | ||
self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch)) | ||
|
||
# Here is generating new data | ||
#z = tf.random.normal((432, self.noise_dim)) | ||
#gen_data = self.generator(z) | ||
|
||
def load(self, path): | ||
assert os.path.isdir(path) == True, \ | ||
"Please provide a valid path. Path must be a directory." | ||
self.generator = Generator(self.batch_size) | ||
self.generator = self.generator.load_weights(path) | ||
return self.generator | ||
|
||
|
||
class Generator(tf.keras.Model): | ||
def __init__(self, batch_size): | ||
self.batch_size = batch_size | ||
|
||
def build_model(self, input_shape, dim, data_dim): | ||
input = Input(shape=input_shape, batch_size=self.batch_size) | ||
x = Dense(dim, activation='relu')(input) | ||
x = Dense(dim * 2, activation='relu')(x) | ||
x = Dense(dim * 4, activation='relu')(x) | ||
x = Dense(data_dim)(x) | ||
return Model(inputs=input, outputs=x) | ||
|
||
class Critic(tf.keras.Model): | ||
def __init__(self, batch_size): | ||
self.batch_size = batch_size | ||
|
||
def build_model(self, input_shape, dim): | ||
input = Input(shape=input_shape, batch_size=self.batch_size) | ||
x = Dense(dim * 4, activation='relu')(input) | ||
x = Dropout(0.1)(x) | ||
x = Dense(dim * 2, activation='relu')(x) | ||
x = Dropout(0.1)(x) | ||
x = Dense(dim, activation='relu')(x) | ||
x = Dense(1)(x) | ||
return Model(inputs=input, outputs=x) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,5 +2,6 @@ pandas==1.0.3 | |
numpy==1.17.4 | ||
scikit-learn==0.22.2 | ||
matplotlib | ||
easydict | ||
|
||
tensorflow==2.1.0 |