From f197ecc4c1940e96001d6c8882484c76ca9f140a Mon Sep 17 00:00:00 2001 From: Lee Yang Date: Thu, 29 Mar 2018 10:10:44 -0700 Subject: [PATCH] mnist estimator example --- examples/mnist/estimator/README.md | 42 +++++ examples/mnist/estimator/mnist_estimator.py | 186 ++++++++++++++++++++ 2 files changed, 228 insertions(+) create mode 100644 examples/mnist/estimator/README.md create mode 100644 examples/mnist/estimator/mnist_estimator.py diff --git a/examples/mnist/estimator/README.md b/examples/mnist/estimator/README.md new file mode 100644 index 00000000..a35a7051 --- /dev/null +++ b/examples/mnist/estimator/README.md @@ -0,0 +1,42 @@ +# MNIST using tf.estimator with tf.layers + +Original Source: https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/examples/tutorials/layers/cnn_mnist.py + +This is the `tf.estimator` version of MNIST from TensorFlow's [tutorial on layers and estimators](https://www.tensorflow.org/versions/master/tutorials/layers), adapted for TensorFlowOnSpark. + +Notes: +- This example assumes that Spark, TensorFlow, and TensorFlowOnSpark are already installed. +- To minimize code changes, this example uses InputMode.TENSORFLOW. + +#### Launch the Spark Standalone cluster + + export MASTER=spark://$(hostname):7077 + export SPARK_WORKER_INSTANCES=3 + export CORES_PER_WORKER=1 + export TOTAL_CORES=$((${CORES_PER_WORKER}*${SPARK_WORKER_INSTANCES})) + export TFoS_HOME= + + ${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-slave.sh -c $CORES_PER_WORKER -m 3G ${MASTER} + +#### Run MNIST using InputMode.TENSORFLOW + +In this mode, each worker will load the entire MNIST dataset into memory (automatically downloading the dataset if needed). + + # remove any old artifacts + rm -rf ${TFoS_HOME}/mnist_model + + # train and validate + ${SPARK_HOME}/bin/spark-submit \ + --master ${MASTER} \ + --conf spark.cores.max=${TOTAL_CORES} \ + --conf spark.task.cpus=${CORES_PER_WORKER} \ + --conf spark.task.maxFailures=1 \ + --conf spark.executorEnv.JAVA_HOME="$JAVA_HOME" \ + ${TFoS_HOME}/examples/mnist/estimator/mnist_estimator.py \ + --cluster_size ${SPARK_WORKER_INSTANCES} \ + --model ${TFoS_HOME}/mnist_model + +#### Shutdown the Spark Standalone cluster + + ${SPARK_HOME}/sbin/stop-slave.sh; ${SPARK_HOME}/sbin/stop-master.sh + diff --git a/examples/mnist/estimator/mnist_estimator.py b/examples/mnist/estimator/mnist_estimator.py new file mode 100644 index 00000000..9aef2d22 --- /dev/null +++ b/examples/mnist/estimator/mnist_estimator.py @@ -0,0 +1,186 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convolutional Neural Network Estimator for MNIST, built with tf.layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +tf.logging.set_verbosity(tf.logging.INFO) + + +def cnn_model_fn(features, labels, mode): + """Model function for CNN.""" + # Input Layer + # Reshape X to 4-D tensor: [batch_size, width, height, channels] + # MNIST images are 28x28 pixels, and have one color channel + input_layer = tf.reshape(features["x"], [-1, 28, 28, 1]) + + # Convolutional Layer #1 + # Computes 32 features using a 5x5 filter with ReLU activation. + # Padding is added to preserve width and height. + # Input Tensor Shape: [batch_size, 28, 28, 1] + # Output Tensor Shape: [batch_size, 28, 28, 32] + conv1 = tf.layers.conv2d( + inputs=input_layer, + filters=32, + kernel_size=[5, 5], + padding="same", + activation=tf.nn.relu) + + # Pooling Layer #1 + # First max pooling layer with a 2x2 filter and stride of 2 + # Input Tensor Shape: [batch_size, 28, 28, 32] + # Output Tensor Shape: [batch_size, 14, 14, 32] + pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) + + # Convolutional Layer #2 + # Computes 64 features using a 5x5 filter. + # Padding is added to preserve width and height. + # Input Tensor Shape: [batch_size, 14, 14, 32] + # Output Tensor Shape: [batch_size, 14, 14, 64] + conv2 = tf.layers.conv2d( + inputs=pool1, + filters=64, + kernel_size=[5, 5], + padding="same", + activation=tf.nn.relu) + + # Pooling Layer #2 + # Second max pooling layer with a 2x2 filter and stride of 2 + # Input Tensor Shape: [batch_size, 14, 14, 64] + # Output Tensor Shape: [batch_size, 7, 7, 64] + pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) + + # Flatten tensor into a batch of vectors + # Input Tensor Shape: [batch_size, 7, 7, 64] + # Output Tensor Shape: [batch_size, 7 * 7 * 64] + pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64]) + + # Dense Layer + # Densely connected layer with 1024 neurons + # Input Tensor Shape: [batch_size, 7 * 7 * 64] + # Output Tensor Shape: [batch_size, 1024] + dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu) + + # Add dropout operation; 0.6 probability that element will be kept + dropout = tf.layers.dropout( + inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN) + + # Logits layer + # Input Tensor Shape: [batch_size, 1024] + # Output Tensor Shape: [batch_size, 10] + logits = tf.layers.dense(inputs=dropout, units=10) + + predictions = { + # Generate predictions (for PREDICT and EVAL mode) + "classes": tf.argmax(input=logits, axis=1), + # Add `softmax_tensor` to the graph. It is used for PREDICT and by the + # `logging_hook`. + "probabilities": tf.nn.softmax(logits, name="softmax_tensor") + } + if mode == tf.estimator.ModeKeys.PREDICT: + return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) + + # Calculate Loss (for both TRAIN and EVAL modes) + loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) + + # Configure the Training Op (for TRAIN mode) + if mode == tf.estimator.ModeKeys.TRAIN: + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) + train_op = optimizer.minimize( + loss=loss, + global_step=tf.train.get_global_step()) + return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) + + # Add evaluation metrics (for EVAL mode) + eval_metric_ops = { + "accuracy": tf.metrics.accuracy( + labels=labels, predictions=predictions["classes"])} + return tf.estimator.EstimatorSpec( + mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) + + +def main(args, ctx): + # Load training and eval data + mnist = tf.contrib.learn.datasets.load_dataset("mnist") + train_data = mnist.train.images # Returns np.array + train_labels = np.asarray(mnist.train.labels, dtype=np.int32) + eval_data = mnist.test.images # Returns np.array + eval_labels = np.asarray(mnist.test.labels, dtype=np.int32) + + # Create the Estimator + mnist_classifier = tf.estimator.Estimator( + model_fn=cnn_model_fn, model_dir=args.model) + + # Set up logging for predictions + # Log the values in the "Softmax" tensor with label "probabilities" + tensors_to_log = {"probabilities": "softmax_tensor"} + logging_hook = tf.train.LoggingTensorHook( + tensors=tensors_to_log, every_n_iter=50) + + # Train the model + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": train_data}, + y=train_labels, + batch_size=args.batch_size, + num_epochs=None, + shuffle=True) + # mnist_classifier.train( + # input_fn=train_input_fn, + # steps=1000, + # hooks=[logging_hook]) + + # Evaluate the model and print results + eval_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": eval_data}, + y=eval_labels, + num_epochs=1, + shuffle=False) + # eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) + # print(eval_results) + + # Using tf.estimator.train_and_evaluate + train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=args.steps, hooks=[logging_hook]) + eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) + tf.estimator.train_and_evaluate(mnist_classifier, train_spec, eval_spec) + + +if __name__ == "__main__": + # tf.app.run() + + from pyspark.context import SparkContext + from pyspark.conf import SparkConf + from tensorflowonspark import TFCluster + import argparse + + sc = SparkContext(conf=SparkConf().setAppName("mnist_spark")) + executors = sc._conf.get("spark.executor.instances") + num_executors = int(executors) if executors is not None else 1 + + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100) + parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors) + parser.add_argument("--model", help="HDFS path to save/load model during train/inference", default="mnist_model") + parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions") + parser.add_argument("--num_ps", help="number of PS nodes in cluster", type=int, default=1) + parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000) + args = parser.parse_args() + print("args:", args) + + cluster = TFCluster.run(sc, main, args, args.cluster_size, args.num_ps, tensorboard=False, input_mode=TFCluster.InputMode.TENSORFLOW, log_dir=args.model, master_node='master') + cluster.shutdown()