Skip to content

Commit

Permalink
feat(gan): Changes to example
Browse files Browse the repository at this point in the history
  • Loading branch information
fabclmnt committed May 4, 2020
1 parent e77fc31 commit 0864c90
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 4 deletions.
103 changes: 99 additions & 4 deletions example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
"import pandas as pd\n",
"import numpy as np\n",
"import sklearn.cluster as cluster\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from models.gan import model\n",
"importlib.reload(model)\n",
"\n",
"from models.gan.model import GAN\n",
"from preprocessing.credit_fraud import *\n",
"\n",
"model = GAN"
],
"metadata": {
Expand All @@ -35,8 +38,32 @@
"name": "#%%\n"
}
},
"execution_count": 1,
"outputs": []
"execution_count": 13,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"Bad key \"text.kerning_factor\" on line 4 in\n",
"/home/fabiana/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.\n",
"You probably need to get an updated matplotlibrc file from\n",
"https://github.com/matplotlib/matplotlib/blob/v3.1.3/matplotlibrc.template\n",
"or from the matplotlib source distribution\n"
]
},
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'preprocessing.dataset'",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
"\u001B[0;32m<ipython-input-13-5f27d42f34a9>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[1;32m 13\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 14\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0mmodels\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgan\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmodel\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mGAN\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 15\u001B[0;31m \u001B[0;32mfrom\u001B[0m \u001B[0mpreprocessing\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdataset\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0;34m*\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m",
"\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'preprocessing.dataset'"
]
}
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -83,6 +110,10 @@
}
],
"source": [
"#Before training the GAN do not forget to apply the required data transformations\n",
"#Such as Log transformation to some of the variables and Normalization such as MinMax.\n",
"\n",
"\n",
"#For the purpose of this example we will only synthesize the minority class\n",
"train_data = data.loc[ data['Class']==1 ].copy()\n",
"\n",
Expand Down Expand Up @@ -115,7 +146,7 @@
"batch_size = 128\n",
"\n",
"log_step = 100\n",
"epochs = 1000\n",
"epochs = 5000+1\n",
"learning_rate = 5e-4\n",
"models_dir = './cache'\n",
"\n",
Expand Down Expand Up @@ -1217,7 +1248,71 @@
"output_type": "execute_result"
}
],
"source": [],
"source": [
"#Training results visualization\n",
"#Adapt this code\n",
"model_steps = [ 0, 200, 500, 1000, 5000]\n",
"rows = len(model_steps)\n",
"columns = 5\n",
"\n",
"axarr = [[]]*len(model_steps)\n",
"\n",
"fig = plt.figure(figsize=(14,rows*3))\n",
"\n",
"for model_step_ix, model_step in enumerate( model_steps ):\n",
" print(model_step)\n",
" \n",
" axarr[model_step_ix] = plt.subplot(rows, columns, model_step_ix*columns + 1)\n",
" \n",
" for group, color, marker, label in zip(real_samples.groupby('Class_1'), colors, markers, class_labels ):\n",
" plt.scatter( group[1][[col1]], group[1][[col2]], \n",
" label=label, marker=marker, edgecolors=color, facecolors='none' )\n",
" \n",
" plt.title('Actual Fraud Data')\n",
" plt.ylabel(col2) # Only add y label to left plot\n",
" plt.xlabel(col1)\n",
" xlims, ylims = axarr[model_step_ix].get_xlim(), axarr[model_step_ix].get_ylim()\n",
" \n",
" if model_step_ix == 0: \n",
" legend = plt.legend()\n",
" legend.get_frame().set_facecolor('white')\n",
" \n",
" for i, model_name in enumerate( model_names[:] ):\n",
" \n",
" [ model_name, with_class, type0, generator_model ] = models[model_name]\n",
" \n",
" generator_model.load_weights( base_dir + model_name + '_generator_model_weights_step_'+str(model_step)+'.h5')\n",
"\n",
" ax = plt.subplot(rows, columns, model_step_ix*columns + 1 + (i+1) )\n",
" \n",
" if with_class:\n",
" g_z = generator_model.predict([z, labels])\n",
" gen_samples = pd.DataFrame(g_z, columns=data_cols+label_cols)\n",
" for group, color, marker, label in zip( gen_samples.groupby('Class_1'), colors, markers, class_labels ):\n",
" plt.scatter( group[1][[col1]], group[1][[col2]], \n",
" label=label, marker=marker, edgecolors=color, facecolors='none' )\n",
" else:\n",
" g_z = generator_model.predict(z)\n",
" gen_samples = pd.DataFrame(g_z, columns=data_cols)\n",
" gen_samples.to_csv('Generated_sample.csv')\n",
" plt.scatter( gen_samples[[col1]], gen_samples[[col2]], \n",
" label=class_labels[0], marker=markers[0], edgecolors=colors[0], facecolors='none' )\n",
" plt.title(model_name) \n",
" plt.xlabel(data_cols[0])\n",
" ax.set_xlim(xlims), ax.set_ylim(ylims)\n",
"\n",
"\n",
"plt.suptitle('Comparison of GAN outputs', size=16, fontweight='bold')\n",
"plt.tight_layout(rect=[0.075,0,1,0.95])\n",
"\n",
"# Adding text labels for traning steps\n",
"vpositions = np.array([ i._position.bounds[1] for i in axarr ])\n",
"vpositions += ((vpositions[0] - vpositions[1]) * 0.35 )\n",
"for model_step_ix, model_step in enumerate( model_steps ):\n",
" fig.text( 0.05, vpositions[model_step_ix], 'training\\nstep\\n'+str(model_step), ha='center', va='center', size=12)\n",
"\n",
"plt.savefig('Comparison_of_GAN_outputs.png')"
],
"metadata": {
"collapsed": false,
"pycharm": {
Expand Down
1 change: 1 addition & 0 deletions preprocessing/credit_fraud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#Include here the needed transformations
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pandas==1.0.3
numpy==1.17.4
scikit-learn==0.22.2
matplotlib

tensorflow==2.1.0

0 comments on commit 0864c90

Please sign in to comment.