diff --git a/ctgan/synthesizer.py b/ctgan/synthesizer.py index 93a94520..a57cc2bf 100644 --- a/ctgan/synthesizer.py +++ b/ctgan/synthesizer.py @@ -35,7 +35,7 @@ class CTGANSynthesizer(object): """ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256), - l2scale=1e-6, batch_size=500): + l2scale=1e-6, batch_size=500, log_frequency=True): self.embedding_dim = embedding_dim self.gen_dim = gen_dim @@ -43,6 +43,7 @@ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256), self.l2scale = l2scale self.batch_size = batch_size + self.log_frequency = log_frequency self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.trained_epoches = 0 @@ -63,6 +64,9 @@ def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): but will be differentiated as if it is the soft sample in autograd dim (int): a dimension along which softmax will be computed. Default: -1. + log_frequency (boolean): + Whether to use log frequency of categorical levels in conditional + sampling. Defaults to ``True``. Returns: Sampled tensor of same shape as logits from the Gumbel-Softmax distribution. @@ -130,7 +134,7 @@ def _cond_loss(self, data, c, m): return (loss * m).sum() / data.size()[0] - def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=True): + def fit(self, train_data, discrete_columns=tuple(), epochs=300): """Fit the CTGAN Synthesizer models to the training data. Args: @@ -144,9 +148,6 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr a ``pandas.DataFrame``, this list should contain the column names. epochs (int): Number of training epochs. Defaults to 300. - log_frequency (boolean): - Whether to use log frequency of categorical levels in conditional - sampling. Defaults to ``True``. """ if not hasattr(self, "transformer"): @@ -162,7 +163,7 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr self.cond_generator = ConditionalGenerator( train_data, self.transformer.output_info, - log_frequency + self.log_frequency ) if not hasattr(self, "generator"): diff --git a/tests/integration/test_ctgan.py b/tests/integration/test_ctgan.py index 263f5eb9..07308c6f 100644 --- a/tests/integration/test_ctgan.py +++ b/tests/integration/test_ctgan.py @@ -66,8 +66,8 @@ def test_log_frequency(): counts = sampled['discrete'].value_counts() assert counts['a'] < 6500 - ctgan = CTGANSynthesizer() - ctgan.fit(data, discrete_columns, epochs=100, log_frequency=False) + ctgan = CTGANSynthesizer(log_frequency=False) + ctgan.fit(data, discrete_columns, epochs=100) sampled = ctgan.sample(10000) counts = sampled['discrete'].value_counts()