diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index c34a67ab..0ce4172f 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -14,11 +14,11 @@ class Discriminator(Module): - def __init__(self, input_dim, discriminator_dim, pack=10): + def __init__(self, input_dim, discriminator_dim, pac=10): super(Discriminator, self).__init__() - dim = input_dim * pack - self.pack = pack - self.packdim = dim + dim = input_dim * pac + self.pac = pac + self.pacdim = dim seq = [] for item in list(discriminator_dim): seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)] @@ -49,8 +49,8 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb return gradient_penalty def forward(self, input): - assert input.size()[0] % self.pack == 0 - return self.seq(input.view(-1, self.packdim)) + assert input.size()[0] % self.pac == 0 + return self.seq(input.view(-1, self.pacdim)) class Residual(Module): @@ -122,12 +122,19 @@ class CTGANSynthesizer(BaseSynthesizer): Whether to have print statements for progress results. Defaults to ``False``. epochs (int): Number of training epochs. Defaults to 300. + pac (int): + Number of samples to group together when applying the discriminator. + Defaults to 10. + cuda (bool): + Whether to attempt to use cuda for GPU computation. + If this is False or CUDA is not available, CPU will be used. + Defaults to ``True``. """ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=0, batch_size=500, discriminator_steps=1, log_frequency=True, - verbose=False, epochs=300): + verbose=False, epochs=300, pac=10, cuda=True): assert batch_size % 2 == 0 @@ -145,8 +152,21 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di self._log_frequency = log_frequency self._verbose = verbose self._epochs = epochs - self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.trained_epochs = 0 + self.pac = pac + + if not cuda or not torch.cuda.is_available(): + device = 'cpu' + elif isinstance(cuda, str): + device = cuda + else: + device = 'cuda' + + self._device = torch.device(device) + + self._transformer = None + self._data_sampler = None + self._generator = None @staticmethod def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): @@ -289,18 +309,19 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): data_dim ).to(self._device) - self._discriminator = Discriminator( + discriminator = Discriminator( data_dim + self._data_sampler.dim_cond_vec(), - self._discriminator_dim + self._discriminator_dim, + pac=self.pac ).to(self._device) - self._optimizerG = optim.Adam( + optimizerG = optim.Adam( self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9), weight_decay=self._generator_decay ) - self._optimizerD = optim.Adam( - self._discriminator.parameters(), lr=self._discriminator_lr, + optimizerD = optim.Adam( + discriminator.parameters(), lr=self._discriminator_lr, betas=(0.5, 0.9), weight_decay=self._discriminator_decay ) @@ -343,17 +364,17 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): real_cat = real fake_cat = fake - y_fake = self._discriminator(fake_cat) - y_real = self._discriminator(real_cat) + y_fake = discriminator(fake_cat) + y_real = discriminator(real_cat) - pen = self._discriminator.calc_gradient_penalty( - real_cat, fake_cat, self._device) + pen = discriminator.calc_gradient_penalty( + real_cat, fake_cat, self._device, self.pac) loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) - self._optimizerD.zero_grad() + optimizerD.zero_grad() pen.backward(retain_graph=True) loss_d.backward() - self._optimizerD.step() + optimizerD.step() fakez = torch.normal(mean=mean, std=std) condvec = self._data_sampler.sample_condvec(self._batch_size) @@ -370,9 +391,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): fakeact = self._apply_activate(fake) if c1 is not None: - y_fake = self._discriminator(torch.cat([fakeact, c1], dim=1)) + y_fake = discriminator(torch.cat([fakeact, c1], dim=1)) else: - y_fake = self._discriminator(fakeact) + y_fake = discriminator(fakeact) if condvec is None: cross_entropy = 0 @@ -381,9 +402,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): loss_g = -torch.mean(y_fake) + cross_entropy - self._optimizerG.zero_grad() + optimizerG.zero_grad() loss_g.backward() - self._optimizerG.step() + optimizerG.step() if self._verbose: print(f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f}," @@ -444,7 +465,5 @@ def sample(self, n, condition_column=None, condition_value=None): def set_device(self, device): self._device = device - if hasattr(self, '_generator'): + if self._generator is not None: self._generator.to(self._device) - if hasattr(self, '_discriminator'): - self._discriminator.to(self._device) diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py index fee2a47e..7b0a15a5 100644 --- a/ctgan/synthesizers/tvae.py +++ b/ctgan/synthesizers/tvae.py @@ -82,7 +82,9 @@ def __init__( decompress_dims=(128, 128), l2scale=1e-5, batch_size=500, - epochs=300 + epochs=300, + loss_factor=2, + cuda=True ): self.embedding_dim = embedding_dim @@ -91,10 +93,17 @@ def __init__( self.l2scale = l2scale self.batch_size = batch_size - self.loss_factor = 2 + self.loss_factor = loss_factor self.epochs = epochs - self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if not cuda or not torch.cuda.is_available(): + device = 'cpu' + elif isinstance(cuda, str): + device = cuda + else: + device = 'cuda' + + self._device = torch.device(device) def fit(self, train_data, discrete_columns=tuple()): self.transformer = DataTransformer()