Skip to content

Commit

Permalink
consider verbose as a frequence, add device print, add training time …
Browse files Browse the repository at this point in the history
…in progress results
  • Loading branch information
florent-prevision committed Apr 19, 2021
1 parent 86fcd23 commit 77ff57f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
16 changes: 12 additions & 4 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import datetime
import time
import warnings

import numpy as np
Expand Down Expand Up @@ -118,8 +120,8 @@ class CTGANSynthesizer(BaseSynthesizer):
log_frequency (boolean):
Whether to use log frequency of categorical levels in conditional
sampling. Defaults to ``True``.
verbose (boolean):
Whether to have print statements for progress results. Defaults to ``False``.
verbose (int):
Frequence of print statements for progress results. Defaults to 0 (no print).
epochs (int):
Number of training epochs. Defaults to 300.
pac (int):
Expand All @@ -134,7 +136,7 @@ class CTGANSynthesizer(BaseSynthesizer):
def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256),
generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4,
discriminator_decay=1e-6, batch_size=500, discriminator_steps=1,
log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True):
log_frequency=True, verbose=0, epochs=300, pac=10, cuda=True):

assert batch_size % 2 == 0

Expand Down Expand Up @@ -162,6 +164,8 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di
device = 'cuda'

self._device = torch.device(device)
if self._verbose != 0:
print(f"Device used: {self._device}")

self._transformer = None
self._data_sampler = None
Expand Down Expand Up @@ -328,6 +332,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
std = mean + 1

steps_per_epoch = max(len(train_data) // self._batch_size, 1)
start_time = time.time()
for i in range(epochs):
for id_ in range(steps_per_epoch):

Expand Down Expand Up @@ -404,9 +409,12 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
loss_g.backward()
optimizerG.step()

if self._verbose:
if self._verbose == 0:
pass
elif (i + 1) % self._verbose == 0:
print(f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f},"
f"Loss D: {loss_d.detach().cpu(): .4f}",
f"{datetime.timedelta(seconds=int(time.time() - start_time))}s",
flush=True)

def sample(self, n, condition_column=None, condition_value=None):
Expand Down
11 changes: 11 additions & 0 deletions tests/integration/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,14 @@ def test_wrong_sampling_conditions():

with pytest.raises(ValueError):
ctgan.sample(1, 'discrete', "d")


def test_verbosity():
data = pd.DataFrame({
'continuous': np.random.random(100),
'discrete': np.random.choice(['a', 'b', 'c'], 100)
})
discrete_columns = ['discrete']

ctgan = CTGANSynthesizer(epochs=1, verbose=1)
ctgan.fit(data, discrete_columns)

0 comments on commit 77ff57f

Please sign in to comment.