In [2]:
import tensorflow as tf
import numpy as np

In [3]:
def reset_graph(seed=42):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

In [4]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.astype(np.float32).reshape(-1, 28*28) / 255.0
X_test = X_test.astype(np.float32).reshape(-1, 28*28) / 255.0
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)
X_valid, X_train = X_train[:5000], X_train[5000:]
y_valid, y_train = y_train[:5000], y_train[5000:]

## TensorFlowの高水準APIを使用したMLPの訓練

In [5]:
feature_cols = [tf.feature_column.numeric_column("X", shape=[28 * 28])]
dnn_clf = tf.estimator.DNNClassifier(hidden_units=[300,100], n_classes=10,
                                     feature_columns=feature_cols)

input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"X": X_train}, y=y_train, num_epochs=40, batch_size=50, shuffle=True)
dnn_clf.train(input_fn=input_fn)

W0828 17:42:18.963798 4543530432 estimator.py:1811] Using temporary folder as model directory: /var/folders/jk/wsby4q_94gz992ztwjsvztq00000gn/T/tmp7ulywb35
W0828 17:42:19.001497 4543530432 deprecation.py:323] From /Users/shimizukousuke/.pyenv/versions/3.6.5/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
W0828 17:42:19.014670 4543530432 deprecation.py:323] From /Users/shimizukousuke/.pyenv/versions/3.6.5/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:62: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines

<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x10f665d30>

In [6]:
test_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"X": X_test}, y=y_test, shuffle=False)
eval_results = dnn_clf.evaluate(input_fn=test_input_fn)

W0828 17:47:29.535653 4543530432 deprecation.py:323] From /Users/shimizukousuke/.pyenv/versions/3.6.5/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


In [7]:
eval_results

{'accuracy': 0.9779,
 'average_loss': 0.11568203,
 'loss': 14.643295,
 'global_step': 44000}

## 標準のTensorFlowを使用したDNNの訓練
とりあえず、ミニバッチ勾配下降法でMINISTのデータセットを対象に訓練を行う

### 構築フェーズ

In [19]:
n_inputs = 28 * 28 # MINIST
n_hidden1 = 300
n_hidden2 = 100
n_out_puts = 10

In [20]:
X = tf.placeholder(tf.float32, shape=(None, n_inputs), name='X')
y = tf.placeholder(tf.int64, shape=(None), name='y')

In [24]:
def neuron_layer(X, n_neurons, name, activation=None):
    with tf.name_scope(name):
        n_inputs = int(X.get_shape()[1])
        stddev = 2 / np.sqrt(n_inputs + n_neurons)
        init = tf.truncated_normal((n_inputs, n_neurons), stddev=stddev)
        W = tf.Variable(init, name='kernel')
        b = tf.Variable(tf.zeros([n_neurons]), name='bias')
        Z = tf.matmul(X, W) + b
        if activation is not None:
            return activation(Z)
        else:
            return Z

In [25]:
with tf.name_scope('dnn'):
    hidden1 = neuron_layer(X, n_hidden1, name='hidden1', activation=tf.nn.relu)
    hidden2 = neuron_layer(hidden1, n_hidden2, name='hidden2', activation=tf.nn.relu)
    logits = neuron_layer(hidden2, n_out_puts, name='outputs')

In [26]:
# neuron_layerを自作しなくてもTensorFlowで用意されている
# with tf.name_scope('dnn'):
#     hidden1 = tf.layers.dense(X, n_hidden1, name='hidden1', activation=tf.nn.relu)
#     hidden2 = tf.layers.dense(hidden1, n_hidden2, name='hidden2', activation=tf.nn.relu)
#     logits = tf.layers.dense(hidden2, n_out_puts, name='outputs')