Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 51 additions & 44 deletions ctgan/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@ class CTGANSynthesizer(object):
Size of the output samples for each one of the Discriminator Layers. A Linear Layer
will be created for each one of the values provided. Defaults to (256, 256).
l2scale (float):
Wheight Decay for the Adam Optimizer. Defaults to 1e-6.
Weight Decay for the Adam Optimizer. Defaults to 1e-6.
batch_size (int):
Number of data samples to process in each step.
discriminator_steps (int):
Number of discriminator updates to do for each generator update.
From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper
default is 5. Default used is 1 to match original CTGAN implementation.
log_frequency (boolean):
Whether to use log frequency of categorical levels in conditional
sampling. Defaults to ``True``.
"""

def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256),
l2scale=1e-6, batch_size=500, log_frequency=True):
l2scale=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True):

self.embedding_dim = embedding_dim
self.gen_dim = gen_dim
Expand All @@ -46,6 +53,7 @@ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256),
self.log_frequency = log_frequency
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.trained_epoches = 0
self.discriminator_steps = discriminator_steps

@staticmethod
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
Expand All @@ -64,9 +72,6 @@ def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
but will be differentiated as if it is the soft sample in autograd
dim (int):
a dimension along which softmax will be computed. Default: -1.
log_frequency (boolean):
Whether to use log frequency of categorical levels in conditional
sampling. Defaults to ``True``.

Returns:
Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
Expand Down Expand Up @@ -197,46 +202,48 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300):
for i in range(epochs):
self.trained_epoches += 1
for id_ in range(steps_per_epoch):
fakez = torch.normal(mean=mean, std=std)

condvec = self.cond_generator.sample(self.batch_size)
if condvec is None:
c1, m1, col, opt = None, None, None, None
real = data_sampler.sample(self.batch_size, col, opt)
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self.device)
m1 = torch.from_numpy(m1).to(self.device)
fakez = torch.cat([fakez, c1], dim=1)

perm = np.arange(self.batch_size)
np.random.shuffle(perm)
real = data_sampler.sample(self.batch_size, col[perm], opt[perm])
c2 = c1[perm]

fake = self.generator(fakez)
fakeact = self._apply_activate(fake)

real = torch.from_numpy(real.astype('float32')).to(self.device)

if c1 is not None:
fake_cat = torch.cat([fakeact, c1], dim=1)
real_cat = torch.cat([real, c2], dim=1)
else:
real_cat = real
fake_cat = fake

y_fake = self.discriminator(fake_cat)
y_real = self.discriminator(real_cat)

pen = self.discriminator.calc_gradient_penalty(
real_cat, fake_cat, self.device)
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))

self.optimizerD.zero_grad()
pen.backward(retain_graph=True)
loss_d.backward()
self.optimizerD.step()
for n in range(self.discriminator_steps):
fakez = torch.normal(mean=mean, std=std)

condvec = self.cond_generator.sample(self.batch_size)
if condvec is None:
c1, m1, col, opt = None, None, None, None
real = data_sampler.sample(self.batch_size, col, opt)
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self.device)
m1 = torch.from_numpy(m1).to(self.device)
fakez = torch.cat([fakez, c1], dim=1)

perm = np.arange(self.batch_size)
np.random.shuffle(perm)
real = data_sampler.sample(self.batch_size, col[perm], opt[perm])
c2 = c1[perm]

fake = self.generator(fakez)
fakeact = self._apply_activate(fake)

real = torch.from_numpy(real.astype('float32')).to(self.device)

if c1 is not None:
fake_cat = torch.cat([fakeact, c1], dim=1)
real_cat = torch.cat([real, c2], dim=1)
else:
real_cat = real
fake_cat = fake

y_fake = self.discriminator(fake_cat)
y_real = self.discriminator(real_cat)

pen = self.discriminator.calc_gradient_penalty(
real_cat, fake_cat, self.device)
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))

self.optimizerD.zero_grad()
pen.backward(retain_graph=True)
loss_d.backward()
self.optimizerD.step()

fakez = torch.normal(mean=mean, std=std)
condvec = self.cond_generator.sample(self.batch_size)
Expand Down