Skip to content

Commit

Permalink
fix: Add weight clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Oct 27, 2020
1 parent 85ed725 commit c82b9b5
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/ydata_synthetic/synthesizers/regular/wgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def train(self, data, train_arguments):
[cache_prefix, epochs, sample_interval] = train_arguments

# Adversarial ground truths
valid = -np.ones((self.batch_size, 1))
fake = np.ones((self.batch_size, 1))
valid = np.ones((self.batch_size, 1))
fake = -np.ones((self.batch_size, 1))

for epoch in range(epochs):

Expand All @@ -101,6 +101,12 @@ def train(self, data, train_arguments):
d_loss_fake = self.critic.train_on_batch(gen_data, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# Critic weight clipping
for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
l.set_weights(weights)

# ---------------------
# Train Generator
# ---------------------
Expand All @@ -121,10 +127,6 @@ def train(self, data, train_arguments):
self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch))
self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch))

# Here is generating new data
#z = tf.random.normal((432, self.noise_dim))
#gen_data = self.generator(z)

def load(self, path):
assert os.path.isdir(path) == True, \
"Please provide a valid path. Path must be a directory."
Expand Down

0 comments on commit c82b9b5

Please sign in to comment.