# The Credit Card Fraud Dataset - Synthesizing the Minority Class

In this notebook a practical exercise is presented to showcase the usage of the YData Synthetic library along with
GANs to synthesize tabular data.
For the purpose of this exercise, dataset of credit card fraud from Kaggle is used, that can be found here:
https://www.kaggle.com/mlg-ulb/creditcardfraud

In [1]:
# Note: You can select between running the Notebook on "CPU" or "GPU"
# Click "Runtime > Change Runtime time" and set "GPU"

In [2]:
# Install ydata-synthetic lib
# ! pip install ydata-synthetic

In [1]:
import os

import matplotlib.pyplot as plt
import sklearn.cluster as cluster
from numpy import array, random, sum, unique
from pandas import DataFrame, read_csv

from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
from ydata_synthetic.synthesizers.regular import VanilllaGAN

In [2]:
model = VanilllaGAN

# Read the original data and have it preprocessed
data = read_csv('../../data/creditcard.csv', index_col=[0])

In [3]:
#List of columns different from the Class column
num_cols = list(data.columns[ data.columns != 'Class' ])
cat_cols = ['Class']

print('Dataset columns: {}'.format(num_cols))
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class']
processed_data = data[ sorted_cols ].copy()

Dataset columns: ['V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount']


In [4]:
# For the purpose of this example we will only synthesize the minority class
# train_data contains 492 rows which had 'Class' value as 1 (which were very few)
train_data = data.loc[ data['Class']==1 ].copy()

print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))

# We define a K-means clustering method using sklearn, and declare that
# we want 2 clusters. We then apply this algorithm (fit_predict) to our train_data
# We essentially get an array of 492 rows ('labels') having values either 0 or 1 for the 2 clustered classes.
algorithm = cluster.KMeans
args, kwds = (), {'n_clusters':2, 'random_state':0}
labels = algorithm(*args, **kwds).fit_predict(train_data[ num_cols ])

# Get the count of both classes
print( DataFrame( [ [sum(labels==i)] for i in unique(labels) ], columns=['count'], index=unique(labels) ) )

# Assign the k-means clustered classes' labels to the a seperate copy of train data 'fraud_w_classes'
fraud_w_classes = train_data.copy()
fraud_w_classes['Class'] = labels

Dataset info: Number of records - 492 Number of variables - 30
   count
0    455
1     37


# GAN training

Below you can try to train your own generators using the available GANs architectures. You can train it either with labels (created using KMeans) or with no labels at all. 

Remember that for this exercise in particular we've decided to synthesize only the minority class from the Credit Fraud dataset.

In [5]:
# Define the GAN and training parameters
noise_dim = 32
dim = 128
batch_size = 128

log_step = 100
epochs = 200+1
learning_rate = 5e-4
beta_1 = 0.5
beta_2 = 0.9
models_dir = './cache'

#Setting the GAN model parameters and the training step parameters
gan_args = ModelParameters(batch_size=batch_size,
                           lr=learning_rate,
                           betas=(beta_1, beta_2),
                           noise_dim=noise_dim,
                           layers_dim=dim)

train_args = TrainParameters(epochs=epochs,
                             sample_interval=log_step)

In [6]:
# Training the GAN model chosen: Vanilla GAN, CGAN, DCGAN, etc.
synthesizer = model(gan_args)
synthesizer.train(data = fraud_w_classes, train_arguments = train_args, num_cols = num_cols, cat_cols = cat_cols)

  1%|▏         | 3/201 [00:00<00:40,  4.84it/s]

0 [D loss: 0.550036, acc.: 50.00%] [G loss: 0.627987]
generated_data
1 [D loss: 0.724137, acc.: 50.00%] [G loss: 0.482935]
2 [D loss: 0.773580, acc.: 44.14%] [G loss: 0.655450]


  3%|▎         | 7/201 [00:01<00:20,  9.55it/s]

3 [D loss: 0.723554, acc.: 32.03%] [G loss: 0.887138]
4 [D loss: 0.686885, acc.: 50.00%] [G loss: 0.951322]
5 [D loss: 0.730200, acc.: 39.84%] [G loss: 0.795471]
6 [D loss: 0.689810, acc.: 51.56%] [G loss: 0.808690]


  5%|▌         | 11/201 [00:01<00:15, 12.43it/s]

7 [D loss: 0.707817, acc.: 35.16%] [G loss: 0.718341]
8 [D loss: 0.690032, acc.: 49.22%] [G loss: 0.724760]
9 [D loss: 0.690305, acc.: 42.97%] [G loss: 0.725031]
10 [D loss: 0.699105, acc.: 42.58%] [G loss: 0.725312]


  7%|▋         | 15/201 [00:01<00:13, 13.97it/s]

11 [D loss: 0.631659, acc.: 78.91%] [G loss: 0.848520]
12 [D loss: 0.742673, acc.: 34.77%] [G loss: 0.677510]
13 [D loss: 0.724618, acc.: 33.20%] [G loss: 0.740743]
14 [D loss: 0.598308, acc.: 85.94%] [G loss: 0.950827]


  8%|▊         | 17/201 [00:01<00:12, 14.42it/s]

15 [D loss: 0.748309, acc.: 35.55%] [G loss: 0.643942]
16 [D loss: 0.821484, acc.: 28.12%] [G loss: 0.584143]
17 [D loss: 0.676677, acc.: 59.77%] [G loss: 0.840448]


 10%|█         | 21/201 [00:01<00:12, 14.21it/s]

18 [D loss: 0.610853, acc.: 83.59%] [G loss: 0.938941]
19 [D loss: 0.650022, acc.: 63.28%] [G loss: 0.887446]
20 [D loss: 0.756313, acc.: 29.30%] [G loss: 0.759600]
21 [D loss: 0.728475, acc.: 33.98%] [G loss: 0.778566]


 12%|█▏        | 25/201 [00:02<00:11, 14.75it/s]

22 [D loss: 0.636214, acc.: 75.39%] [G loss: 0.886067]
23 [D loss: 0.701726, acc.: 48.05%] [G loss: 0.729937]
24 [D loss: 0.692845, acc.: 39.06%] [G loss: 0.725504]
25 [D loss: 0.641558, acc.: 67.97%] [G loss: 0.878462]


 14%|█▍        | 29/201 [00:02<00:11, 15.10it/s]

26 [D loss: 0.724304, acc.: 32.81%] [G loss: 0.764726]
27 [D loss: 0.604908, acc.: 83.98%] [G loss: 0.912407]
28 [D loss: 0.685659, acc.: 40.23%] [G loss: 0.705084]
29 [D loss: 0.784759, acc.: 26.56%] [G loss: 0.616831]


 16%|█▋        | 33/201 [00:02<00:10, 15.39it/s]

30 [D loss: 0.680897, acc.: 48.44%] [G loss: 0.784109]
31 [D loss: 0.488568, acc.: 88.28%] [G loss: 1.022120]
32 [D loss: 0.517986, acc.: 71.88%] [G loss: 0.984776]
33 [D loss: 0.773452, acc.: 37.89%] [G loss: 0.677296]


 18%|█▊        | 37/201 [00:02<00:10, 15.60it/s]

34 [D loss: 0.845428, acc.: 28.91%] [G loss: 0.568926]
35 [D loss: 0.688814, acc.: 51.56%] [G loss: 0.843009]
36 [D loss: 0.638633, acc.: 62.89%] [G loss: 0.920435]
37 [D loss: 0.684682, acc.: 46.88%] [G loss: 0.878682]


 20%|██        | 41/201 [00:03<00:10, 15.65it/s]

38 [D loss: 0.742228, acc.: 33.20%] [G loss: 0.801381]
39 [D loss: 0.747093, acc.: 33.98%] [G loss: 0.878402]
40 [D loss: 0.605564, acc.: 76.56%] [G loss: 1.110654]
41 [D loss: 0.517402, acc.: 88.67%] [G loss: 1.188590]


 22%|██▏       | 45/201 [00:03<00:09, 15.70it/s]

42 [D loss: 0.622090, acc.: 64.06%] [G loss: 0.839116]
43 [D loss: 0.687085, acc.: 52.34%] [G loss: 0.777982]
44 [D loss: 0.699964, acc.: 39.45%] [G loss: 0.749893]
45 [D loss: 0.646823, acc.: 59.38%] [G loss: 0.811863]


 24%|██▍       | 49/201 [00:03<00:09, 15.67it/s]

46 [D loss: 0.545613, acc.: 80.08%] [G loss: 1.006777]
47 [D loss: 0.562644, acc.: 74.61%] [G loss: 1.045896]
48 [D loss: 0.748366, acc.: 50.39%] [G loss: 0.884654]
49 [D loss: 0.733166, acc.: 38.28%] [G loss: 0.893640]


 26%|██▋       | 53/201 [00:03<00:09, 15.48it/s]

50 [D loss: 0.684397, acc.: 48.05%] [G loss: 0.911950]
51 [D loss: 0.673821, acc.: 49.22%] [G loss: 0.949331]
52 [D loss: 0.682803, acc.: 50.00%] [G loss: 0.892081]
53 [D loss: 0.633063, acc.: 64.84%] [G loss: 1.007781]


 28%|██▊       | 57/201 [00:04<00:10, 14.24it/s]

54 [D loss: 0.597956, acc.: 72.27%] [G loss: 0.971366]
55 [D loss: 0.689976, acc.: 60.94%] [G loss: 0.909638]
56 [D loss: 0.755158, acc.: 47.66%] [G loss: 0.840833]
57 [D loss: 0.721160, acc.: 41.41%] [G loss: 0.816230]


 30%|███       | 61/201 [00:04<00:09, 14.95it/s]

58 [D loss: 0.710852, acc.: 41.02%] [G loss: 0.876071]
59 [D loss: 0.722174, acc.: 37.89%] [G loss: 0.795401]
60 [D loss: 0.678059, acc.: 51.17%] [G loss: 0.851614]
61 [D loss: 0.700930, acc.: 44.14%] [G loss: 0.809518]


 32%|███▏      | 65/201 [00:04<00:09, 15.11it/s]

62 [D loss: 0.697455, acc.: 52.73%] [G loss: 0.880072]
63 [D loss: 0.670013, acc.: 54.30%] [G loss: 0.881521]
64 [D loss: 0.672341, acc.: 53.52%] [G loss: 0.852372]
65 [D loss: 0.688349, acc.: 42.19%] [G loss: 0.767715]


 34%|███▍      | 69/201 [00:05<00:08, 15.28it/s]

66 [D loss: 0.673236, acc.: 50.39%] [G loss: 0.785096]
67 [D loss: 0.662382, acc.: 58.59%] [G loss: 0.802408]
68 [D loss: 0.671653, acc.: 53.12%] [G loss: 0.758970]
69 [D loss: 0.678161, acc.: 50.00%] [G loss: 0.772674]


 36%|███▋      | 73/201 [00:05<00:08, 15.42it/s]

70 [D loss: 0.670603, acc.: 55.47%] [G loss: 0.800789]
71 [D loss: 0.671375, acc.: 45.31%] [G loss: 0.751903]
72 [D loss: 0.681129, acc.: 45.70%] [G loss: 0.755616]
73 [D loss: 0.666789, acc.: 52.34%] [G loss: 0.776031]


 38%|███▊      | 77/201 [00:05<00:08, 15.24it/s]

74 [D loss: 0.662781, acc.: 54.30%] [G loss: 0.777716]
75 [D loss: 0.673832, acc.: 46.09%] [G loss: 0.755191]
76 [D loss: 0.664220, acc.: 57.81%] [G loss: 0.774768]
77 [D loss: 0.660048, acc.: 52.34%] [G loss: 0.773231]


 40%|████      | 81/201 [00:05<00:07, 15.06it/s]

78 [D loss: 0.675459, acc.: 53.52%] [G loss: 0.759155]
79 [D loss: 0.677063, acc.: 54.69%] [G loss: 0.761785]
80 [D loss: 0.666177, acc.: 54.30%] [G loss: 0.791037]
81 [D loss: 0.648826, acc.: 56.25%] [G loss: 0.802469]


 42%|████▏     | 85/201 [00:06<00:07, 15.17it/s]

82 [D loss: 0.667839, acc.: 45.70%] [G loss: 0.752686]
83 [D loss: 0.677281, acc.: 55.86%] [G loss: 0.792059]
84 [D loss: 0.645253, acc.: 60.94%] [G loss: 0.804177]
85 [D loss: 0.679732, acc.: 54.30%] [G loss: 0.760794]


 44%|████▍     | 89/201 [00:06<00:07, 15.29it/s]

86 [D loss: 0.684020, acc.: 51.56%] [G loss: 0.811780]
87 [D loss: 0.641267, acc.: 59.77%] [G loss: 0.870217]
88 [D loss: 0.656589, acc.: 56.64%] [G loss: 0.793731]
89 [D loss: 0.655231, acc.: 61.33%] [G loss: 0.778803]


 46%|████▋     | 93/201 [00:06<00:07, 15.38it/s]

90 [D loss: 0.648248, acc.: 59.77%] [G loss: 0.800200]
91 [D loss: 0.635786, acc.: 64.45%] [G loss: 0.797188]
92 [D loss: 0.663567, acc.: 51.56%] [G loss: 0.784666]
93 [D loss: 0.663212, acc.: 56.25%] [G loss: 0.866105]


 48%|████▊     | 97/201 [00:06<00:06, 15.20it/s]

94 [D loss: 0.636874, acc.: 62.11%] [G loss: 0.884161]
95 [D loss: 0.631512, acc.: 69.53%] [G loss: 0.898002]
96 [D loss: 0.658940, acc.: 55.47%] [G loss: 0.810010]
97 [D loss: 0.661417, acc.: 53.12%] [G loss: 0.852796]


 50%|█████     | 101/201 [00:07<00:06, 14.93it/s]

98 [D loss: 0.638799, acc.: 66.02%] [G loss: 0.853503]
99 [D loss: 0.615267, acc.: 71.88%] [G loss: 0.891819]
100 [D loss: 0.653688, acc.: 56.64%] [G loss: 0.866009]
generated_data


 52%|█████▏    | 105/201 [00:07<00:06, 15.11it/s]

101 [D loss: 0.666099, acc.: 50.39%] [G loss: 0.842007]
102 [D loss: 0.651683, acc.: 58.98%] [G loss: 0.923288]
103 [D loss: 0.646693, acc.: 62.50%] [G loss: 0.858193]
104 [D loss: 0.659455, acc.: 58.98%] [G loss: 0.865092]


 54%|█████▍    | 109/201 [00:07<00:05, 15.38it/s]

105 [D loss: 0.645711, acc.: 59.38%] [G loss: 0.880491]
106 [D loss: 0.653809, acc.: 55.08%] [G loss: 0.818559]
107 [D loss: 0.678842, acc.: 43.75%] [G loss: 0.767874]
108 [D loss: 0.646079, acc.: 58.20%] [G loss: 0.876584]


 56%|█████▌    | 113/201 [00:07<00:05, 15.48it/s]

109 [D loss: 0.666158, acc.: 51.56%] [G loss: 0.835043]
110 [D loss: 0.640382, acc.: 59.38%] [G loss: 0.832070]
111 [D loss: 0.643277, acc.: 60.94%] [G loss: 0.845614]
112 [D loss: 0.669900, acc.: 57.03%] [G loss: 0.820596]


 58%|█████▊    | 117/201 [00:08<00:05, 15.34it/s]

113 [D loss: 0.656420, acc.: 57.03%] [G loss: 0.836345]
114 [D loss: 0.641564, acc.: 57.42%] [G loss: 0.876468]
115 [D loss: 0.646709, acc.: 59.38%] [G loss: 0.825197]
116 [D loss: 0.651036, acc.: 54.69%] [G loss: 0.777496]


 60%|██████    | 121/201 [00:08<00:05, 15.05it/s]

117 [D loss: 0.652030, acc.: 58.59%] [G loss: 0.793402]
118 [D loss: 0.649176, acc.: 59.38%] [G loss: 0.811079]
119 [D loss: 0.649228, acc.: 54.69%] [G loss: 0.807060]
120 [D loss: 0.631465, acc.: 57.81%] [G loss: 0.873953]


 62%|██████▏   | 125/201 [00:08<00:04, 15.27it/s]

121 [D loss: 0.642788, acc.: 54.30%] [G loss: 0.842716]
122 [D loss: 0.633139, acc.: 61.72%] [G loss: 0.860812]
123 [D loss: 0.666036, acc.: 56.25%] [G loss: 0.833639]
124 [D loss: 0.637599, acc.: 57.42%] [G loss: 0.924634]


 64%|██████▍   | 129/201 [00:09<00:04, 15.19it/s]

125 [D loss: 0.643670, acc.: 58.59%] [G loss: 0.840004]
126 [D loss: 0.666960, acc.: 49.22%] [G loss: 0.773053]
127 [D loss: 0.638831, acc.: 59.77%] [G loss: 0.855623]
128 [D loss: 0.630438, acc.: 62.11%] [G loss: 0.862262]


 65%|██████▌   | 131/201 [00:09<00:04, 15.05it/s]

129 [D loss: 0.641727, acc.: 61.33%] [G loss: 0.807410]
130 [D loss: 0.647421, acc.: 56.64%] [G loss: 0.860150]
131 [D loss: 0.637578, acc.: 62.11%] [G loss: 0.880380]


 67%|██████▋   | 135/201 [00:09<00:04, 15.14it/s]

132 [D loss: 0.635053, acc.: 63.67%] [G loss: 0.867916]
133 [D loss: 0.642439, acc.: 56.25%] [G loss: 0.792550]
134 [D loss: 0.638722, acc.: 62.11%] [G loss: 0.829996]
135 [D loss: 0.650501, acc.: 54.30%] [G loss: 0.810637]


 69%|██████▉   | 139/201 [00:09<00:04, 15.30it/s]

136 [D loss: 0.642949, acc.: 60.55%] [G loss: 0.825194]
137 [D loss: 0.633457, acc.: 60.16%] [G loss: 0.918275]
138 [D loss: 0.637209, acc.: 61.72%] [G loss: 0.876791]
139 [D loss: 0.633082, acc.: 61.72%] [G loss: 0.847758]


 71%|███████   | 143/201 [00:09<00:03, 14.96it/s]

140 [D loss: 0.673803, acc.: 50.39%] [G loss: 0.753332]
141 [D loss: 0.652552, acc.: 55.08%] [G loss: 0.834769]
142 [D loss: 0.609212, acc.: 71.09%] [G loss: 0.917820]


 73%|███████▎  | 147/201 [00:10<00:03, 15.00it/s]

143 [D loss: 0.585832, acc.: 72.27%] [G loss: 0.918271]
144 [D loss: 0.700710, acc.: 44.14%] [G loss: 0.732573]
145 [D loss: 0.648752, acc.: 57.03%] [G loss: 0.886211]
146 [D loss: 0.631603, acc.: 61.33%] [G loss: 0.965210]


 75%|███████▌  | 151/201 [00:10<00:03, 15.22it/s]

147 [D loss: 0.644220, acc.: 61.72%] [G loss: 0.881742]
148 [D loss: 0.620618, acc.: 63.28%] [G loss: 0.864807]
149 [D loss: 0.630044, acc.: 61.72%] [G loss: 0.887240]
150 [D loss: 0.627703, acc.: 63.28%] [G loss: 0.853771]


 77%|███████▋  | 155/201 [00:10<00:02, 15.39it/s]

151 [D loss: 0.668791, acc.: 53.52%] [G loss: 0.834410]
152 [D loss: 0.611434, acc.: 61.33%] [G loss: 0.921129]
153 [D loss: 0.634519, acc.: 59.77%] [G loss: 0.854780]
154 [D loss: 0.642557, acc.: 53.91%] [G loss: 0.862049]


 79%|███████▉  | 159/201 [00:10<00:02, 15.36it/s]

155 [D loss: 0.655550, acc.: 58.20%] [G loss: 0.857191]
156 [D loss: 0.622134, acc.: 64.45%] [G loss: 0.904369]
157 [D loss: 0.619334, acc.: 66.02%] [G loss: 0.884082]
158 [D loss: 0.628977, acc.: 63.28%] [G loss: 0.854065]


 81%|████████  | 163/201 [00:11<00:02, 15.33it/s]

159 [D loss: 0.625271, acc.: 60.55%] [G loss: 0.915155]
160 [D loss: 0.629663, acc.: 59.77%] [G loss: 0.901985]
161 [D loss: 0.654066, acc.: 49.22%] [G loss: 0.845741]
162 [D loss: 0.617537, acc.: 65.23%] [G loss: 0.928255]


 83%|████████▎ | 167/201 [00:11<00:02, 15.52it/s]

163 [D loss: 0.627255, acc.: 65.62%] [G loss: 0.837423]
164 [D loss: 0.619008, acc.: 59.77%] [G loss: 0.889656]
165 [D loss: 0.624287, acc.: 62.50%] [G loss: 0.928081]
166 [D loss: 0.621571, acc.: 60.94%] [G loss: 0.963335]


 85%|████████▌ | 171/201 [00:11<00:01, 15.38it/s]

167 [D loss: 0.614337, acc.: 67.97%] [G loss: 0.873778]
168 [D loss: 0.631943, acc.: 61.33%] [G loss: 0.901347]
169 [D loss: 0.630981, acc.: 60.94%] [G loss: 0.866399]
170 [D loss: 0.608251, acc.: 66.02%] [G loss: 0.855200]


 87%|████████▋ | 175/201 [00:12<00:01, 15.29it/s]

171 [D loss: 0.616137, acc.: 64.06%] [G loss: 0.882089]
172 [D loss: 0.638764, acc.: 59.38%] [G loss: 0.906505]
173 [D loss: 0.634071, acc.: 60.94%] [G loss: 0.883332]
174 [D loss: 0.613683, acc.: 61.72%] [G loss: 0.965842]


 88%|████████▊ | 177/201 [00:12<00:01, 15.13it/s]

175 [D loss: 0.619862, acc.: 59.38%] [G loss: 0.901156]
176 [D loss: 0.625921, acc.: 62.11%] [G loss: 0.861191]
177 [D loss: 0.634482, acc.: 60.55%] [G loss: 0.858311]


 90%|█████████ | 181/201 [00:12<00:01, 14.95it/s]

178 [D loss: 0.610101, acc.: 64.06%] [G loss: 0.872502]
179 [D loss: 0.625674, acc.: 60.55%] [G loss: 0.905137]
180 [D loss: 0.626871, acc.: 62.89%] [G loss: 0.881574]


 92%|█████████▏| 185/201 [00:12<00:01, 14.93it/s]

181 [D loss: 0.621699, acc.: 63.28%] [G loss: 0.852086]
182 [D loss: 0.620817, acc.: 63.67%] [G loss: 0.896311]
183 [D loss: 0.585868, acc.: 64.84%] [G loss: 0.935561]
184 [D loss: 0.610238, acc.: 65.23%] [G loss: 0.875845]


 94%|█████████▍| 189/201 [00:12<00:00, 15.26it/s]

185 [D loss: 0.620626, acc.: 62.89%] [G loss: 0.904199]
186 [D loss: 0.621455, acc.: 62.89%] [G loss: 0.867477]
187 [D loss: 0.607304, acc.: 63.67%] [G loss: 0.905859]
188 [D loss: 0.597267, acc.: 64.45%] [G loss: 0.951865]


 96%|█████████▌| 193/201 [00:13<00:00, 15.40it/s]

189 [D loss: 0.622227, acc.: 62.11%] [G loss: 0.895168]
190 [D loss: 0.621200, acc.: 62.11%] [G loss: 0.871956]
191 [D loss: 0.604076, acc.: 64.45%] [G loss: 0.847970]
192 [D loss: 0.601647, acc.: 64.45%] [G loss: 0.908888]


 98%|█████████▊| 197/201 [00:13<00:00, 15.47it/s]

193 [D loss: 0.599213, acc.: 64.45%] [G loss: 0.939804]
194 [D loss: 0.595883, acc.: 63.28%] [G loss: 0.898727]
195 [D loss: 0.596841, acc.: 65.23%] [G loss: 0.925748]
196 [D loss: 0.633918, acc.: 58.20%] [G loss: 0.908294]


100%|██████████| 201/201 [00:13<00:00, 14.62it/s]

197 [D loss: 0.610397, acc.: 60.94%] [G loss: 0.952700]
198 [D loss: 0.601393, acc.: 64.06%] [G loss: 0.904418]
199 [D loss: 0.613972, acc.: 64.06%] [G loss: 0.877581]
200 [D loss: 0.604303, acc.: 63.67%] [G loss: 0.965180]
generated_data





In [7]:
# Generator description
synthesizer.generator.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(128, 32)]               0         
_________________________________________________________________
dense (Dense)                (128, 128)                4224      
_________________________________________________________________
dense_1 (Dense)              (128, 256)                33024     
_________________________________________________________________
dense_2 (Dense)              (128, 512)                131584    
_________________________________________________________________
dense_3 (Dense)              (128, 31)                 15903     
Total params: 184,735
Trainable params: 184,735
Non-trainable params: 0
_________________________________________________________________


In [8]:
# Discriminator description
synthesizer.discriminator.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(128, 31)]               0         
_________________________________________________________________
dense_4 (Dense)              (128, 512)                16384     
_________________________________________________________________
dropout (Dropout)            (128, 512)                0         
_________________________________________________________________
dense_5 (Dense)              (128, 256)                131328    
_________________________________________________________________
dropout_1 (Dropout)          (128, 256)                0         
_________________________________________________________________
dense_6 (Dense)              (128, 128)                32896     
_________________________________________________________________
dense_7 (Dense)              (128, 1)                  129 

In [9]:
# You can easily save the trained generator and loaded it afterwards
if not os.path.exists("./saved/gan"):
    os.makedirs("./saved/gan")
synthesizer.save(path="./saved/gan/generator_fraud.pkl")

In [10]:
models = {'GAN': ['GAN', False, synthesizer.generator]}

In [13]:
# Setup parameters visualization parameters
seed = 17
test_size = 492 # number of fraud cases
noise_dim = 32

random.seed(seed)
z = random.normal(size=(test_size, noise_dim))
real_processed = synthesizer.processor.transform(fraud_w_classes)
real_samples = synthesizer.get_data_batch(real_processed, batch_size)
class_labels = ['Class_1','Class_2']
real_samples = DataFrame(real_samples, columns=num_cols+class_labels)
labels = fraud_w_classes['Class']

model_names = ['GAN']
colors = ['deepskyblue','blue']
markers = ['o','^']

col1, col2 = 'V17', 'V10'

base_dir = 'cache/'

# Actual fraud data visualization
model_steps = [ 0, 100, 200]
rows = len(model_steps)
columns = 1 + len(models)

axarr = [[]]*len(model_steps)

fig = plt.figure(figsize=(14,rows*3))

# Go through each of the 3 model_step values -> 0, 100, 200
for model_step_ix, model_step in enumerate(model_steps):
    axarr[model_step_ix] = plt.subplot(rows, columns, model_step_ix*columns + 1)

    # Plot 'Class 1' and 'Class 2' samples taken from the original data, in a random shuffled fashion
    for group, color, marker, label in zip(real_samples.groupby('Class_1'), colors, markers, class_labels ):
        plt.scatter( group[1][[col1]], group[1][[col2]],
                         label=label, marker=marker, edgecolors=color, facecolors='none' )

    plt.title('Actual Fraud Data')
    plt.ylabel(col2) # Only add y label to left plot
    plt.xlabel(col1)
    xlims, ylims = axarr[model_step_ix].get_xlim(), axarr[model_step_ix].get_ylim()

    if model_step_ix == 0:
        legend = plt.legend()
        legend.get_frame().set_facecolor('white')

    # Go through all the GAN models listed in 'model_names' and defined in 'models'
    for i, model_name in enumerate( model_names[:] ):

        [model_name, with_class, generator_model] = models[model_name]

        generator_model.load_weights( base_dir + '_generator_model_weights_step_'+str(model_step)+'.h5')

        ax = plt.subplot(rows, columns, model_step_ix*columns + 1 + (i+1) )

        if with_class:
            g_z = generator_model([z, labels])
            gen_samples = DataFrame(g_z, columns=num_cols+class_labels)
            for group, color, marker, label in zip( gen_samples.groupby('Class_1'), colors, markers, class_labels ):
                plt.scatter( group[1][[col1]], group[1][[col2]],
                                 label=label, marker=marker, edgecolors=color, facecolors='none' )
        else:
            g_z = generator_model(z)
            gen_samples = DataFrame(g_z, columns=num_cols+class_labels)
            gen_samples.to_csv('../../data/Generated_sample.csv')
            plt.scatter( gen_samples[[col1]], gen_samples[[col2]],
                             label=class_labels[0], marker=markers[0], edgecolors=colors[0], facecolors='none' )
        plt.title(model_name)
        plt.xlabel(col1)
        ax.set_xlim(xlims), ax.set_ylim(ylims)

plt.suptitle('Comparison of GAN outputs', size=16, fontweight='bold')
plt.tight_layout(rect=[0.075,0,1,0.95])

# Adding text labels for training steps
vpositions = array([ i._position.bounds[1] for i in axarr ])
vpositions += ((vpositions[0] - vpositions[1]) * 0.35 )
for model_step_ix, model_step in enumerate( model_steps ):
    fig.text( 0.05, vpositions[model_step_ix], 'training\nstep\n'+str(model_step), ha='center', va='center', size=12)

if not os.path.exists("./img"):
    os.makedirs("./img")
plt.savefig('img/Comparison_of_GAN_outputs.png', dpi=100)