From 0640b0fe2dc9c16f83132c4c11960817764a334e Mon Sep 17 00:00:00 2001 From: Bauke Brenninkmeijer Date: Wed, 2 Dec 2020 21:38:16 +0100 Subject: [PATCH 1/5] Add n_discriminator steps --- ctgan/synthesizer.py | 86 +++++++++++++++++++++++--------------------- 1 file changed, 46 insertions(+), 40 deletions(-) diff --git a/ctgan/synthesizer.py b/ctgan/synthesizer.py index 93a94520..576229a5 100644 --- a/ctgan/synthesizer.py +++ b/ctgan/synthesizer.py @@ -130,7 +130,8 @@ def _cond_loss(self, data, c, m): return (loss * m).sum() / data.size()[0] - def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=True): + def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=True, + n_discriminator=5): """Fit the CTGAN Synthesizer models to the training data. Args: @@ -147,6 +148,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr log_frequency (boolean): Whether to use log frequency of categorical levels in conditional sampling. Defaults to ``True``. + n_discriminator (int): + Number of discriminator updates to do for each generator update. + Defaults to 5. """ if not hasattr(self, "transformer"): @@ -196,46 +200,48 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr for i in range(epochs): self.trained_epoches += 1 for id_ in range(steps_per_epoch): - fakez = torch.normal(mean=mean, std=std) - - condvec = self.cond_generator.sample(self.batch_size) - if condvec is None: - c1, m1, col, opt = None, None, None, None - real = data_sampler.sample(self.batch_size, col, opt) - else: - c1, m1, col, opt = condvec - c1 = torch.from_numpy(c1).to(self.device) - m1 = torch.from_numpy(m1).to(self.device) - fakez = torch.cat([fakez, c1], dim=1) - - perm = np.arange(self.batch_size) - np.random.shuffle(perm) - real = data_sampler.sample(self.batch_size, col[perm], opt[perm]) - c2 = c1[perm] - - fake = self.generator(fakez) - fakeact = self._apply_activate(fake) - - real = torch.from_numpy(real.astype('float32')).to(self.device) - - if c1 is not None: - fake_cat = torch.cat([fakeact, c1], dim=1) - real_cat = torch.cat([real, c2], dim=1) - else: - real_cat = real - fake_cat = fake - - y_fake = self.discriminator(fake_cat) - y_real = self.discriminator(real_cat) - - pen = self.discriminator.calc_gradient_penalty( - real_cat, fake_cat, self.device) - loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) - self.optimizerD.zero_grad() - pen.backward(retain_graph=True) - loss_d.backward() - self.optimizerD.step() + for n in range(n_discriminator): + fakez = torch.normal(mean=mean, std=std) + + condvec = self.cond_generator.sample(self.batch_size) + if condvec is None: + c1, m1, col, opt = None, None, None, None + real = data_sampler.sample(self.batch_size, col, opt) + else: + c1, m1, col, opt = condvec + c1 = torch.from_numpy(c1).to(self.device) + m1 = torch.from_numpy(m1).to(self.device) + fakez = torch.cat([fakez, c1], dim=1) + + perm = np.arange(self.batch_size) + np.random.shuffle(perm) + real = data_sampler.sample(self.batch_size, col[perm], opt[perm]) + c2 = c1[perm] + + fake = self.generator(fakez) + fakeact = self._apply_activate(fake) + + real = torch.from_numpy(real.astype('float32')).to(self.device) + + if c1 is not None: + fake_cat = torch.cat([fakeact, c1], dim=1) + real_cat = torch.cat([real, c2], dim=1) + else: + real_cat = real + fake_cat = fake + + y_fake = self.discriminator(fake_cat) + y_real = self.discriminator(real_cat) + + pen = self.discriminator.calc_gradient_penalty( + real_cat, fake_cat, self.device) + loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) + + self.optimizerD.zero_grad() + pen.backward(retain_graph=True) + loss_d.backward() + self.optimizerD.step() fakez = torch.normal(mean=mean, std=std) condvec = self.cond_generator.sample(self.batch_size) From 6048013b8d5d38974dd3d7830b73a44bb105e9bb Mon Sep 17 00:00:00 2001 From: Bauke Brenninkmeijer Date: Wed, 2 Dec 2020 21:54:40 +0100 Subject: [PATCH 2/5] move parameter to init --- ctgan/synthesizer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ctgan/synthesizer.py b/ctgan/synthesizer.py index 576229a5..92081d57 100644 --- a/ctgan/synthesizer.py +++ b/ctgan/synthesizer.py @@ -32,10 +32,13 @@ class CTGANSynthesizer(object): Wheight Decay for the Adam Optimizer. Defaults to 1e-6. batch_size (int): Number of data samples to process in each step. + n_discriminator (int): + Number of discriminator updates to do for each generator update. + Defaults to 5. """ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256), - l2scale=1e-6, batch_size=500): + l2scale=1e-6, batch_size=500, n_discriminator=1): self.embedding_dim = embedding_dim self.gen_dim = gen_dim @@ -45,6 +48,7 @@ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256), self.batch_size = batch_size self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.trained_epoches = 0 + self.n_discriminator = n_discriminator @staticmethod def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): @@ -130,8 +134,7 @@ def _cond_loss(self, data, c, m): return (loss * m).sum() / data.size()[0] - def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=True, - n_discriminator=5): + def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=True): """Fit the CTGAN Synthesizer models to the training data. Args: @@ -148,9 +151,6 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr log_frequency (boolean): Whether to use log frequency of categorical levels in conditional sampling. Defaults to ``True``. - n_discriminator (int): - Number of discriminator updates to do for each generator update. - Defaults to 5. """ if not hasattr(self, "transformer"): @@ -201,7 +201,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr self.trained_epoches += 1 for id_ in range(steps_per_epoch): - for n in range(n_discriminator): + for n in range(self.n_discriminator): fakez = torch.normal(mean=mean, std=std) condvec = self.cond_generator.sample(self.batch_size) From 66781703efd643b3a5a64fd8261b964594061c7e Mon Sep 17 00:00:00 2001 From: Bauke Brenninkmeijer Date: Thu, 3 Dec 2020 09:50:35 +0100 Subject: [PATCH 3/5] Update synthesizer.py --- ctgan/synthesizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ctgan/synthesizer.py b/ctgan/synthesizer.py index 92081d57..7f11cbb5 100644 --- a/ctgan/synthesizer.py +++ b/ctgan/synthesizer.py @@ -33,8 +33,8 @@ class CTGANSynthesizer(object): batch_size (int): Number of data samples to process in each step. n_discriminator (int): - Number of discriminator updates to do for each generator update. - Defaults to 5. + Number of discriminator updates to do for each generator update. WGAN paper + defaults to 5. Defaults to 1. """ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256), From 9228d6fd9e1781fd7c105b5d962326f0aab31fc2 Mon Sep 17 00:00:00 2001 From: Bauke Date: Thu, 3 Dec 2020 10:02:16 +0100 Subject: [PATCH 4/5] remove whitespace --- ctgan/synthesizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ctgan/synthesizer.py b/ctgan/synthesizer.py index 7f11cbb5..330439bd 100644 --- a/ctgan/synthesizer.py +++ b/ctgan/synthesizer.py @@ -33,7 +33,7 @@ class CTGANSynthesizer(object): batch_size (int): Number of data samples to process in each step. n_discriminator (int): - Number of discriminator updates to do for each generator update. WGAN paper + Number of discriminator updates to do for each generator update. WGAN paper defaults to 5. Defaults to 1. """ From 3869c528d69e99f733be8d7e7be76d7294619950 Mon Sep 17 00:00:00 2001 From: Bauke Date: Thu, 3 Dec 2020 10:52:40 +0100 Subject: [PATCH 5/5] Add extra information in docstring and change variable name to discriminator_steps --- ctgan/synthesizer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/ctgan/synthesizer.py b/ctgan/synthesizer.py index 330439bd..83df6f68 100644 --- a/ctgan/synthesizer.py +++ b/ctgan/synthesizer.py @@ -29,16 +29,17 @@ class CTGANSynthesizer(object): Size of the output samples for each one of the Discriminator Layers. A Linear Layer will be created for each one of the values provided. Defaults to (256, 256). l2scale (float): - Wheight Decay for the Adam Optimizer. Defaults to 1e-6. + Weight Decay for the Adam Optimizer. Defaults to 1e-6. batch_size (int): Number of data samples to process in each step. - n_discriminator (int): - Number of discriminator updates to do for each generator update. WGAN paper - defaults to 5. Defaults to 1. + discriminator_steps (int): + Number of discriminator updates to do for each generator update. + From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper + default is 5. Default used is 1 to match original CTGAN implementation. """ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256), - l2scale=1e-6, batch_size=500, n_discriminator=1): + l2scale=1e-6, batch_size=500, discriminator_steps=1): self.embedding_dim = embedding_dim self.gen_dim = gen_dim @@ -48,7 +49,7 @@ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256), self.batch_size = batch_size self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.trained_epoches = 0 - self.n_discriminator = n_discriminator + self.discriminator_steps = discriminator_steps @staticmethod def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): @@ -201,7 +202,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr self.trained_epoches += 1 for id_ in range(steps_per_epoch): - for n in range(self.n_discriminator): + for n in range(self.discriminator_steps): fakez = torch.normal(mean=mean, std=std) condvec = self.cond_generator.sample(self.batch_size)