diff --git a/src/ydata_synthetic/synthesizers/regular/wgan/model.py b/src/ydata_synthetic/synthesizers/regular/wgan/model.py index a13fad7b..4a206169 100644 --- a/src/ydata_synthetic/synthesizers/regular/wgan/model.py +++ b/src/ydata_synthetic/synthesizers/regular/wgan/model.py @@ -81,8 +81,8 @@ def train(self, data, train_arguments): [cache_prefix, epochs, sample_interval] = train_arguments # Adversarial ground truths - valid = -np.ones((self.batch_size, 1)) - fake = np.ones((self.batch_size, 1)) + valid = np.ones((self.batch_size, 1)) + fake = -np.ones((self.batch_size, 1)) for epoch in range(epochs): @@ -101,6 +101,12 @@ def train(self, data, train_arguments): d_loss_fake = self.critic.train_on_batch(gen_data, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) + # Critic weight clipping + for l in self.critic.layers: + weights = l.get_weights() + weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights] + l.set_weights(weights) + # --------------------- # Train Generator # --------------------- @@ -121,10 +127,6 @@ def train(self, data, train_arguments): self.generator.save_weights(model_checkpoint_base_name.format('generator', epoch)) self.critic.save_weights(model_checkpoint_base_name.format('critic', epoch)) - # Here is generating new data - #z = tf.random.normal((432, self.noise_dim)) - #gen_data = self.generator(z) - def load(self, path): assert os.path.isdir(path) == True, \ "Please provide a valid path. Path must be a directory."