Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions ctgan/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ class CTGANSynthesizer(object):
"""

def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256),
l2scale=1e-6, batch_size=500):
l2scale=1e-6, batch_size=500, log_frequency=True):

self.embedding_dim = embedding_dim
self.gen_dim = gen_dim
self.dis_dim = dis_dim

self.l2scale = l2scale
self.batch_size = batch_size
self.log_frequency = log_frequency
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.trained_epoches = 0

Expand All @@ -63,6 +64,9 @@ def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
but will be differentiated as if it is the soft sample in autograd
dim (int):
a dimension along which softmax will be computed. Default: -1.
log_frequency (boolean):
Whether to use log frequency of categorical levels in conditional
sampling. Defaults to ``True``.

Returns:
Sampled tensor of same shape as logits from the Gumbel-Softmax distribution.
Expand Down Expand Up @@ -130,7 +134,7 @@ def _cond_loss(self, data, c, m):

return (loss * m).sum() / data.size()[0]

def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=True):
def fit(self, train_data, discrete_columns=tuple(), epochs=300):
"""Fit the CTGAN Synthesizer models to the training data.

Args:
Expand All @@ -144,9 +148,6 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr
a ``pandas.DataFrame``, this list should contain the column names.
epochs (int):
Number of training epochs. Defaults to 300.
log_frequency (boolean):
Whether to use log frequency of categorical levels in conditional
sampling. Defaults to ``True``.
"""

if not hasattr(self, "transformer"):
Expand All @@ -162,7 +163,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr
self.cond_generator = ConditionalGenerator(
train_data,
self.transformer.output_info,
log_frequency
self.log_frequency
)

if not hasattr(self, "generator"):
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def test_log_frequency():
counts = sampled['discrete'].value_counts()
assert counts['a'] < 6500

ctgan = CTGANSynthesizer()
ctgan.fit(data, discrete_columns, epochs=100, log_frequency=False)
ctgan = CTGANSynthesizer(log_frequency=False)
ctgan.fit(data, discrete_columns, epochs=100)

sampled = ctgan.sample(10000)
counts = sampled['discrete'].value_counts()
Expand Down