Skip to content

Commit

Permalink
add tf_learn_cars
Browse files Browse the repository at this point in the history
  • Loading branch information
Toru Shimizu committed May 31, 2018
1 parent 10afa95 commit da9b292
Showing 1 changed file with 281 additions and 0 deletions.
281 changes: 281 additions & 0 deletions tf_learn_cars.ipynb
@@ -0,0 +1,281 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"import glob\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# data import\n",
"data_dir = \"cars_data/\"\n",
"train_x = np.array(pickle.load(open(data_dir + \"train_feature.pkl\", \"br\")), dtype=\"float32\")\n",
"train_y = np.array(pickle.load(open(data_dir + \"train_label.pkl\", \"br\")), dtype=\"float32\")\n",
"valid_x = np.array(pickle.load(open(data_dir + \"valid_feature.pkl\", \"br\")), dtype=\"float32\")\n",
"valid_y = np.array(pickle.load(open(data_dir + \"valid_label.pkl\", \"br\")), dtype=\"float32\")\n",
"test_x = np.array(pickle.load(open(data_dir + \"test_feature.pkl\", \"br\")), dtype=\"float32\")\n",
"test_y = np.array(pickle.load(open(data_dir + \"test_label.pkl\", \"br\")), dtype=\"float32\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_x[:10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_y[:10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"len(train_x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"len(valid_x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"len(test_x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define placeholder tensors which work as \"entrances\" of the data\n",
"x1 = tf.placeholder(dtype=tf.float32, shape=[None, None])\n",
"y = tf.placeholder(dtype=tf.int32, shape=[None,])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define parameters for the 1st layer\n",
"W1 = tf.Variable(tf.random_normal([6, 20], stddev=0.05), dtype=tf.float32)\n",
"b1 = tf.Variable(tf.zeros((1, 20)), dtype=tf.float32)\n",
"# tensor calculation for the 1st layer\n",
"a2 = tf.matmul(x1, W1) + b1\n",
"x2 = tf.sigmoid(a2)\n",
"\n",
"# define parameters for the 2nd layer\n",
"W2 = tf.Variable(tf.random_normal([20, 8], stddev=0.05), dtype=tf.float32)\n",
"b2 = tf.Variable(tf.zeros((1, 8)), dtype=tf.float32)\n",
"# tensor calculation for the 2nd layer\n",
"a3 = tf.matmul(x2, W2) + b2\n",
"x3 = tf.sigmoid(a3)\n",
"\n",
"# the 3rd layer\n",
"W3 = tf.Variable(tf.random_normal([8, 4], stddev=0.05), dtype=tf.float32)\n",
"b3 = tf.Variable(tf.zeros((1, 4)), dtype=tf.float32)\n",
"logits = tf.matmul(x3, W3) + b3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# the average cost for records in the batch:\n",
"# tf calculates it with the logits and y labels\n",
"cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)\n",
"cost = tf.reduce_mean(cross_entropy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define the train op\n",
"alpha = 1.\n",
"train_op = tf.train.GradientDescentOptimizer(alpha).minimize(cost)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# prepare a dir to store model parameters\n",
"if not os.path.isdir(\"cars_models\"):\n",
" ! mkdir cars_models\n",
"saver = tf.train.Saver(max_to_keep=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def evaluation(sess, d_x, d_y):\n",
" raw_results = sess.run(logits, feed_dict={x1: d_x})\n",
" results = np.argmax(raw_results, axis=1)\n",
" match_count = np.sum(d_y == results)\n",
" accuracy = float(match_count) / float(d_y.size)\n",
" \n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# training\n",
"\n",
"total_epochs = 2000\n",
"cost_history = list()\n",
"valid_history = list()\n",
"best_iter = 0\n",
"test_accuracy = 0.\n",
"\n",
"# clear the past model dir\n",
"if len(glob.glob(\"cars_models/model.*\")) > 0:\n",
" ! rm cars_models/model.*\n",
" \n",
"# setup a session\n",
"with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" \n",
" for itr in range(1, total_epochs + 1):\n",
" fd = {x1: train_x, y: train_y}\n",
" cost_val, t = sess.run((cost, train_op), feed_dict=fd)\n",
" \n",
" # store and display the status once in ten iterations\n",
" if itr % 10 == 0:\n",
" cost_history.append((itr, cost_val))\n",
" print(\"itr: {}, cost: {:.3f}\".format(itr, cost_val))\n",
" \n",
" # validation\n",
" if itr % 50 == 0:\n",
" acc = evaluation(sess, valid_x, valid_y)\n",
" valid_history.append((itr, acc))\n",
" print(\"validation accuracy: {:.3f}\".format(acc))\n",
" print(\"\")\n",
" saver.save(sess, \"cars_models/model.{}\".format(itr))\n",
" \n",
" itr += 1\n",
" \n",
" # test\n",
" sorted_valid_history = sorted(valid_history, key=lambda x: x[1], reverse=True)\n",
" best_itr = sorted_valid_history[0][0]\n",
" best_acc = sorted_valid_history[0][1]\n",
" print(\"best validation performance was {:.3f} at iter {}\".format(best_acc, best_itr))\n",
" \n",
" # restore the model and evaluate it with the test data\n",
" saver.restore(sess, \"cars_models/model.{}\".format(best_itr))\n",
" test_accuracy = evaluation(sess, test_x, test_y)\n",
" print(\"test accuracy: {:.3f}\".format(test_accuracy))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# visualize the result\n",
"\n",
"plt.figure(figsize=(10, 5))\n",
"ax1 = plt.subplot(211)\n",
"ax2 = plt.subplot(212)\n",
"\n",
"plt.subplots_adjust(hspace=.4)\n",
"\n",
"train_x_val = [t[0] for t in cost_history]\n",
"train_y_val = [t[1] for t in cost_history]\n",
"valid_x_val = [t[0] for t in valid_history]\n",
"valid_y_val = [t[1] for t in valid_history]\n",
"\n",
"ax1.plot(train_x_val, train_y_val, color=\"green\")\n",
"ax2.plot(valid_x_val, valid_y_val)\n",
"ax2.hlines(test_accuracy, xmin=train_x_val[0], xmax=train_x_val[-1], color=\"red\", linestyles=\"--\")\n",
"ax2.vlines(best_itr, ymin=np.min(valid_y_val), ymax=np.max(valid_y_val), color=\"blue\", linestyles=\":\")\n",
"\n",
"plt.title(\"learning curve for the cars training\")\n",
"ax1.set_title(\"cost\")\n",
"ax2.set_title(\"accuracy\")\n",
"ax2.set_xlabel(\"# of iterations\");"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit da9b292

Please sign in to comment.