Skip to content

Commit

Permalink
fix: cgan discriminator repeated label (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Oct 28, 2021
1 parent ba4f86d commit 1c84754
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions examples/regular/cgan_example.py
Expand Up @@ -25,7 +25,7 @@
train_data = data.loc[ data['Class']==1 ].copy()

#Create a new class column using KMeans - This will mainly be useful if we want to leverage conditional GAN
print("Dataset info: Number of records - {} Number of varibles - {}".format(train_data.shape[0], train_data.shape[1]))
print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))
algorithm = cluster.KMeans
args, kwds = (), {'n_clusters':2, 'random_state':0}
labels = algorithm(*args, **kwds).fit_predict(train_data[ data_cols ])
Expand Down Expand Up @@ -63,7 +63,7 @@
lr=learning_rate,
betas=(beta_1, beta_2),
noise_dim=noise_dim,
n_cols=train_sample.shape[1],
n_cols=train_sample.shape[1] - len(label_cols), # Don't count the label columns here
layers_dim=dim)

train_args = TrainParameters(epochs=epochs,
Expand Down
5 changes: 3 additions & 2 deletions src/ydata_synthetic/synthesizers/regular/cgan/model.py
Expand Up @@ -93,14 +93,15 @@ def train(self, data: Union[DataFrame, array],
# ---------------------
batch_x = self.get_data_batch(data, self.batch_size)
label = batch_x[:, train_arguments.label_dim]
data_cols = [i for i in range(batch_x.shape[1] - 1)] # All data without the label columns
noise = tf.random.normal((self.batch_size, self.noise_dim))

# Generate a batch of new records
gen_records = self.generator([noise, label], training=True)

# Train the discriminator
d_loss_real = self.discriminator.train_on_batch([batch_x, label], valid)
d_loss_fake = self.discriminator.train_on_batch([gen_records, label], fake)
d_loss_real = self.discriminator.train_on_batch([batch_x[:, data_cols], label], valid) # Separate labels
d_loss_fake = self.discriminator.train_on_batch([gen_records, label], fake) # Separate labels
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# ---------------------
Expand Down

0 comments on commit 1c84754

Please sign in to comment.