In [1]:
import pandas as pd
import numpy as np
from hyperopt import hp, fmin, tpe, STATUS_OK, Trials
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import roc_curve, roc_auc_score
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table import CTGANSynthesizer
from sdv.single_table import TVAESynthesizer
from sdv.single_table import CopulaGANSynthesizer
from sdv.metadata import SingleTableMetadata
from sklearn.tree import DecisionTreeClassifier
import xgboost as xgb
import time
import utilities

## Load data
Load data and create train test split from the smaller dataset that contains 10% of the full data

In [2]:
df = pd.read_csv("../data/adult.csv")
df.loc[df["income"] == "<=50K", "income"] = 0
df.loc[df["income"] == ">50K", "income"] = 1
df, df_te = train_test_split(df, test_size = 0.1, random_state = 5)
df_te.to_csv("../data/small_adult.csv", index=False)

df.head()
cat_col = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'gender', 'native-country']
params_xgb = {
        'eval_metric': 'auc'
}


In [3]:
target = 'income'
target_encoder = utilities.MultiColumnTargetEncoder(cat_col, target)

In [4]:
df_original = pd.read_csv("../data/small_adult.csv")
df_original.replace('?', np.NaN,inplace=True)
df_original.dropna(axis=0,how='any',inplace=True)

In [5]:
df = df_original.copy()

In [6]:
df_train, df_test = train_test_split(df, test_size = 0.2,  random_state = 5)

df_train.to_csv("../data/train.csv", index=False)
df_test.to_csv("../data/test.csv", index=False)

df_train_modified = target_encoder.transform(df_train)
df_test_modified = target_encoder.transform_test_data(df_test)

df_train_modified.to_csv("../data/train_modified.csv", index=False)
df_test_modified.to_csv("../data/test_modified.csv", index=False)

x_train = df_train_modified.loc[:, df_train_modified.columns != target]
y_train = df_train_modified[target]

x_test = df_test_modified.loc[:, df_test_modified.columns != target]
y_test = df_test_modified[target]

In [7]:
x_train

Unnamed: 0,age,fnlwgt,educational-num,capital-gain,capital-loss,hours-per-week,workclass_target_encoded,education_target_encoded,marital-status_target_encoded,occupation_target_encoded,relationship_target_encoded,race_target_encoded,gender_target_encoded,native-country_target_encoded
3863,36,275338,13,0,0,40,0.216647,0.400943,0.463214,0.465164,0.518293,0.271972,0.116522,0.25909
1058,19,248749,10,0,0,20,0.216647,0.202156,0.041958,0.271663,0.021195,0.271972,0.318722,0.25909
2491,44,244974,13,0,0,44,0.314815,0.400943,0.463214,0.337349,0.460474,0.271972,0.318722,0.25909
2971,33,306309,14,0,0,50,0.314815,0.563452,0.041958,0.443089,0.102247,0.271972,0.318722,0.25909
793,19,28145,9,0,0,52,0.216647,0.164599,0.041958,0.044818,0.021195,0.271972,0.116522,0.25909
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3296,29,153416,13,0,0,55,0.216647,0.400943,0.041958,0.465164,0.102247,0.271972,0.116522,0.25909
1867,38,197711,6,0,0,40,0.216647,0.054348,0.104082,0.100437,0.102247,0.284211,0.116522,0.37500
4435,67,152102,9,0,0,65,0.288401,0.164599,0.090090,0.140625,0.102247,0.271972,0.318722,0.25909
2443,33,123291,9,0,0,84,0.216647,0.164599,0.463214,0.100437,0.460474,0.271972,0.318722,0.25909


In [8]:
x_test

Unnamed: 0,age,fnlwgt,educational-num,capital-gain,capital-loss,hours-per-week,workclass_target_encoded,education_target_encoded,marital-status_target_encoded,occupation_target_encoded,relationship_target_encoded,race_target_encoded,gender_target_encoded,native-country_target_encoded
445,22,178818,10,0,0,20,0.314815,0.202156,0.041958,0.443089,0.021195,0.271972,0.116522,0.25909
4271,62,173601,13,0,0,40,0.216647,0.400943,0.463214,0.100437,0.460474,0.271972,0.318722,0.25909
2504,23,134446,9,0,0,54,0.216647,0.164599,0.090909,0.100437,0.054688,0.101639,0.318722,0.25909
4750,46,96652,11,0,0,40,0.314815,0.277027,0.090909,0.137300,0.054688,0.101639,0.116522,0.25909
4378,62,174711,9,0,0,32,0.216647,0.164599,0.041958,0.137300,0.102247,0.271972,0.116522,0.25909
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3720,36,49626,9,0,0,40,0.216647,0.164599,0.463214,0.465164,0.460474,0.271972,0.318722,0.25909
1999,20,218962,9,0,0,40,0.216647,0.164599,0.041958,0.100437,0.102247,0.271972,0.318722,0.25909
2375,33,183778,12,0,0,40,0.288401,0.330357,0.104082,0.271663,0.102247,0.271972,0.318722,0.25909
132,53,297796,6,0,0,40,0.288401,0.054348,0.463214,0.140625,0.460474,0.271972,0.318722,0.25909


In [22]:
# df = df_original.copy()
# df_modified = target_encoder.transform(df)

# for col in cat_col:
#     df[col] = df[col].astype('category')
# df, df_te = train_test_split(df, test_size = 0.2,  random_state = 5)
# df.to_csv("../data/train.csv", index=False)
# df_te.to_csv("../data/test.csv", index=False)
# target = 'income'

# x_train = df.loc[:, df.columns != target]
# y_train = df[target]

# x_test = df_te.loc[:, df_te.columns != target]
# y_test = df_te[target]

## Create Supervised Synthesizers

In [22]:
params_range = {
            'method': "CTGAN",
            'epochs':  1000,
            'batch_size':  hp.randint('batch_size',1, 5), # multiple of 100
            'g_dim1':  hp.randint('g_dim1',1, 3), # multiple of 128
            'g_dim2':  hp.randint('g_dim2',1, 3), # multiple of 128
            'g_dim3':  hp.randint('g_dim3',0, 3), # multiple of 128
            'd_dim1':  hp.randint('d_dim1',1, 3), # multiple of 128
            'd_dim2':  hp.randint('d_dim2',1, 3), # multiple of 128
            'd_dim3':  hp.randint('d_dim3',0, 3), # multiple of 128
           } 

def fit_synth(df, params):
    metadata = SingleTableMetadata()
    metadata.detect_from_dataframe(data=df)
    method = params['method']
    if method == "GaussianCopula":
        synth = GaussianCopulaSynthesizer(metadata=metadata)
    elif method == "CTGAN" or method =="CopulaGAN":
        epoch = params['epochs']
        batch_size = params['batch_size']*100
        if params["g_dim3"] != 0:
            generator_dim = (128*params['g_dim1'], 128*params['g_dim2'], 128*params['g_dim3'])
        else:
            generator_dim = (128*params['g_dim1'], 128*params['g_dim2'])
        if params["d_dim3"] != 0:
            discriminator_dim = (128*params['d_dim1'], 128*params['d_dim2'], 128*params['d_dim3'])
        else:
            discriminator_dim = (128*params['d_dim1'], 128*params['d_dim2'])
        discriminator_lr = params['d_lr']
        generator_lr = params['g_lr']
        if method == "CTGAN":
            synth = CTGANSynthesizer(metadata=metadata, epochs=epoch, batch_size=batch_size, generator_dim=generator_dim, 
                                     discriminator_dim=discriminator_dim, generator_lr=generator_lr, 
                                     discriminator_lr=discriminator_lr)
        if method == "CopulaGAN":
            synth = CopulaGANSynthesizer(metadata=metadata, epochs=epoch, batch_size=batch_size, generator_dim=generator_dim,
                                         discriminator_dim=discriminator_dim, generator_lr=generator_lr,
                                         discriminator_lr=discriminator_lr)
    elif method == "TVAE":
        epoch = params['epochs']
        batch_size = params['batch_size']*100
        if params["c_dim3"] != 0:
            compress_dims = (64*params['c_dim1'], 64*params['c_dim2'], 64*params['c_dim3'])
        else:
            compress_dims = (64*params['c_dim1'], 64*params['c_dim2'])
        if params["d_dim3"] != 0:
            decompress_dims = (64*params['d_dim1'], 64*params['d_dim2'], 64*params['d_dim3'])
        else:
            decompress_dims = (64*params['d_dim1'], 64*params['d_dim2'])
        synth = TVAESynthesizer(metadata=metadata, epochs=epoch, batch_size=batch_size, compress_dims=compress_dims, 
                                 decompress_dims=decompress_dims)
    else:
        raise ValueError("Invalid model name: " + method)
    return synth

def downstream_loss(sampled, df_te, target, classifier):
    x_samp = sampled.loc[:, sampled.columns != target]
    y_samp = sampled[target]
    x_test = df_te.loc[:, sampled.columns != target]
    y_test = df_te[target]
    print("------------------------")
    print(x_samp)
    print(x_test)
    if classifier == "XGB":
        for column in x_samp.columns:
            if x_samp[column].dtype == 'object':
                x_samp[column] = x_train[column].astype('category')
                x_test[column] = x_test[column].astype('category')
        print("++++++++++++++++")
        print(x_samp)
        print(x_test)
        dtrain = xgb.DMatrix(data=x_samp, label=y_samp, enable_categorical=True)
        dtest = xgb.DMatrix(data=x_test, label=y_test, enable_categorical=True)
        print("+++++++++++++++++")
        clf = xgb.train(params_xgb, dtrain, 1000, verbose_eval=False)
        clf_probs = clf.predict(dtest)
        print(clf_probs)
        clf_auc = roc_auc_score(y_test.values.astype(float), clf_probs)
        return clf_auc
    else:
        raise ValueError("Invalid classifier: " + classifier)
        
    
    

In [23]:
params_range = {
    'N_sim': 10000,
    'target': 'income',
    'loss': 'ROCAUC',
    'method': 'CTGAN',
    'epochs':  1000,
    'batch_size':  hp.randint('batch_size',1, 5), # multiple of 100
    'g_dim1':  hp.randint('g_dim1',1, 3), # multiple of 128
    'g_dim2':  hp.randint('g_dim2',1, 3), # multiple of 128
    'g_dim3':  hp.randint('g_dim3',0, 3), # multiple of 128
    'd_dim1':  hp.randint('d_dim1',1, 3), # multiple of 128
    'd_dim2':  hp.randint('d_dim2',1, 3), # multiple of 128
    'd_dim3':  hp.randint('d_dim3',0, 3), # multiple of 128
    'd_lr': 2e-4, "g_lr": 2e-4
} 


In [24]:
def objective_maximize(params):
    global best_test_roc 
    global best_synth
    print("1. Reached here !!!!!!!")
    synth = fit_synth(df_train_modified, params)
    print(synth)
    print("2. Reached here !!!!!!!")
    synth.fit(df_train_modified)
    print("3. Reached here !!!!!!!")
    N_sim = params["N_sim"]
    sampled = synth.sample(num_rows = N_sim)
    print('Sample data: ', sampled)
    print("4. Reached here !!!!!!!")
    clf_auc = downstream_loss(sampled, df_test_modified, target, classifier = "XGB")

    if clf_auc > best_test_roc:
        best_test_roc = clf_auc
        best_synth = sampled
    
    return {
        'loss' : 1 - clf_auc,
        'status' : STATUS_OK,
        'eval_time ': time.time(),
        'test_roc' : clf_auc,
        }


def trainDT(max_evals:int):
    global best_test_roc
    global best_synth
    
    best_test_roc = 0
    trials = Trials()
    start = time.time()
    clf_best_param = fmin(fn=objective_maximize,
                    space=params_range,
                    max_evals=max_evals,
                   # rstate=np.random.default_rng(42),
                    algo=tpe.suggest,
                    trials=trials)
    print(clf_best_param)
    print('It takes %s minutes' % ((time.time() - start)/60))
    return best_test_roc, best_synth, clf_best_param

In [25]:
best_test_roc, best_synth, clf_best_param = trainDT(10)

1. Reached here !!!!!!!                               
<sdv.single_table.ctgan.CTGANSynthesizer object at 0x15f391390>
2. Reached here !!!!!!!                               
  0%|          | 0/10 [00:00<?, ?trial/s, best loss=?]


























3. Reached here !!!!!!!                               
Sample data:                                          
      age  fnlwgt  educational-num  capital-gain  capital-loss  \
0      53  217873               10           106             2   
1      24  180628               11             0             2   
2      48  214273                4            57             0   
3      65  109974               10             0             0   
4      37  190525                9            46             5   
...   ...     ...              ...           ...           ...   
9995   53  573656               10            28             0   
9996   22  423877                9            42             0   
9997   39  141467               13          5330             1   
9998   57  208188               11            25             0   
9999   37  112804                9            14             0   

      hours-per-week  income  workclass_target_encoded  \
0                 40       0           


























3. Reached here !!!!!!!                                                          
Sample data:                                                                     
      age  fnlwgt  educational-num  capital-gain  capital-loss  \                
0      30   65187                9          4849             2   
1      41  104456               14             0             0   
2      48  216944               10            64             0   
3      27  154641               13            75             4   
4      40  139015                9          3363             1   
...   ...     ...              ...           ...           ...   
9995   65  378189               11            35             0   
9996   61  210390                9             0             0   
9997   85  205318                9             0             1   
9998   52  459555               14            14             0   
9999   20  241172               10             0             0   

      hours-per-week  incom


























3. Reached here !!!!!!!                                                          
Sample data:                                                                     
      age  fnlwgt  educational-num  capital-gain  capital-loss  \                
0      49  551862               10            55             1   
1      42  220838               13             0             0   
2      63  357876               10            44          1881   
3      58  243624                9             0             0   
4      42  400060                7             0             0   
...   ...     ...              ...           ...           ...   
9995   50  467692               13             0             0   
9996   46  203185                9            14             0   
9997   45  115538               13             9             1   
9998   61  326139                9             0             2   
9999   22  346342                5             0             0   

      hours-per-week  incom


























3. Reached here !!!!!!!                                                          
Sample data:                                                                     
      age  fnlwgt  educational-num  capital-gain  capital-loss  \                
0      25  214404               13             0          1680   
1      46  119080               14             0             0   
2      23  124661               13             0             0   
3      45  310770               10             0             2   
4      54  204068                6             0          1996   
...   ...     ...              ...           ...           ...   
9995   35  363281                4             0             4   
9996   49  174646                9            46             0   
9997   29  189647                9            37             3   
9998   48  135862               13             0             0   
9999   26  255210               10             0             1   

      hours-per-week  incom


























3. Reached here !!!!!!!                                                          
Sample data:                                                                     
      age  fnlwgt  educational-num  capital-gain  capital-loss  \                
0      55  143403               15             1             0   
1      31  294513                7           103             0   
2      32  199103                9           100          1950   
3      22  293828                9            93             1   
4      25  287091                9             0             0   
...   ...     ...              ...           ...           ...   
9995   53  367325                9             0             1   
9996   53   50623               13            99             0   
9997   42   97568               10            29             0   
9998   45  186745               10            90             0   
9999   36  158626               10            80             0   

      hours-per-week  incom


























3. Reached here !!!!!!!                                                           
Sample data:                                                                      
      age  fnlwgt  educational-num  capital-gain  capital-loss  \                 
0      25  215708                9            22             1   
1      29  198688               14             0             1   
2      50  125976                9            61             0   
3      45   88291               12            98             0   
4      46  145807               13            24             1   
...   ...     ...              ...           ...           ...   
9995   47   56951               10             0             0   
9996   45  143867               10             0             0   
9997   35  251038                5           150             0   
9998   54  102407               10             7             1   
9999   23  232988                4             4             4   

      hours-per-week  in


























3. Reached here !!!!!!!                                                           
Sample data:                                                                      
      age  fnlwgt  educational-num  capital-gain  capital-loss  \                 
0      44  214543                9          3753             2   
1      42   67299               10             0             0   
2      45  405764                9             0             0   
3      36  290747                9            28             2   
4      40  293810                9             0          1964   
...   ...     ...              ...           ...           ...   
9995   57  378972                7             0             3   
9996   20   54564               13            46             0   
9997   44  438715               13             0             3   
9998   79  520624               11             0             3   
9999   42   41198               13            40             0   

      hours-per-week  in


























3. Reached here !!!!!!!                                                           
Sample data:                                                                      
      age  fnlwgt  educational-num  capital-gain  capital-loss  \                 
0      45  188754                9             0             0   
1      66  458295               10             0             0   
2      24  159191                9             0             4   
3      31  233701               11            32             0   
4      29  129242                9            38             0   
...   ...     ...              ...           ...           ...   
9995   32  281246                4             0             2   
9996   53  216945                9             0          1741   
9997   33  189486               10             0             0   
9998   30  131035               10            27             1   
9999   33  156608                9             0             3   

      hours-per-week  in


























3. Reached here !!!!!!!                                                           
Sample data:                                                                      
      age  fnlwgt  educational-num  capital-gain  capital-loss  \                 
0      26  488062               10             0             0   
1      32  230935                9             0             0   
2      25  150598                7            26             2   
3      65  197327               16            83             4   
4      50  187418               13             2             1   
...   ...     ...              ...           ...           ...   
9995   59  528662                9            19             5   
9996   45  184006               10             0             0   
9997   46  125720                9            23          1842   
9998   32  197194                9            63             0   
9999   44  100812                9             0             0   

      hours-per-week  in


























3. Reached here !!!!!!!                                                           
Sample data:                                                                      
      age  fnlwgt  educational-num  capital-gain  capital-loss  \                 
0      32  141244                9             0             0   
1      45  150700               14             0             3   
2      61  170242                9             0             0   
3      42   99419               10            49             4   
4      43  214185               13             0             1   
...   ...     ...              ...           ...           ...   
9995   48  257647                5             0             2   
9996   39  213399                9             0             0   
9997   51  122699               12             0             0   
9998   34  188883               10             0             2   
9999   24  295560                4            29             0   

      hours-per-week  in

In [26]:
best_test_roc

0.8437410187345021

In [27]:
best_synth

Unnamed: 0,age,fnlwgt,educational-num,capital-gain,capital-loss,hours-per-week,income,workclass_target_encoded,education_target_encoded,marital-status_target_encoded,occupation_target_encoded,relationship_target_encoded,race_target_encoded,gender_target_encoded,native-country_target_encoded
0,32,141244,9,0,0,40,0,0.217574,0.163273,0.042203,0.045764,0.035743,0.272015,0.317889,0.258805
1,45,150700,14,0,3,40,0,0.216459,0.576930,0.113043,0.454884,0.111376,0.271562,0.116522,0.259680
2,61,170242,9,0,0,64,0,0.289927,0.164506,0.463214,0.138267,0.463029,0.271437,0.318076,0.258960
3,42,99419,10,49,4,21,0,0.215932,0.203169,0.096757,0.463132,0.102456,0.271420,0.117301,0.259228
4,43,214185,13,0,1,40,1,0.335999,0.398792,0.105110,0.460847,0.462290,0.271779,0.317359,0.259037
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,48,257647,5,0,2,40,0,0.216307,0.059685,0.089488,0.058799,0.047623,0.271781,0.116522,0.077401
9996,39,213399,9,0,0,40,0,0.289980,0.164816,0.041958,0.044818,0.052550,0.272409,0.318600,0.259050
9997,51,122699,12,0,0,35,0,0.314267,0.292643,0.106349,0.457622,0.040434,0.272240,0.116522,0.259487
9998,34,188883,10,0,2,40,0,0.216428,0.206236,0.463214,0.351169,0.459488,0.271806,0.318135,0.259580


In [28]:
clf_best_param

{'batch_size': 1,
 'd_dim1': 1,
 'd_dim2': 2,
 'd_dim3': 2,
 'g_dim1': 2,
 'g_dim2': 2,
 'g_dim3': 2}