-
Notifications
You must be signed in to change notification settings - Fork 227
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(VanillaGAN): Add VanillaGAN again to the repo.
- Loading branch information
Showing
3 changed files
with
149 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from models.cgan.model import CGAN | ||
from models.wgan.model import WGAN | ||
from models.vanillagan.model import VanilllaGAN | ||
|
||
__all__ = [ | ||
"VanilllaGAN", | ||
"CGAN", | ||
"VanilllaGAN" | ||
"WGAN" | ||
] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import os | ||
from os import path | ||
import numpy as np | ||
|
||
from models import gan | ||
|
||
import tensorflow as tf | ||
from tensorflow.keras.layers import Input, Dense, Dropout | ||
from tensorflow.keras import Model | ||
from tensorflow.keras.optimizers import Adam | ||
|
||
class VanilllaGAN(gan.Model): | ||
|
||
def __init__(self, model_parameters): | ||
super().__init__(model_parameters) | ||
|
||
def define_gan(self): | ||
self.generator = Generator(self.batch_size).\ | ||
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim) | ||
|
||
self.discriminator = Discriminator(self.batch_size).\ | ||
build_model(input_shape=(self.data_dim,), dim=self.layers_dim) | ||
|
||
optimizer = Adam(self.lr, 0.5) | ||
|
||
# Build and compile the discriminator | ||
self.discriminator.compile(loss='binary_crossentropy', | ||
optimizer=optimizer, | ||
metrics=['accuracy']) | ||
|
||
# The generator takes noise as input and generates imgs | ||
z = Input(shape=(self.noise_dim,)) | ||
record = self.generator(z) | ||
|
||
# For the combined model we will only train the generator | ||
self.discriminator.trainable = False | ||
|
||
# The discriminator takes generated images as input and determines validity | ||
validity = self.discriminator(record) | ||
|
||
# The combined model (stacked generator and discriminator) | ||
# Trains the generator to fool the discriminator | ||
self._model = Model(z, validity) | ||
self._model.compile(loss='binary_crossentropy', optimizer=optimizer) | ||
|
||
def get_data_batch(self, train, batch_size, seed=0): | ||
# # random sampling - some samples will have excessively low or high sampling, but easy to implement | ||
# np.random.seed(seed) | ||
# x = train.loc[ np.random.choice(train.index, batch_size) ].values | ||
# iterate through shuffled indices, so every sample gets covered evenly | ||
|
||
start_i = (batch_size * seed) % len(train) | ||
stop_i = start_i + batch_size | ||
shuffle_seed = (batch_size * seed) // len(train) | ||
np.random.seed(shuffle_seed) | ||
train_ix = np.random.choice(list(train.index), replace=False, size=len(train)) # wasteful to shuffle every time | ||
train_ix = list(train_ix) + list(train_ix) # duplicate to cover ranges past the end of the set | ||
x = train.loc[train_ix[start_i: stop_i]].values | ||
return np.reshape(x, (batch_size, -1)) | ||
|
||
def train(self, data, train_arguments): | ||
[cache_prefix, epochs, sample_interval] = train_arguments | ||
|
||
# Adversarial ground truths | ||
valid = np.ones((self.batch_size, 1)) | ||
fake = np.zeros((self.batch_size, 1)) | ||
|
||
for epoch in range(epochs): | ||
# --------------------- | ||
# Train Discriminator | ||
# --------------------- | ||
batch_data = self.get_data_batch(data, self.batch_size) | ||
noise = tf.random.normal((self.batch_size, self.noise_dim)) | ||
|
||
# Generate a batch of events | ||
gen_data = self.generator(noise, training=True) | ||
|
||
# Train the discriminator | ||
d_loss_real = self.discriminator.train_on_batch(batch_data, valid) | ||
d_loss_fake = self.discriminator.train_on_batch(gen_data, fake) | ||
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) | ||
|
||
# --------------------- | ||
# Train Generator | ||
# --------------------- | ||
noise = tf.random.normal((self.batch_size, self.noise_dim)) | ||
# Train the generator (to have the discriminator label samples as valid) | ||
g_loss = self.model.train_on_batch(noise, valid) | ||
|
||
# Plot the progress | ||
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss)) | ||
|
||
# If at save interval => save generated events | ||
if epoch % sample_interval == 0: | ||
#Test here data generation step | ||
# save model checkpoints | ||
if path.exists('./cache') is False: | ||
os.mkdir('./cache') | ||
model_checkpoint_base_name = './cache/' + cache_prefix + '_{}_model_weights_step_{}.h5' | ||
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch)) | ||
self.discriminator.save_weights(model_checkpoint_base_name.format('discriminator', epoch)) | ||
|
||
#Here is generating the data | ||
z = tf.random.normal((432, self.noise_dim)) | ||
gen_data = self.generator(z) | ||
print('generated_data') | ||
|
||
def save(self, path, name): | ||
assert os.path.isdir(path) == True, \ | ||
"Please provide a valid path. Path must be a directory." | ||
model_path = os.path.join(path, name) | ||
self.generator.save_weights(model_path) # Load the generator | ||
return | ||
|
||
def load(self, path): | ||
assert os.path.isdir(path) == True, \ | ||
"Please provide a valid path. Path must be a directory." | ||
self.generator = Generator(self.batch_size) | ||
self.generator = self.generator.load_weights(path) | ||
return self.generator | ||
|
||
class Generator(tf.keras.Model): | ||
def __init__(self, batch_size): | ||
self.batch_size=batch_size | ||
|
||
def build_model(self, input_shape, dim, data_dim): | ||
input= Input(shape=input_shape, batch_size=self.batch_size) | ||
x = Dense(dim, activation='relu')(input) | ||
x = Dense(dim * 2, activation='relu')(x) | ||
x = Dense(dim * 4, activation='relu')(x) | ||
x = Dense(data_dim)(x) | ||
return Model(inputs=input, outputs=x) | ||
|
||
class Discriminator(tf.keras.Model): | ||
def __init__(self,batch_size): | ||
self.batch_size=batch_size | ||
|
||
def build_model(self, input_shape, dim): | ||
input = Input(shape=input_shape, batch_size=self.batch_size) | ||
x = Dense(dim * 4, activation='relu')(input) | ||
x = Dropout(0.1)(x) | ||
x = Dense(dim * 2, activation='relu')(x) | ||
x = Dropout(0.1)(x) | ||
x = Dense(dim, activation='relu')(x) | ||
x = Dense(1, activation='sigmoid')(x) | ||
return Model(inputs=input, outputs=x) |