In [1]:
!pip install ydata-synthetic
import pandas as pd
import numpy as np
from sklearn.preprocessing import PowerTransformer
from ydata_synthetic.synthesizers.regular import WGAN_GP
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
from tensorflow.random import uniform
from tensorflow.dtypes import float32



In [3]:
# Import Dataset
data = pd.read_csv('./data.csv')
print(data)

num_cols = ['Pregnancies', 'Glucose', 'BloodPressure', 'SkinThickness', 'Insulin', 'BMI', 'DiabetesPedigreeFunction', 'Age']
cat_cols = ['Outcome']

     Pregnancies  Glucose  ...  Age  Outcome
0              6      148  ...   50        1
1              1       85  ...   31        0
2              8      183  ...   32        1
3              1       89  ...   21        0
4              0      137  ...   33        1
..           ...      ...  ...  ...      ...
763           10      101  ...   63        0
764            2      122  ...   27        0
765            5      121  ...   30        0
766            1      126  ...   47        1
767            1       93  ...   23        0

[768 rows x 9 columns]


In [4]:
# Define the GAN and training parameters
noise_dim = 128
dim = 128
batch_size = 10

log_step = 100
epochs = 50
learning_rate = [5e-4, 3e-3]
beta_1 = 0.5
beta_2 = 0.9

gan_args = ModelParameters(batch_size=batch_size,
                           lr=learning_rate,
                           betas=(beta_1, beta_2),
                           noise_dim=noise_dim,
                           layers_dim=dim)

train_args = TrainParameters(epochs=epochs,
                             sample_interval=log_step)

n_critic = 3
sample_size = 50000000


# Random noise for sampling both generators
noise = uniform([sample_size, noise_dim], dtype=float32)

# Training the GAN model
model = WGAN_GP
synthesizer = model(gan_args, n_critic)
synthesizer.train(data, train_args, num_cols, cat_cols)

  2%|▏         | 1/50 [00:02<02:01,  2.49s/it]

Epoch: 0 | disc_loss: 0.2902154326438904 | gen_loss: -0.16385281085968018


  4%|▍         | 2/50 [00:02<01:02,  1.29s/it]

Epoch: 1 | disc_loss: -0.04795543849468231 | gen_loss: -0.10544504225254059


  6%|▌         | 3/50 [00:03<00:42,  1.10it/s]

Epoch: 2 | disc_loss: -0.02909155935049057 | gen_loss: 0.10356004536151886


  8%|▊         | 4/50 [00:03<00:33,  1.38it/s]

Epoch: 3 | disc_loss: 0.1073603704571724 | gen_loss: -0.03828207775950432


 10%|█         | 5/50 [00:04<00:28,  1.60it/s]

Epoch: 4 | disc_loss: -0.14477434754371643 | gen_loss: 0.05113510414958


 12%|█▏        | 6/50 [00:04<00:24,  1.78it/s]

Epoch: 5 | disc_loss: -0.1609516441822052 | gen_loss: -0.03536919131875038


 14%|█▍        | 7/50 [00:05<00:22,  1.91it/s]

Epoch: 6 | disc_loss: -0.08231465518474579 | gen_loss: -0.07890507578849792


 16%|█▌        | 8/50 [00:05<00:20,  2.02it/s]

Epoch: 7 | disc_loss: -0.10179103165864944 | gen_loss: -0.04180557280778885


 18%|█▊        | 9/50 [00:06<00:19,  2.09it/s]

Epoch: 8 | disc_loss: -0.10420244187116623 | gen_loss: -0.0161609910428524


 20%|██        | 10/50 [00:06<00:18,  2.14it/s]

Epoch: 9 | disc_loss: -0.15892794728279114 | gen_loss: 0.06333713233470917


 22%|██▏       | 11/50 [00:06<00:17,  2.18it/s]

Epoch: 10 | disc_loss: -0.09089508652687073 | gen_loss: 0.006445896811783314


 24%|██▍       | 12/50 [00:07<00:17,  2.20it/s]

Epoch: 11 | disc_loss: -0.2597293555736542 | gen_loss: 0.0915500670671463


 26%|██▌       | 13/50 [00:07<00:16,  2.22it/s]

Epoch: 12 | disc_loss: 0.03570307046175003 | gen_loss: -0.0466182641685009


 28%|██▊       | 14/50 [00:08<00:16,  2.23it/s]

Epoch: 13 | disc_loss: -0.2024146020412445 | gen_loss: 0.015064539387822151


 30%|███       | 15/50 [00:08<00:15,  2.24it/s]

Epoch: 14 | disc_loss: -0.1622469574213028 | gen_loss: 0.08596664667129517


 32%|███▏      | 16/50 [00:09<00:15,  2.25it/s]

Epoch: 15 | disc_loss: -0.08859257400035858 | gen_loss: -0.049301933497190475


 34%|███▍      | 17/50 [00:09<00:14,  2.25it/s]

Epoch: 16 | disc_loss: -0.059393394738435745 | gen_loss: 0.00346413254737854


 36%|███▌      | 18/50 [00:10<00:14,  2.25it/s]

Epoch: 17 | disc_loss: 0.007794640958309174 | gen_loss: -0.048375438898801804


 38%|███▊      | 19/50 [00:10<00:13,  2.26it/s]

Epoch: 18 | disc_loss: -0.07433837652206421 | gen_loss: 0.02506088651716709


 40%|████      | 20/50 [00:10<00:13,  2.26it/s]

Epoch: 19 | disc_loss: -0.06206878274679184 | gen_loss: 0.06690545380115509


 42%|████▏     | 21/50 [00:11<00:12,  2.25it/s]

Epoch: 20 | disc_loss: -0.0014637894928455353 | gen_loss: -0.06982740759849548


 44%|████▍     | 22/50 [00:11<00:12,  2.24it/s]

Epoch: 21 | disc_loss: -0.0940803587436676 | gen_loss: -0.009809223003685474


 46%|████▌     | 23/50 [00:12<00:11,  2.25it/s]

Epoch: 22 | disc_loss: -0.13067229092121124 | gen_loss: 0.008552921935915947


 48%|████▊     | 24/50 [00:12<00:11,  2.26it/s]

Epoch: 23 | disc_loss: -0.028910629451274872 | gen_loss: -0.004777704365551472


 50%|█████     | 25/50 [00:13<00:11,  2.25it/s]

Epoch: 24 | disc_loss: -0.18213339149951935 | gen_loss: 0.10452502965927124


 52%|█████▏    | 26/50 [00:13<00:10,  2.27it/s]

Epoch: 25 | disc_loss: -0.08740465342998505 | gen_loss: -0.045747481286525726


 54%|█████▍    | 27/50 [00:14<00:10,  2.27it/s]

Epoch: 26 | disc_loss: -0.0850708931684494 | gen_loss: 0.05970175191760063


 56%|█████▌    | 28/50 [00:14<00:09,  2.27it/s]

Epoch: 27 | disc_loss: -0.18615059554576874 | gen_loss: 0.023173833265900612


 58%|█████▊    | 29/50 [00:14<00:09,  2.28it/s]

Epoch: 28 | disc_loss: -0.1403001844882965 | gen_loss: 0.07088662683963776


 60%|██████    | 30/50 [00:15<00:08,  2.26it/s]

Epoch: 29 | disc_loss: -0.1419713795185089 | gen_loss: 0.07942787557840347


 62%|██████▏   | 31/50 [00:15<00:08,  2.26it/s]

Epoch: 30 | disc_loss: -0.12709267437458038 | gen_loss: -0.057902365922927856


 64%|██████▍   | 32/50 [00:16<00:08,  2.24it/s]

Epoch: 31 | disc_loss: -0.039059899747371674 | gen_loss: -0.17425687611103058


 66%|██████▌   | 33/50 [00:16<00:07,  2.23it/s]

Epoch: 32 | disc_loss: -0.02997516840696335 | gen_loss: 0.03412641957402229


 68%|██████▊   | 34/50 [00:17<00:07,  2.22it/s]

Epoch: 33 | disc_loss: -0.2036609947681427 | gen_loss: 0.04964857175946236


 70%|███████   | 35/50 [00:17<00:06,  2.22it/s]

Epoch: 34 | disc_loss: -0.1986587643623352 | gen_loss: 0.13818714022636414


 72%|███████▏  | 36/50 [00:18<00:06,  2.23it/s]

Epoch: 35 | disc_loss: -0.17901469767093658 | gen_loss: 0.169156014919281


 74%|███████▍  | 37/50 [00:18<00:05,  2.24it/s]

Epoch: 36 | disc_loss: -0.16044919192790985 | gen_loss: 0.11791130155324936


 76%|███████▌  | 38/50 [00:18<00:05,  2.24it/s]

Epoch: 37 | disc_loss: -0.16277945041656494 | gen_loss: 0.21605190634727478


 78%|███████▊  | 39/50 [00:19<00:04,  2.23it/s]

Epoch: 38 | disc_loss: -0.0993216335773468 | gen_loss: 0.14175589382648468


 80%|████████  | 40/50 [00:19<00:04,  2.25it/s]

Epoch: 39 | disc_loss: -0.20962314307689667 | gen_loss: 0.07760071754455566


 82%|████████▏ | 41/50 [00:20<00:03,  2.26it/s]

Epoch: 40 | disc_loss: -0.06500864773988724 | gen_loss: 0.03179243952035904


 84%|████████▍ | 42/50 [00:20<00:03,  2.26it/s]

Epoch: 41 | disc_loss: -0.09191881865262985 | gen_loss: 0.05251043289899826


 86%|████████▌ | 43/50 [00:21<00:03,  2.25it/s]

Epoch: 42 | disc_loss: -0.06766289472579956 | gen_loss: 0.04613421857357025


 88%|████████▊ | 44/50 [00:21<00:02,  2.24it/s]

Epoch: 43 | disc_loss: -0.04317223280668259 | gen_loss: 0.23048171401023865


 90%|█████████ | 45/50 [00:22<00:02,  2.24it/s]

Epoch: 44 | disc_loss: -0.019317686557769775 | gen_loss: 0.061155933886766434


 92%|█████████▏| 46/50 [00:22<00:01,  2.25it/s]

Epoch: 45 | disc_loss: -0.1357932835817337 | gen_loss: 0.012329155579209328


 94%|█████████▍| 47/50 [00:22<00:01,  2.26it/s]

Epoch: 46 | disc_loss: -0.1169242262840271 | gen_loss: 0.015503637492656708


 96%|█████████▌| 48/50 [00:23<00:00,  2.26it/s]

Epoch: 47 | disc_loss: -0.07670173048973083 | gen_loss: 0.1705242544412613


 98%|█████████▊| 49/50 [00:23<00:00,  2.26it/s]

Epoch: 48 | disc_loss: -0.13615919649600983 | gen_loss: 0.05437643080949783


100%|██████████| 50/50 [00:24<00:00,  2.06it/s]

Epoch: 49 | disc_loss: -0.16945521533489227 | gen_loss: 0.21999402344226837





In [5]:
# Generate samples
gs_samples = synthesizer.sample(sample_size)[:sample_size]
print(gs_samples)

gs_samples.to_csv('generated.csv', index=False)

Synthetic data generation: 100%|██████████| 5000001/5000001 [3:49:04<00:00, 363.77it/s]


          Pregnancies  Glucose  ...  Age  Outcome
0                   1       81  ...   25        1
1                   4       97  ...   33        0
2                   3      133  ...   32        1
3                   3      106  ...   34        0
4                   2       95  ...   27        1
...               ...      ...  ...  ...      ...
49999995            4      130  ...   37        0
49999996            2       94  ...   36        1
49999997            3      157  ...   23        1
49999998            3      119  ...   30        1
49999999            3      114  ...   30        1

[50000000 rows x 9 columns]
