Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 45 additions & 26 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

class Discriminator(Module):

def __init__(self, input_dim, discriminator_dim, pack=10):
def __init__(self, input_dim, discriminator_dim, pac=10):
super(Discriminator, self).__init__()
dim = input_dim * pack
self.pack = pack
self.packdim = dim
dim = input_dim * pac
self.pac = pac
self.pacdim = dim
seq = []
for item in list(discriminator_dim):
seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
Expand Down Expand Up @@ -49,8 +49,8 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb
return gradient_penalty

def forward(self, input):
assert input.size()[0] % self.pack == 0
return self.seq(input.view(-1, self.packdim))
assert input.size()[0] % self.pac == 0
return self.seq(input.view(-1, self.pacdim))


class Residual(Module):
Expand Down Expand Up @@ -122,12 +122,19 @@ class CTGANSynthesizer(BaseSynthesizer):
Whether to have print statements for progress results. Defaults to ``False``.
epochs (int):
Number of training epochs. Defaults to 300.
pac (int):
Number of samples to group together when applying the discriminator.
Defaults to 10.
cuda (bool):
Whether to attempt to use cuda for GPU computation.
If this is False or CUDA is not available, CPU will be used.
Defaults to ``True``.
"""

def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256),
generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4,
discriminator_decay=0, batch_size=500, discriminator_steps=1, log_frequency=True,
verbose=False, epochs=300):
verbose=False, epochs=300, pac=10, cuda=True):

assert batch_size % 2 == 0

Expand All @@ -145,8 +152,21 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di
self._log_frequency = log_frequency
self._verbose = verbose
self._epochs = epochs
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.trained_epochs = 0
self.pac = pac

if not cuda or not torch.cuda.is_available():
device = 'cpu'
elif isinstance(cuda, str):
device = cuda
else:
device = 'cuda'

self._device = torch.device(device)

self._transformer = None
self._data_sampler = None
self._generator = None

@staticmethod
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
Expand Down Expand Up @@ -289,18 +309,19 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
data_dim
).to(self._device)

self._discriminator = Discriminator(
discriminator = Discriminator(
data_dim + self._data_sampler.dim_cond_vec(),
self._discriminator_dim
self._discriminator_dim,
pac=self.pac
).to(self._device)

self._optimizerG = optim.Adam(
optimizerG = optim.Adam(
self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9),
weight_decay=self._generator_decay
)

self._optimizerD = optim.Adam(
self._discriminator.parameters(), lr=self._discriminator_lr,
optimizerD = optim.Adam(
discriminator.parameters(), lr=self._discriminator_lr,
betas=(0.5, 0.9), weight_decay=self._discriminator_decay
)

Expand Down Expand Up @@ -343,17 +364,17 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
real_cat = real
fake_cat = fake

y_fake = self._discriminator(fake_cat)
y_real = self._discriminator(real_cat)
y_fake = discriminator(fake_cat)
y_real = discriminator(real_cat)

pen = self._discriminator.calc_gradient_penalty(
real_cat, fake_cat, self._device)
pen = discriminator.calc_gradient_penalty(
real_cat, fake_cat, self._device, self.pac)
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))

self._optimizerD.zero_grad()
optimizerD.zero_grad()
pen.backward(retain_graph=True)
loss_d.backward()
self._optimizerD.step()
optimizerD.step()

fakez = torch.normal(mean=mean, std=std)
condvec = self._data_sampler.sample_condvec(self._batch_size)
Expand All @@ -370,9 +391,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
fakeact = self._apply_activate(fake)

if c1 is not None:
y_fake = self._discriminator(torch.cat([fakeact, c1], dim=1))
y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
else:
y_fake = self._discriminator(fakeact)
y_fake = discriminator(fakeact)

if condvec is None:
cross_entropy = 0
Expand All @@ -381,9 +402,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):

loss_g = -torch.mean(y_fake) + cross_entropy

self._optimizerG.zero_grad()
optimizerG.zero_grad()
loss_g.backward()
self._optimizerG.step()
optimizerG.step()

if self._verbose:
print(f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f},"
Expand Down Expand Up @@ -444,7 +465,5 @@ def sample(self, n, condition_column=None, condition_value=None):

def set_device(self, device):
self._device = device
if hasattr(self, '_generator'):
if self._generator is not None:
self._generator.to(self._device)
if hasattr(self, '_discriminator'):
self._discriminator.to(self._device)
15 changes: 12 additions & 3 deletions ctgan/synthesizers/tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def __init__(
decompress_dims=(128, 128),
l2scale=1e-5,
batch_size=500,
epochs=300
epochs=300,
loss_factor=2,
cuda=True
):

self.embedding_dim = embedding_dim
Expand All @@ -91,10 +93,17 @@ def __init__(

self.l2scale = l2scale
self.batch_size = batch_size
self.loss_factor = 2
self.loss_factor = loss_factor
self.epochs = epochs

self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if not cuda or not torch.cuda.is_available():
device = 'cpu'
elif isinstance(cuda, str):
device = cuda
else:
device = 'cuda'

self._device = torch.device(device)

def fit(self, train_data, discrete_columns=tuple()):
self.transformer = DataTransformer()
Expand Down