diff --git a/sdv/tabular/ctgan.py b/sdv/tabular/ctgan.py index f303ed9a6..2c5db4c7a 100644 --- a/sdv/tabular/ctgan.py +++ b/sdv/tabular/ctgan.py @@ -19,8 +19,6 @@ class CTGANModel(BaseTabularModel): 'O': 'label_encoding' } - _cuda = True - def _build_model(self): return self._MODEL_CLASS(**self._model_kwargs) @@ -33,16 +31,6 @@ def _fit(self, table_data): """ self._model = self._build_model() - import torch - if not self._cuda or not torch.cuda.is_available(): - device = 'cpu' - elif isinstance(self._cuda, str): - device = self._cuda - else: - device = 'cuda' - - self._model.device = torch.device(device) - categoricals = [] fields_before_transform = self._metadata.get_fields() for field in table_data.columns: @@ -157,6 +145,9 @@ class CTGAN(CTGANModel): Whether to have print statements for progress results. Defaults to ``False``. epochs (int): Number of training epochs. Defaults to 300. + pac (int): + Number of samples to group together when applying the discriminator. + Defaults to 10. cuda (bool or str): If ``True``, use CUDA. If a ``str``, use the indicated device. If ``False``, do not use cuda at all. @@ -169,7 +160,7 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=0, batch_size=500, discriminator_steps=1, - log_frequency=True, verbose=False, epochs=300, cuda=True): + log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True): super().__init__( field_names=field_names, primary_key=primary_key, @@ -192,11 +183,11 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None, 'discriminator_steps': discriminator_steps, 'log_frequency': log_frequency, 'verbose': verbose, - 'epochs': epochs + 'epochs': epochs, + 'pac': pac, + 'cuda': cuda } - self._cuda = cuda - class TVAE(CTGANModel): """Model wrapping ``TVAESynthesizer`` model. @@ -250,6 +241,8 @@ class TVAE(CTGANModel): Number of data samples to process in each step. epochs (int): Number of training epochs. Defaults to 300. + loss_factor (int): + Multiplier for the reconstruction error. Defaults to 2. cuda (bool or str): If ``True``, use CUDA. If a ``str``, use the indicated device. If ``False``, do not use cuda at all. @@ -260,7 +253,7 @@ class TVAE(CTGANModel): def __init__(self, field_names=None, field_types=None, field_transformers=None, anonymize_fields=None, primary_key=None, constraints=None, table_metadata=None, embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), - l2scale=1e-5, batch_size=500, epochs=300, cuda=True): + l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True): super().__init__( field_names=field_names, primary_key=primary_key, @@ -277,7 +270,7 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None, 'decompress_dims': decompress_dims, 'l2scale': l2scale, 'batch_size': batch_size, - 'epochs': epochs + 'epochs': epochs, + 'loss_factor': loss_factor, + 'cuda': cuda } - - self._cuda = cuda diff --git a/setup.py b/setup.py index 1c8cd71b9..a29409d31 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ 'torch>=1.4,<2', 'tqdm>=4.14,<5', 'copulas>=0.5.0,<0.6', - 'ctgan>=0.4.0,<0.5', + 'ctgan>=0.4.1.dev0,<0.5', 'deepecho>=0.1.4,<0.2', 'rdt>=0.4.0,<0.5', 'sdmetrics>=0.2.0,<0.3',