Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Toru Shimizu
committed
May 31, 2018
1 parent
10afa95
commit da9b292
Showing
1 changed file
with
281 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,281 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import pickle\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import glob\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import tensorflow as tf\n", | ||
"\n", | ||
"%matplotlib inline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# data import\n", | ||
"data_dir = \"cars_data/\"\n", | ||
"train_x = np.array(pickle.load(open(data_dir + \"train_feature.pkl\", \"br\")), dtype=\"float32\")\n", | ||
"train_y = np.array(pickle.load(open(data_dir + \"train_label.pkl\", \"br\")), dtype=\"float32\")\n", | ||
"valid_x = np.array(pickle.load(open(data_dir + \"valid_feature.pkl\", \"br\")), dtype=\"float32\")\n", | ||
"valid_y = np.array(pickle.load(open(data_dir + \"valid_label.pkl\", \"br\")), dtype=\"float32\")\n", | ||
"test_x = np.array(pickle.load(open(data_dir + \"test_feature.pkl\", \"br\")), dtype=\"float32\")\n", | ||
"test_y = np.array(pickle.load(open(data_dir + \"test_label.pkl\", \"br\")), dtype=\"float32\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_x[:10]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_y[:10]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"len(train_x)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"len(valid_x)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"len(test_x)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# define placeholder tensors which work as \"entrances\" of the data\n", | ||
"x1 = tf.placeholder(dtype=tf.float32, shape=[None, None])\n", | ||
"y = tf.placeholder(dtype=tf.int32, shape=[None,])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# define parameters for the 1st layer\n", | ||
"W1 = tf.Variable(tf.random_normal([6, 20], stddev=0.05), dtype=tf.float32)\n", | ||
"b1 = tf.Variable(tf.zeros((1, 20)), dtype=tf.float32)\n", | ||
"# tensor calculation for the 1st layer\n", | ||
"a2 = tf.matmul(x1, W1) + b1\n", | ||
"x2 = tf.sigmoid(a2)\n", | ||
"\n", | ||
"# define parameters for the 2nd layer\n", | ||
"W2 = tf.Variable(tf.random_normal([20, 8], stddev=0.05), dtype=tf.float32)\n", | ||
"b2 = tf.Variable(tf.zeros((1, 8)), dtype=tf.float32)\n", | ||
"# tensor calculation for the 2nd layer\n", | ||
"a3 = tf.matmul(x2, W2) + b2\n", | ||
"x3 = tf.sigmoid(a3)\n", | ||
"\n", | ||
"# the 3rd layer\n", | ||
"W3 = tf.Variable(tf.random_normal([8, 4], stddev=0.05), dtype=tf.float32)\n", | ||
"b3 = tf.Variable(tf.zeros((1, 4)), dtype=tf.float32)\n", | ||
"logits = tf.matmul(x3, W3) + b3" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# the average cost for records in the batch:\n", | ||
"# tf calculates it with the logits and y labels\n", | ||
"cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)\n", | ||
"cost = tf.reduce_mean(cross_entropy)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# define the train op\n", | ||
"alpha = 1.\n", | ||
"train_op = tf.train.GradientDescentOptimizer(alpha).minimize(cost)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# prepare a dir to store model parameters\n", | ||
"if not os.path.isdir(\"cars_models\"):\n", | ||
" ! mkdir cars_models\n", | ||
"saver = tf.train.Saver(max_to_keep=None)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def evaluation(sess, d_x, d_y):\n", | ||
" raw_results = sess.run(logits, feed_dict={x1: d_x})\n", | ||
" results = np.argmax(raw_results, axis=1)\n", | ||
" match_count = np.sum(d_y == results)\n", | ||
" accuracy = float(match_count) / float(d_y.size)\n", | ||
" \n", | ||
" return accuracy" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# training\n", | ||
"\n", | ||
"total_epochs = 2000\n", | ||
"cost_history = list()\n", | ||
"valid_history = list()\n", | ||
"best_iter = 0\n", | ||
"test_accuracy = 0.\n", | ||
"\n", | ||
"# clear the past model dir\n", | ||
"if len(glob.glob(\"cars_models/model.*\")) > 0:\n", | ||
" ! rm cars_models/model.*\n", | ||
" \n", | ||
"# setup a session\n", | ||
"with tf.Session() as sess:\n", | ||
" sess.run(tf.global_variables_initializer())\n", | ||
" \n", | ||
" for itr in range(1, total_epochs + 1):\n", | ||
" fd = {x1: train_x, y: train_y}\n", | ||
" cost_val, t = sess.run((cost, train_op), feed_dict=fd)\n", | ||
" \n", | ||
" # store and display the status once in ten iterations\n", | ||
" if itr % 10 == 0:\n", | ||
" cost_history.append((itr, cost_val))\n", | ||
" print(\"itr: {}, cost: {:.3f}\".format(itr, cost_val))\n", | ||
" \n", | ||
" # validation\n", | ||
" if itr % 50 == 0:\n", | ||
" acc = evaluation(sess, valid_x, valid_y)\n", | ||
" valid_history.append((itr, acc))\n", | ||
" print(\"validation accuracy: {:.3f}\".format(acc))\n", | ||
" print(\"\")\n", | ||
" saver.save(sess, \"cars_models/model.{}\".format(itr))\n", | ||
" \n", | ||
" itr += 1\n", | ||
" \n", | ||
" # test\n", | ||
" sorted_valid_history = sorted(valid_history, key=lambda x: x[1], reverse=True)\n", | ||
" best_itr = sorted_valid_history[0][0]\n", | ||
" best_acc = sorted_valid_history[0][1]\n", | ||
" print(\"best validation performance was {:.3f} at iter {}\".format(best_acc, best_itr))\n", | ||
" \n", | ||
" # restore the model and evaluate it with the test data\n", | ||
" saver.restore(sess, \"cars_models/model.{}\".format(best_itr))\n", | ||
" test_accuracy = evaluation(sess, test_x, test_y)\n", | ||
" print(\"test accuracy: {:.3f}\".format(test_accuracy))\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# visualize the result\n", | ||
"\n", | ||
"plt.figure(figsize=(10, 5))\n", | ||
"ax1 = plt.subplot(211)\n", | ||
"ax2 = plt.subplot(212)\n", | ||
"\n", | ||
"plt.subplots_adjust(hspace=.4)\n", | ||
"\n", | ||
"train_x_val = [t[0] for t in cost_history]\n", | ||
"train_y_val = [t[1] for t in cost_history]\n", | ||
"valid_x_val = [t[0] for t in valid_history]\n", | ||
"valid_y_val = [t[1] for t in valid_history]\n", | ||
"\n", | ||
"ax1.plot(train_x_val, train_y_val, color=\"green\")\n", | ||
"ax2.plot(valid_x_val, valid_y_val)\n", | ||
"ax2.hlines(test_accuracy, xmin=train_x_val[0], xmax=train_x_val[-1], color=\"red\", linestyles=\"--\")\n", | ||
"ax2.vlines(best_itr, ymin=np.min(valid_y_val), ymax=np.max(valid_y_val), color=\"blue\", linestyles=\":\")\n", | ||
"\n", | ||
"plt.title(\"learning curve for the cars training\")\n", | ||
"ax1.set_title(\"cost\")\n", | ||
"ax2.set_title(\"accuracy\")\n", | ||
"ax2.set_xlabel(\"# of iterations\");" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |