In [32]:
import warnings
from collections import Counter
warnings.filterwarnings("ignore")

import pickle
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from ctgan import CTGANSynthesizer
from sklearn.model_selection import train_test_split
from utils import *

MODELS_PATH = './models'
DATA_PATH = './data/'
dataset = 'adult'
seed = 1
np.random.seed(seed)

In [33]:
from ctgan import load_demo
data = load_demo()

categorical_features = [
    'workclass',
    'education',
    'marital-status',
    'occupation',
    'relationship',
    'race',
    'sex',
    'native-country',
]


In [34]:
X, y = data.iloc[:, :-1], data.iloc[:, -1]
le = LabelEncoder()
y = le.fit_transform(y)

In [27]:
# Append classifier to preprocessing pipeline.
# Now we have a full prediction pipeline.
preprocessor = get_preprocessor(X, categorical_features)
rf = RandomForestClassifier(n_jobs=-1, random_state=seed)

clf = Pipeline(steps=[('preprocessor', preprocessor),
                      ('classifier', rf)])

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed)

clf.fit(X_train, y_train)
print("model score: %.3f" % clf.score(X_test, y_test))

model score: 0.860


In [31]:
z_features = get_noise_features(X_train, categorical_features)
z_rows = int(0.25 * X_train.shape[0])
z = gen_random_noise(shape=(z_rows, z_features))
batch_size = 500
epochs = 50
confidence_level = 0.9 
gen_lr = 2e-5
loss = 'log'

rf_ctgan = CTGANSynthesizer(batch_size=batch_size, 
                            blackbox_model=rf, 
                            preprocessing_pipeline=preprocessor, 
                            bb_loss=loss
                            )

hist = rf_ctgan.fit(train_data=z, 
                    epochs=epochs,
                    confidence_level=confidence_level,
                    gen_lr=gen_lr,
                    )

rf_ctgan.save(f"{MODELS_PATH}/{dataset}_ctgan_c_{confidence_level}.pkl")

plot_losses(hist, title=dataset + ' loss')
print()

Epoch 1, Loss G: 0.5080986133860981, loss_bb: 0.5080986133860981
Epoch 2, Loss G: 0.5064743379269014, loss_bb: 0.5064743379269014
Epoch 3, Loss G: 0.5133853381283394, loss_bb: 0.5133853381283394
Epoch 4, Loss G: 0.4984214786592201, loss_bb: 0.4984214786592201
Epoch 5, Loss G: 0.49789018777142435, loss_bb: 0.49789018777142435
Epoch 6, Loss G: 0.5047797800283932, loss_bb: 0.5047797800283932
Epoch 7, Loss G: 0.5128966421942207, loss_bb: 0.5128966421942207
Epoch 8, Loss G: 0.5180240170559639, loss_bb: 0.5180240170559639
Epoch 9, Loss G: 0.5059346959673963, loss_bb: 0.5059346959673963


KeyboardInterrupt: 