diff --git a/example.ipynb b/example.ipynb index effe9574..1e0a8a5a 100644 --- a/example.ipynb +++ b/example.ipynb @@ -22,11 +22,14 @@ "import pandas as pd\n", "import numpy as np\n", "import sklearn.cluster as cluster\n", + "import matplotlib.pyplot as plt\n", "\n", "from models.gan import model\n", "importlib.reload(model)\n", "\n", "from models.gan.model import GAN\n", + "from preprocessing.credit_fraud import *\n", + "\n", "model = GAN" ], "metadata": { @@ -35,8 +38,32 @@ "name": "#%%\n" } }, - "execution_count": 1, - "outputs": [] + "execution_count": 13, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Bad key \"text.kerning_factor\" on line 4 in\n", + "/home/fabiana/anaconda3/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.\n", + "You probably need to get an updated matplotlibrc file from\n", + "https://github.com/matplotlib/matplotlib/blob/v3.1.3/matplotlibrc.template\n", + "or from the matplotlib source distribution\n" + ] + }, + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'preprocessing.dataset'", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)", + "\u001B[0;32m\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 13\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 14\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0mmodels\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgan\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmodel\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mGAN\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 15\u001B[0;31m \u001B[0;32mfrom\u001B[0m \u001B[0mpreprocessing\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdataset\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0;34m*\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m", + "\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'preprocessing.dataset'" + ] + } + ] }, { "cell_type": "code", @@ -83,6 +110,10 @@ } ], "source": [ + "#Before training the GAN do not forget to apply the required data transformations\n", + "#Such as Log transformation to some of the variables and Normalization such as MinMax.\n", + "\n", + "\n", "#For the purpose of this example we will only synthesize the minority class\n", "train_data = data.loc[ data['Class']==1 ].copy()\n", "\n", @@ -115,7 +146,7 @@ "batch_size = 128\n", "\n", "log_step = 100\n", - "epochs = 1000\n", + "epochs = 5000+1\n", "learning_rate = 5e-4\n", "models_dir = './cache'\n", "\n", @@ -1217,7 +1248,71 @@ "output_type": "execute_result" } ], - "source": [], + "source": [ + "#Training results visualization\n", + "#Adapt this code\n", + "model_steps = [ 0, 200, 500, 1000, 5000]\n", + "rows = len(model_steps)\n", + "columns = 5\n", + "\n", + "axarr = [[]]*len(model_steps)\n", + "\n", + "fig = plt.figure(figsize=(14,rows*3))\n", + "\n", + "for model_step_ix, model_step in enumerate( model_steps ):\n", + " print(model_step)\n", + " \n", + " axarr[model_step_ix] = plt.subplot(rows, columns, model_step_ix*columns + 1)\n", + " \n", + " for group, color, marker, label in zip(real_samples.groupby('Class_1'), colors, markers, class_labels ):\n", + " plt.scatter( group[1][[col1]], group[1][[col2]], \n", + " label=label, marker=marker, edgecolors=color, facecolors='none' )\n", + " \n", + " plt.title('Actual Fraud Data')\n", + " plt.ylabel(col2) # Only add y label to left plot\n", + " plt.xlabel(col1)\n", + " xlims, ylims = axarr[model_step_ix].get_xlim(), axarr[model_step_ix].get_ylim()\n", + " \n", + " if model_step_ix == 0: \n", + " legend = plt.legend()\n", + " legend.get_frame().set_facecolor('white')\n", + " \n", + " for i, model_name in enumerate( model_names[:] ):\n", + " \n", + " [ model_name, with_class, type0, generator_model ] = models[model_name]\n", + " \n", + " generator_model.load_weights( base_dir + model_name + '_generator_model_weights_step_'+str(model_step)+'.h5')\n", + "\n", + " ax = plt.subplot(rows, columns, model_step_ix*columns + 1 + (i+1) )\n", + " \n", + " if with_class:\n", + " g_z = generator_model.predict([z, labels])\n", + " gen_samples = pd.DataFrame(g_z, columns=data_cols+label_cols)\n", + " for group, color, marker, label in zip( gen_samples.groupby('Class_1'), colors, markers, class_labels ):\n", + " plt.scatter( group[1][[col1]], group[1][[col2]], \n", + " label=label, marker=marker, edgecolors=color, facecolors='none' )\n", + " else:\n", + " g_z = generator_model.predict(z)\n", + " gen_samples = pd.DataFrame(g_z, columns=data_cols)\n", + " gen_samples.to_csv('Generated_sample.csv')\n", + " plt.scatter( gen_samples[[col1]], gen_samples[[col2]], \n", + " label=class_labels[0], marker=markers[0], edgecolors=colors[0], facecolors='none' )\n", + " plt.title(model_name) \n", + " plt.xlabel(data_cols[0])\n", + " ax.set_xlim(xlims), ax.set_ylim(ylims)\n", + "\n", + "\n", + "plt.suptitle('Comparison of GAN outputs', size=16, fontweight='bold')\n", + "plt.tight_layout(rect=[0.075,0,1,0.95])\n", + "\n", + "# Adding text labels for traning steps\n", + "vpositions = np.array([ i._position.bounds[1] for i in axarr ])\n", + "vpositions += ((vpositions[0] - vpositions[1]) * 0.35 )\n", + "for model_step_ix, model_step in enumerate( model_steps ):\n", + " fig.text( 0.05, vpositions[model_step_ix], 'training\\nstep\\n'+str(model_step), ha='center', va='center', size=12)\n", + "\n", + "plt.savefig('Comparison_of_GAN_outputs.png')" + ], "metadata": { "collapsed": false, "pycharm": { diff --git a/preprocessing/credit_fraud.py b/preprocessing/credit_fraud.py new file mode 100644 index 00000000..dd638dea --- /dev/null +++ b/preprocessing/credit_fraud.py @@ -0,0 +1 @@ +#Include here the needed transformations \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index cfa8f7ac..c5c5894d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ pandas==1.0.3 numpy==1.17.4 scikit-learn==0.22.2 +matplotlib tensorflow==2.1.0 \ No newline at end of file