Skip to content

Commit

Permalink
feat(adult): Add WGANGP example with adult census dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed Dec 28, 2020
1 parent 6e02330 commit c97931d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
27 changes: 27 additions & 0 deletions examples/adult_wgangp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from ydata_synthetic.preprocessing.adult import transformations
from ydata_synthetic.synthesizers import WGAN_GP

#Load and process the data
data, processed_data, preprocessor = transformations()

# WGAN_GP training
#Defininf the training parameters of WGAN_GP

noise_dim = 32
dim = 128
batch_size = 128

log_step = 100
epochs = 200+1
learning_rate = 5e-4
beta_1 = 0.5
beta_2 = 0.9
models_dir = './cache'

gan_args = [batch_size, learning_rate, beta_1, beta_2, noise_dim, processed_data.shape[1], dim]
train_args = ['', epochs, log_step]

synthesizer = WGAN_GP(gan_args, n_critic=2)
synthesizer.train(processed_data, train_args)

synth_data = synthesizer.sample(1000)
11 changes: 4 additions & 7 deletions src/ydata_synthetic/preprocessing/adult.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@

from pmlb import fetch_data

def transformations(auto=True):
if auto:
data = fetch_data('adult')
else:
data = fetch_data('adult')
def transformations():
data = fetch_data('adult')

numerical_features = ['age', 'fnlwgt',
'capital-gain', 'capital-loss',
Expand All @@ -29,8 +26,8 @@ def transformations(auto=True):
('num', numerical_transformer, numerical_features),
('cat', categorical_transformer, categorical_features)])

processed_data = preprocessor.fit_transform(data)
processed_data = pd.DataFrame.sparse.from_spmatrix(preprocessor.fit_transform(processed_data))
processed_data = pd.DataFrame.sparse.from_spmatrix(preprocessor.fit_transform(data))

return data, processed_data, preprocessor


Expand Down

0 comments on commit c97931d

Please sign in to comment.