From 2e4e5354b8910c1ab4cd79798813d1ede1481eb3 Mon Sep 17 00:00:00 2001 From: Felipe Hofmann Date: Fri, 5 Mar 2021 02:05:29 -0800 Subject: [PATCH 1/7] Adds parameters --- sdv/tabular/ctgan.py | 18 ++++++++++++++---- setup.py | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/sdv/tabular/ctgan.py b/sdv/tabular/ctgan.py index f87ac958d..bc5a871ee 100644 --- a/sdv/tabular/ctgan.py +++ b/sdv/tabular/ctgan.py @@ -143,6 +143,9 @@ class CTGAN(CTGANModel): Whether to have print statements for progress results. Defaults to ``False``. epochs (int): Number of training epochs. Defaults to 300. + pac (int): + Number of samples to group together when applying the discriminator. + Defaults to 10. cuda (bool or str): If ``True``, use CUDA. If a ``str``, use the indicated device. If ``False``, do not use cuda at all. @@ -155,7 +158,7 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=0, batch_size=500, discriminator_steps=1, - log_frequency=True, verbose=False, epochs=300, cuda=True): + log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True): super().__init__( field_names=field_names, primary_key=primary_key, @@ -178,7 +181,9 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None, 'discriminator_steps': discriminator_steps, 'log_frequency': log_frequency, 'verbose': verbose, - 'epochs': epochs + 'epochs': epochs, + 'pac': pac, + 'cuda': cuda } self._cuda = cuda @@ -236,6 +241,8 @@ class TVAE(CTGANModel): Number of data samples to process in each step. epochs (int): Number of training epochs. Defaults to 300. + loss_factor (int): + TODO. Defaults to 2. cuda (bool or str): If ``True``, use CUDA. If a ``str``, use the indicated device. If ``False``, do not use cuda at all. @@ -246,7 +253,7 @@ class TVAE(CTGANModel): def __init__(self, field_names=None, field_types=None, field_transformers=None, anonymize_fields=None, primary_key=None, constraints=None, table_metadata=None, embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), - l2scale=1e-5, batch_size=500, epochs=300, cuda=True): + l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True): super().__init__( field_names=field_names, primary_key=primary_key, @@ -263,7 +270,10 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None, 'decompress_dims': decompress_dims, 'l2scale': l2scale, 'batch_size': batch_size, - 'epochs': epochs + 'epochs': epochs, + 'loss_factor': loss_factor, + 'cuda': cuda } + #update the cuda logic self._cuda = cuda diff --git a/setup.py b/setup.py index 1c8cd71b9..b75ca4b36 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ 'torch>=1.4,<2', 'tqdm>=4.14,<5', 'copulas>=0.5.0,<0.6', - 'ctgan>=0.4.0,<0.5', + 'ctgan>=0.4.0.dev0,<0.5', 'deepecho>=0.1.4,<0.2', 'rdt>=0.4.0,<0.5', 'sdmetrics>=0.2.0,<0.3', From 038a7b4da52dcdac6a4f1bf3084fdec22a37798f Mon Sep 17 00:00:00 2001 From: Felipe Hofmann Date: Tue, 9 Mar 2021 11:50:56 -0800 Subject: [PATCH 2/7] Need to check changes --- sdv/metadata/utils.py | 2 +- sdv/tabular/ctgan.py | 2 +- setup.py | 2 +- tutorials/single_table_data/.DS_Store | Bin 0 -> 6148 bytes 4 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 tutorials/single_table_data/.DS_Store diff --git a/sdv/metadata/utils.py b/sdv/metadata/utils.py index 86021f0e9..719da62fa 100644 --- a/sdv/metadata/utils.py +++ b/sdv/metadata/utils.py @@ -1,10 +1,10 @@ """Tools to generate strings from regular expressions.""" import re -import sre_parse import string import numpy as np +import sre_parse def _literal(character, max_repeat): diff --git a/sdv/tabular/ctgan.py b/sdv/tabular/ctgan.py index bc5a871ee..d08f832f2 100644 --- a/sdv/tabular/ctgan.py +++ b/sdv/tabular/ctgan.py @@ -275,5 +275,5 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None, 'cuda': cuda } - #update the cuda logic + # update the cuda logic self._cuda = cuda diff --git a/setup.py b/setup.py index b75ca4b36..a29409d31 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ 'torch>=1.4,<2', 'tqdm>=4.14,<5', 'copulas>=0.5.0,<0.6', - 'ctgan>=0.4.0.dev0,<0.5', + 'ctgan>=0.4.1.dev0,<0.5', 'deepecho>=0.1.4,<0.2', 'rdt>=0.4.0,<0.5', 'sdmetrics>=0.2.0,<0.3', diff --git a/tutorials/single_table_data/.DS_Store b/tutorials/single_table_data/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 GIT binary patch literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 Date: Tue, 9 Mar 2021 20:54:59 -0800 Subject: [PATCH 3/7] Fix cuda logic --- sdv/tabular/ctgan.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/sdv/tabular/ctgan.py b/sdv/tabular/ctgan.py index d08f832f2..95c038fe8 100644 --- a/sdv/tabular/ctgan.py +++ b/sdv/tabular/ctgan.py @@ -18,8 +18,6 @@ class CTGANModel(BaseTabularModel): 'O': 'label_encoding' } - _cuda = True - def _build_model(self): return self._MODEL_CLASS(**self._model_kwargs) @@ -32,16 +30,6 @@ def _fit(self, table_data): """ self._model = self._build_model() - import torch - if not self._cuda or not torch.cuda.is_available(): - device = 'cpu' - elif isinstance(self._cuda, str): - device = self._cuda - else: - device = 'cuda' - - self._model.device = torch.device(device) - categoricals = [ field for field, meta in self._metadata.get_fields().items() @@ -186,8 +174,6 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None, 'cuda': cuda } - self._cuda = cuda - class TVAE(CTGANModel): """Model wrapping ``TVAESynthesizer`` model. @@ -274,6 +260,3 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None, 'loss_factor': loss_factor, 'cuda': cuda } - - # update the cuda logic - self._cuda = cuda From 9c2aef23193f317c4d35015efee4346b7a20a900 Mon Sep 17 00:00:00 2001 From: Felipe Hofmann Date: Tue, 9 Mar 2021 20:56:28 -0800 Subject: [PATCH 4/7] Fix lint --- sdv/metadata/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdv/metadata/utils.py b/sdv/metadata/utils.py index 719da62fa..86021f0e9 100644 --- a/sdv/metadata/utils.py +++ b/sdv/metadata/utils.py @@ -1,10 +1,10 @@ """Tools to generate strings from regular expressions.""" import re +import sre_parse import string import numpy as np -import sre_parse def _literal(character, max_repeat): From 45d731eab9c8db062c23cbba7c299c09aafca2bd Mon Sep 17 00:00:00 2001 From: Felipe Hofmann Date: Tue, 9 Mar 2021 21:20:53 -0800 Subject: [PATCH 5/7] Update documentation --- sdv/tabular/ctgan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdv/tabular/ctgan.py b/sdv/tabular/ctgan.py index 95c038fe8..15495dd8b 100644 --- a/sdv/tabular/ctgan.py +++ b/sdv/tabular/ctgan.py @@ -228,7 +228,7 @@ class TVAE(CTGANModel): epochs (int): Number of training epochs. Defaults to 300. loss_factor (int): - TODO. Defaults to 2. + Multiplier for the reconstruction error. Defaults to 2. cuda (bool or str): If ``True``, use CUDA. If a ``str``, use the indicated device. If ``False``, do not use cuda at all. From 2867ed8cf64875bd41ef82e8975e820708b48738 Mon Sep 17 00:00:00 2001 From: fealho Date: Tue, 9 Mar 2021 21:26:48 -0800 Subject: [PATCH 6/7] Delete .DS_Store Remove .DS_Store --- tutorials/single_table_data/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tutorials/single_table_data/.DS_Store diff --git a/tutorials/single_table_data/.DS_Store b/tutorials/single_table_data/.DS_Store deleted file mode 100644 index 5008ddfcf53c02e82d7eee2e57c38e5672ef89f6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 Date: Tue, 9 Mar 2021 21:33:52 -0800 Subject: [PATCH 7/7] Fix lint --- sdv/tabular/ctgan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdv/tabular/ctgan.py b/sdv/tabular/ctgan.py index 6d8bd31f8..2c5db4c7a 100644 --- a/sdv/tabular/ctgan.py +++ b/sdv/tabular/ctgan.py @@ -49,7 +49,7 @@ def _fit(self, table_data): if kind in ['O', 'b']: categoricals.append(field) - + self._model.fit( table_data, discrete_columns=categoricals