diff --git a/ctgan/synthesizer.py b/ctgan/synthesizer.py index a57cc2bf..ee9762b6 100644 --- a/ctgan/synthesizer.py +++ b/ctgan/synthesizer.py @@ -29,13 +29,20 @@ class CTGANSynthesizer(object): Size of the output samples for each one of the Discriminator Layers. A Linear Layer will be created for each one of the values provided. Defaults to (256, 256). l2scale (float): - Wheight Decay for the Adam Optimizer. Defaults to 1e-6. + Weight Decay for the Adam Optimizer. Defaults to 1e-6. batch_size (int): Number of data samples to process in each step. + discriminator_steps (int): + Number of discriminator updates to do for each generator update. + From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper + default is 5. Default used is 1 to match original CTGAN implementation. + log_frequency (boolean): + Whether to use log frequency of categorical levels in conditional + sampling. Defaults to ``True``. """ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256), - l2scale=1e-6, batch_size=500, log_frequency=True): + l2scale=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True): self.embedding_dim = embedding_dim self.gen_dim = gen_dim @@ -46,6 +53,7 @@ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256), self.log_frequency = log_frequency self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.trained_epoches = 0 + self.discriminator_steps = discriminator_steps @staticmethod def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): @@ -64,9 +72,6 @@ def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): but will be differentiated as if it is the soft sample in autograd dim (int): a dimension along which softmax will be computed. Default: -1. - log_frequency (boolean): - Whether to use log frequency of categorical levels in conditional - sampling. Defaults to ``True``. Returns: Sampled tensor of same shape as logits from the Gumbel-Softmax distribution. @@ -197,46 +202,48 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300): for i in range(epochs): self.trained_epoches += 1 for id_ in range(steps_per_epoch): - fakez = torch.normal(mean=mean, std=std) - - condvec = self.cond_generator.sample(self.batch_size) - if condvec is None: - c1, m1, col, opt = None, None, None, None - real = data_sampler.sample(self.batch_size, col, opt) - else: - c1, m1, col, opt = condvec - c1 = torch.from_numpy(c1).to(self.device) - m1 = torch.from_numpy(m1).to(self.device) - fakez = torch.cat([fakez, c1], dim=1) - - perm = np.arange(self.batch_size) - np.random.shuffle(perm) - real = data_sampler.sample(self.batch_size, col[perm], opt[perm]) - c2 = c1[perm] - - fake = self.generator(fakez) - fakeact = self._apply_activate(fake) - - real = torch.from_numpy(real.astype('float32')).to(self.device) - - if c1 is not None: - fake_cat = torch.cat([fakeact, c1], dim=1) - real_cat = torch.cat([real, c2], dim=1) - else: - real_cat = real - fake_cat = fake - - y_fake = self.discriminator(fake_cat) - y_real = self.discriminator(real_cat) - - pen = self.discriminator.calc_gradient_penalty( - real_cat, fake_cat, self.device) - loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) - self.optimizerD.zero_grad() - pen.backward(retain_graph=True) - loss_d.backward() - self.optimizerD.step() + for n in range(self.discriminator_steps): + fakez = torch.normal(mean=mean, std=std) + + condvec = self.cond_generator.sample(self.batch_size) + if condvec is None: + c1, m1, col, opt = None, None, None, None + real = data_sampler.sample(self.batch_size, col, opt) + else: + c1, m1, col, opt = condvec + c1 = torch.from_numpy(c1).to(self.device) + m1 = torch.from_numpy(m1).to(self.device) + fakez = torch.cat([fakez, c1], dim=1) + + perm = np.arange(self.batch_size) + np.random.shuffle(perm) + real = data_sampler.sample(self.batch_size, col[perm], opt[perm]) + c2 = c1[perm] + + fake = self.generator(fakez) + fakeact = self._apply_activate(fake) + + real = torch.from_numpy(real.astype('float32')).to(self.device) + + if c1 is not None: + fake_cat = torch.cat([fakeact, c1], dim=1) + real_cat = torch.cat([real, c2], dim=1) + else: + real_cat = real + fake_cat = fake + + y_fake = self.discriminator(fake_cat) + y_real = self.discriminator(real_cat) + + pen = self.discriminator.calc_gradient_penalty( + real_cat, fake_cat, self.device) + loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) + + self.optimizerD.zero_grad() + pen.backward(retain_graph=True) + loss_d.backward() + self.optimizerD.step() fakez = torch.normal(mean=mean, std=std) condvec = self.cond_generator.sample(self.batch_size)