Skip to content

Commit

Permalink
fix(VanillaGAN): Add VanillaGAN again to the repo.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Sep 16, 2020
1 parent 4292858 commit 8d8eab3
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 1 deletion.
4 changes: 3 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from models.cgan.model import CGAN
from models.wgan.model import WGAN
from models.vanillagan.model import VanilllaGAN

__all__ = [
"VanilllaGAN",
"CGAN",
"VanilllaGAN"
"WGAN"
]
Empty file added models/vanillagan/__init__.py
Empty file.
146 changes: 146 additions & 0 deletions models/vanillagan/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os
from os import path
import numpy as np

from models import gan

import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam

class VanilllaGAN(gan.Model):

def __init__(self, model_parameters):
super().__init__(model_parameters)

def define_gan(self):
self.generator = Generator(self.batch_size).\
build_model(input_shape=(self.noise_dim,), dim=self.layers_dim, data_dim=self.data_dim)

self.discriminator = Discriminator(self.batch_size).\
build_model(input_shape=(self.data_dim,), dim=self.layers_dim)

optimizer = Adam(self.lr, 0.5)

# Build and compile the discriminator
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])

# The generator takes noise as input and generates imgs
z = Input(shape=(self.noise_dim,))
record = self.generator(z)

# For the combined model we will only train the generator
self.discriminator.trainable = False

# The discriminator takes generated images as input and determines validity
validity = self.discriminator(record)

# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
self._model = Model(z, validity)
self._model.compile(loss='binary_crossentropy', optimizer=optimizer)

def get_data_batch(self, train, batch_size, seed=0):
# # random sampling - some samples will have excessively low or high sampling, but easy to implement
# np.random.seed(seed)
# x = train.loc[ np.random.choice(train.index, batch_size) ].values
# iterate through shuffled indices, so every sample gets covered evenly

start_i = (batch_size * seed) % len(train)
stop_i = start_i + batch_size
shuffle_seed = (batch_size * seed) // len(train)
np.random.seed(shuffle_seed)
train_ix = np.random.choice(list(train.index), replace=False, size=len(train)) # wasteful to shuffle every time
train_ix = list(train_ix) + list(train_ix) # duplicate to cover ranges past the end of the set
x = train.loc[train_ix[start_i: stop_i]].values
return np.reshape(x, (batch_size, -1))

def train(self, data, train_arguments):
[cache_prefix, epochs, sample_interval] = train_arguments

# Adversarial ground truths
valid = np.ones((self.batch_size, 1))
fake = np.zeros((self.batch_size, 1))

for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
batch_data = self.get_data_batch(data, self.batch_size)
noise = tf.random.normal((self.batch_size, self.noise_dim))

# Generate a batch of events
gen_data = self.generator(noise, training=True)

# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(batch_data, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_data, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# ---------------------
# Train Generator
# ---------------------
noise = tf.random.normal((self.batch_size, self.noise_dim))
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.model.train_on_batch(noise, valid)

# Plot the progress
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))

# If at save interval => save generated events
if epoch % sample_interval == 0:
#Test here data generation step
# save model checkpoints
if path.exists('./cache') is False:
os.mkdir('./cache')
model_checkpoint_base_name = './cache/' + cache_prefix + '_{}_model_weights_step_{}.h5'
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
self.discriminator.save_weights(model_checkpoint_base_name.format('discriminator', epoch))

#Here is generating the data
z = tf.random.normal((432, self.noise_dim))
gen_data = self.generator(z)
print('generated_data')

def save(self, path, name):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
model_path = os.path.join(path, name)
self.generator.save_weights(model_path) # Load the generator
return

def load(self, path):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
self.generator = Generator(self.batch_size)
self.generator = self.generator.load_weights(path)
return self.generator

class Generator(tf.keras.Model):
def __init__(self, batch_size):
self.batch_size=batch_size

def build_model(self, input_shape, dim, data_dim):
input= Input(shape=input_shape, batch_size=self.batch_size)
x = Dense(dim, activation='relu')(input)
x = Dense(dim * 2, activation='relu')(x)
x = Dense(dim * 4, activation='relu')(x)
x = Dense(data_dim)(x)
return Model(inputs=input, outputs=x)

class Discriminator(tf.keras.Model):
def __init__(self,batch_size):
self.batch_size=batch_size

def build_model(self, input_shape, dim):
input = Input(shape=input_shape, batch_size=self.batch_size)
x = Dense(dim * 4, activation='relu')(input)
x = Dropout(0.1)(x)
x = Dense(dim * 2, activation='relu')(x)
x = Dropout(0.1)(x)
x = Dense(dim, activation='relu')(x)
x = Dense(1, activation='sigmoid')(x)
return Model(inputs=input, outputs=x)

0 comments on commit 8d8eab3

Please sign in to comment.