Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ctgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def is_discrete_column(column_info):
# Prepare an interval matrix for efficiently sample conditional vector
max_category = max(
[column_info[0].dim for column_info in output_info
if is_discrete_column(column_info)])
if is_discrete_column(column_info)], default=0)

self._discrete_column_cond_st = np.zeros(n_discrete_columns, dtype='int32')
self._discrete_column_n_category = np.zeros(
Expand Down Expand Up @@ -133,7 +133,7 @@ def sample_data(self, n, col, opt):
n rows of matrix data.
"""
if col is None:
idx = np.random.randint(len(self._data), n)
idx = np.random.randint(len(self._data), size=n)
return self._data[idx]

idx = []
Expand Down
15 changes: 15 additions & 0 deletions tests/integration/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@
from ctgan.synthesizers.ctgan import CTGANSynthesizer


def test_ctgan_no_categoricals():
data = pd.DataFrame({
'continuous': np.random.random(1000)
})

ctgan = CTGANSynthesizer(epochs=1)
ctgan.fit(data, [])

sampled = ctgan.sample(100)

assert sampled.shape == (100, 1)
assert isinstance(sampled, pd.DataFrame)
assert set(sampled.columns) == {'continuous'}


def test_ctgan_dataframe():
data = pd.DataFrame({
'continuous': np.random.random(100),
Expand Down