Skip to content

Commit

Permalink
feat(cgan): Add Jupyter notebook conditional example.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Sep 14, 2020
1 parent 7e7018d commit 72d9ca4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
File renamed without changes.
11 changes: 8 additions & 3 deletions models/cgan/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from os import path
import numpy as np

import tensorflow as tf
Expand All @@ -19,7 +20,7 @@ def __init__(self, gan_args):
self.discriminator = Discriminator(self.batch_size, num_classes). \
build_model(input_shape=(self.data_dim,), dim=layers_dim)

optimizer = Adam(lr, 0.5)
optimizer = Adam(lr, beta_1=0.5)

# Build and compile the discriminator
self.discriminator.compile(loss='binary_crossentropy',
Expand Down Expand Up @@ -73,7 +74,7 @@ def train(self, data, train_arguments):
noise = tf.random.normal((self.batch_size, self.noise_dim))

# Generate a batch of new records
gen_records = self.generator.predict([noise, label])
gen_records = self.generator([noise, label], training=True)

# Train the discriminator
d_loss_real = self.discriminator.train_on_batch([batch_x, label], valid)
Expand All @@ -94,7 +95,9 @@ def train(self, data, train_arguments):
if epoch % sample_interval == 0:
# Test here data generation step
# save model checkpoints
model_checkpoint_base_name = data_dir + cache_prefix + '_{}_model_weights_step_{}.h5'
if path.exists('./cache') is False:
os.mkdir('./cache')
model_checkpoint_base_name = './cache/' + cache_prefix + '_{}_model_weights_step_{}.h5'
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
self.discriminator.save_weights(model_checkpoint_base_name.format('discriminator', epoch))

Expand Down Expand Up @@ -159,3 +162,5 @@ def build_model(self, input_shape, dim):





0 comments on commit 72d9ca4

Please sign in to comment.