diff --git a/examples/mnist/tf/README.md b/examples/mnist/tf/README.md index 39833696..3500eafe 100644 --- a/examples/mnist/tf/README.md +++ b/examples/mnist/tf/README.md @@ -9,39 +9,6 @@ # hdfs dfs -rm -r mnist_model # hdfs dfs -rm -r predictions -${SPARK_HOME}/bin/spark-submit \ ---master yarn \ ---deploy-mode cluster \ ---queue ${QUEUE} \ ---num-executors 4 \ ---executor-memory 27G \ ---py-files TensorFlowOnSpark/tfspark.zip,TensorFlowOnSpark/examples/mnist/tf/mnist_dist_dataset.py \ ---conf spark.dynamicAllocation.enabled=false \ ---conf spark.yarn.maxAppAttempts=1 \ ---archives hdfs:///user/${USER}/Python.zip#Python \ ---conf spark.executorEnv.LD_LIBRARY_PATH=$LIB_CUDA:$LIB_JVM:$LIB_HDFS \ ---driver-library-path=$LIB_CUDA \ -TensorFlowOnSpark/examples/mnist/tf/mnist_spark_dataset.py \ -${TF_ROOT}/${TF_VERSION}/examples/mnist/tf/mnist_spark_dataset.py \ ---images_labels mnist/csv2/train \ ---format csv2 \ ---mode train \ ---model mnist_model - -# to use inference mode, change `--mode train` to `--mode inference` and add `--output predictions` -# one item in csv2 format is `image | label`, to use input data in TFRecord format, change `--format csv` to `--format tfr` -# to use infiniband, add `--rdma` -``` - -### _using QueueRunners_ -```bash -# for CPU mode: -# export QUEUE=default -# remove references to $LIB_CUDA - -# hdfs dfs -rm -r mnist_model -# hdfs dfs -rm -r predictions - ${SPARK_HOME}/bin/spark-submit \ --master yarn \ --deploy-mode cluster \ @@ -55,16 +22,14 @@ ${SPARK_HOME}/bin/spark-submit \ --conf spark.executorEnv.LD_LIBRARY_PATH=$LIB_CUDA:$LIB_JVM:$LIB_HDFS \ --driver-library-path=$LIB_CUDA \ TensorFlowOnSpark/examples/mnist/tf/mnist_spark.py \ ---images mnist/tfr/train/images \ ---labels mnist/tfr/train/labels \ ---format csv \ +--images_labels mnist/csv2/train \ +--format csv2 \ --mode train \ --model mnist_model # to use inference mode, change `--mode train` to `--mode inference` and add `--output predictions` -# to use input data in TFRecord format, change `--format csv` to `--format tfr` +# one item in csv2 format is `image | label`, to use input data in TFRecord format, change `--format csv` to `--format tfr` # to use infiniband, add `--rdma` -``` ### _using Spark ML Pipeline_ ```bash @@ -83,7 +48,7 @@ ${SPARK_HOME}/bin/spark-submit \ --queue ${QUEUE} \ --num-executors 4 \ --executor-memory 27G \ ---jars hdfs:///user/${USER}/tensorflow-hadoop-1.0-SNAPSHOT.jar \ +--jars hdfs:///user/${USER}/tensorflow-hadoop-1.0-SNAPSHOT.jar \ --py-files TensorFlowOnSpark/tfspark.zip,TensorFlowOnSpark/examples/mnist/tf/mnist_dist_pipeline.py \ --conf spark.dynamicAllocation.enabled=false \ --conf spark.yarn.maxAppAttempts=1 \ @@ -102,6 +67,6 @@ TensorFlowOnSpark/examples/mnist/tf/mnist_spark_pipeline.py \ --inference_output predictions # to use input data in TFRecord format, change `--format csv` to `--format tfr` -# tensorflow-hadoop-1.0-SNAPSHOT.jar is needed for transforming csv input to TFRecord +# tensorflow-hadoop-1.0-SNAPSHOT.jar is needed for transforming csv input to TFRecord # `--tfrecord_dir` is needed for temporarily saving dataframe to TFRecord on hdfs ``` diff --git a/examples/mnist/tf/mnist_dist.py b/examples/mnist/tf/mnist_dist.py index 2b1bdb1f..cc95e93d 100644 --- a/examples/mnist/tf/mnist_dist.py +++ b/examples/mnist/tf/mnist_dist.py @@ -16,6 +16,7 @@ def print_log(worker_num, arg): def map_fun(args, ctx): from datetime import datetime + from tensorflowonspark import TFNode import math import os import tensorflow as tf @@ -54,6 +55,27 @@ def _parse_tfr(example_proto): label = tf.to_float(features['label']) return (image, label) + def build_model(graph, x): + with graph.as_default(): + # Variables of the hidden layer + hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units], + stddev=1.0 / IMAGE_PIXELS), name="hid_w") + hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b") + tf.summary.histogram("hidden_weights", hid_w) + + # Variables of the softmax layer + sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10], + stddev=1.0 / math.sqrt(hidden_units)), name="sm_w") + sm_b = tf.Variable(tf.zeros([10]), name="sm_b") + tf.summary.histogram("softmax_weights", sm_w) + + hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) + hid = tf.nn.relu(hid_lin) + + y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) + prediction = tf.argmax(y, 1, name="prediction") + return y, prediction + if job_name == "ps": server.join() elif job_name == "worker": @@ -78,26 +100,13 @@ def _parse_tfr(example_proto): iterator = ds.make_one_shot_iterator() x, y_ = iterator.get_next() - # Variables of the hidden layer - hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units], - stddev=1.0 / IMAGE_PIXELS), name="hid_w") - hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b") - tf.summary.histogram("hidden_weights", hid_w) - - # Variables of the softmax layer - sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10], - stddev=1.0 / math.sqrt(hidden_units)), name="sm_w") - sm_b = tf.Variable(tf.zeros([10]), name="sm_b") - tf.summary.histogram("softmax_weights", sm_w) + # Build core model + y, prediction = build_model(tf.get_default_graph(), x) + # Add training bits x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1]) tf.summary.image("x_img", x_img) - hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) - hid = tf.nn.relu(hid_lin) - - y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) - global_step = tf.train.get_or_create_global_step() loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) @@ -105,9 +114,7 @@ def _parse_tfr(example_proto): train_op = tf.train.AdagradOptimizer(0.01).minimize( loss, global_step=global_step) - # Test trained model label = tf.argmax(y_, 1, name="label") - prediction = tf.argmax(y, 1, name="prediction") correct_prediction = tf.equal(prediction, label) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy") tf.summary.scalar("acc", accuracy) @@ -117,8 +124,10 @@ def _parse_tfr(example_proto): init_op = tf.global_variables_initializer() # Create a "supervisor", which oversees the training process and stores model state into HDFS - logdir = ctx.absolute_path(args.model) - print("tensorflow model path: {0}".format(logdir)) + model_dir = ctx.absolute_path(args.model) + export_dir = ctx.absolute_path(args.export) + print("tensorflow model path: {0}".format(model_dir)) + print("tensorflow export path: {0}".format(export_dir)) summary_writer = tf.summary.FileWriter("tensorboard_%d" % worker_num, graph=tf.get_default_graph()) if args.mode == 'inference': @@ -130,7 +139,7 @@ def _parse_tfr(example_proto): with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(task_index == 0), scaffold=tf.train.Scaffold(init_op=init_op, summary_op=summary_op, saver=saver), - checkpoint_dir=logdir, + checkpoint_dir=model_dir, hooks=[tf.train.StopAtStepHook(last_step=args.steps)]) as sess: print("{} session ready".format(datetime.now().isoformat())) @@ -163,6 +172,41 @@ def _parse_tfr(example_proto): print("{} stopping MonitoredTrainingSession".format(datetime.now().isoformat())) + # export model (on chief worker only) + if args.mode == "train" and task_index == 0: + tf.reset_default_graph() + + # add placeholders for input images (and optional labels) + x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name='x') + y_ = tf.placeholder(tf.float32, [None, 10], name='y_') + label = tf.argmax(y_, 1, name="label") + + # add core model + y, prediction = build_model(tf.get_default_graph(), x) + + # restore from last checkpoint + saver = tf.train.Saver() + with tf.Session() as sess: + ckpt = tf.train.get_checkpoint_state(model_dir) + print("ckpt: {}".format(ckpt)) + assert ckpt, "Invalid model checkpoint path: {}".format(model_dir) + saver.restore(sess, ckpt.model_checkpoint_path) + + print("Exporting saved_model to: {}".format(export_dir)) + # exported signatures defined in code + signatures = { + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: { + 'inputs': { 'image': x }, + 'outputs': { 'prediction': prediction }, + 'method_name': tf.saved_model.signature_constants.PREDICT_METHOD_NAME + } + } + TFNode.export_saved_model(sess, + export_dir, + tf.saved_model.tag_constants.SERVING, + signatures) + print("Exported saved_model") + # WORKAROUND for https://github.com/tensorflow/tensorflow/issues/21745 # wait for all other nodes to complete (via done files) done_dir = "{}/{}/done".format(ctx.absolute_path(args.model), args.mode) diff --git a/examples/mnist/tf/mnist_inference.py b/examples/mnist/tf/mnist_inference.py new file mode 100644 index 00000000..2ba508d6 --- /dev/null +++ b/examples/mnist/tf/mnist_inference.py @@ -0,0 +1,106 @@ +# Copyright 2018 Yahoo Inc. +# Licensed under the terms of the Apache 2.0 license. +# Please see LICENSE file in the project root for terms. + +# This example demonstrates how to leverage Spark for parallel inferencing from a SavedModel. +# +# Normally, you can use TensorFlowOnSpark to just form a TensorFlow cluster for training and inferencing. +# However, in some situations, you may have a SavedModel without the original code for defining the inferencing +# graph. In these situations, we can use Spark to instantiate a single-node TensorFlow instance on each executor, +# where each executor can independently load the model and inference on input data. +# +# Note: this particular example demonstrates use of `tf.data.Dataset` to read the input data for inferencing, +# but it could also be adapted to just use an RDD of TFRecords from Spark. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import logging +import sys +import tensorflow as tf +import time +import traceback + +IMAGE_PIXELS = 28 + +def inference(it, num_workers, args): + from tensorflowonspark import util + + # consume worker number from RDD partition iterator + for i in it: + worker_num = i + print("worker_num: {}".format(i)) + + # setup env for single-node TF + util.single_node_env() + + # load saved_model using default tag and signature + sess = tf.Session() + tf.saved_model.loader.load(sess, ['serve'], args.export) + + # parse function for TFRecords + def parse_tfr(example_proto): + feature_def = {"label": tf.FixedLenFeature(10, tf.int64), + "image": tf.FixedLenFeature(IMAGE_PIXELS * IMAGE_PIXELS, tf.int64)} + features = tf.parse_single_example(example_proto, feature_def) + norm = tf.constant(255, dtype=tf.float32, shape=(784,)) + image = tf.div(tf.to_float(features['image']), norm) + label = tf.to_float(features['label']) + return (image, label) + + # define a new tf.data.Dataset (for inferencing) + ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels)) + ds = ds.shard(num_workers, worker_num) + ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=1) + ds = ds.map(parse_tfr).batch(10) + iterator = ds.make_one_shot_iterator() + image_label = iterator.get_next(name='inf_image') + + # create an output file per spark worker for the predictions + tf.gfile.MakeDirs(args.output) + output_file = tf.gfile.GFile("{}/part-{:05d}".format(args.output, worker_num), mode='w') + + while True: + try: + # get images and labels from tf.data.Dataset + img, lbl = sess.run(['inf_image:0', 'inf_image:1']) + + # inference by feeding these images and labels into the input tensors + # you can view the exported model signatures via: + # saved_model_cli show --dir mnist_export --all + + # note that we feed directly into the graph tensors (bypassing the exported signatures) + # also note that we can feed/fetch tensors that were not explicitly exported, e.g. `y_` and `label:0` + + labels, preds = sess.run(['label:0', 'prediction:0'], feed_dict={'x:0': img, 'y_:0': lbl}) + for i in range(len(labels)): + output_file.write("{} {}\n".format(labels[i], preds[i])) + except tf.errors.OutOfRangeError: + break + + output_file.close() + +if __name__ == '__main__': + import os + from pyspark.context import SparkContext + from pyspark.conf import SparkConf + + sc = SparkContext(conf=SparkConf().setAppName("mnist_inference")) + executors = sc._conf.get("spark.executor.instances") + num_executors = int(executors) if executors is not None else 1 + + parser = argparse.ArgumentParser() + parser.add_argument("--cluster_size", help="number of nodes in the cluster (for S with labelspark Standalone)", type=int, default=num_executors) + parser.add_argument('--images_labels', type=str, help='Directory for input images with labels') + parser.add_argument("--export", help="HDFS path to export model", type=str, default="mnist_export") + parser.add_argument("--output", help="HDFS path to save predictions", type=str, default="predictions") + args, _ = parser.parse_known_args() + print("args: {}".format(args)) + + # Not using TFCluster... just running single-node TF instances on each executor + nodes = list(range(args.cluster_size)) + nodeRDD = sc.parallelize(list(range(args.cluster_size)), args.cluster_size) + nodeRDD.foreachPartition(lambda worker_num: inference(worker_num, args.cluster_size, args)) + diff --git a/examples/mnist/tf/mnist_spark.py b/examples/mnist/tf/mnist_spark.py index 2a25d6e3..83ce77fe 100644 --- a/examples/mnist/tf/mnist_spark.py +++ b/examples/mnist/tf/mnist_spark.py @@ -26,6 +26,7 @@ parser.add_argument("--driver_ps_nodes", help="""run tensorflow PS node on driver locally. You will need to set cluster_size = num_executors + num_ps""", default=False) parser.add_argument("--epochs", help="number of epochs", type=int, default=1) +parser.add_argument("--export", help="HDFS path to export model", type=str, default="mnist_export") parser.add_argument("--format", help="example format: (csv2|tfr)", choices=["csv2", "tfr"], default="tfr") parser.add_argument("--images_labels", help="HDFS path to MNIST image_label files in parallelized format") parser.add_argument("--mode", help="train|inference", default="train") diff --git a/tensorflowonspark/pipeline.py b/tensorflowonspark/pipeline.py index e721bd6b..06ad023b 100755 --- a/tensorflowonspark/pipeline.py +++ b/tensorflowonspark/pipeline.py @@ -25,7 +25,7 @@ import tensorflow as tf from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.python.saved_model import loader -from . import TFCluster, gpu_info, dfutil +from . import TFCluster, gpu_info, dfutil, util import argparse import copy @@ -570,32 +570,15 @@ def single_node_env(args): Args: :args: command line arguments as either argparse args or argv list """ + # setup ARGV for the TF process if isinstance(args, list): sys.argv = args elif args.argv: sys.argv = args.argv - # ensure expanded CLASSPATH w/o glob characters (required for Spark 2.1 + JNI) - if 'HADOOP_PREFIX' in os.environ and 'TFOS_CLASSPATH_UPDATED' not in os.environ: - classpath = os.environ['CLASSPATH'] - hadoop_path = os.path.join(os.environ['HADOOP_PREFIX'], 'bin', 'hadoop') - hadoop_classpath = subprocess.check_output([hadoop_path, 'classpath', '--glob']).decode() - logging.debug("CLASSPATH: {0}".format(hadoop_classpath)) - os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath - os.environ['TFOS_CLASSPATH_UPDATED'] = '1' - - # reserve GPU, if requested - if tf.test.is_built_with_cuda(): - # GPU - num_gpus = args.num_gpus if 'num_gpus' in args else 1 - gpus_to_use = gpu_info.get_gpus(num_gpus) - logging.info("Using gpu(s): {0}".format(gpus_to_use)) - os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use - # Note: if there is a GPU conflict (CUDA_ERROR_INVALID_DEVICE), the entire task will fail and retry. - else: - # CPU - logging.info("Using CPU") - os.environ['CUDA_VISIBLE_DEVICES'] = '' + # setup ENV for Hadoop-compatibility and/or GPU allocation + num_gpus = args.num_gpus if 'num_gpus' in args else 1 + util.single_node_env(num_gpus) def get_meta_graph_def(saved_model_dir, tag_set): diff --git a/tensorflowonspark/util.py b/tensorflowonspark/util.py index 11f1f65c..dc167226 100644 --- a/tensorflowonspark/util.py +++ b/tensorflowonspark/util.py @@ -7,11 +7,33 @@ from __future__ import nested_scopes from __future__ import print_function +import logging import os import socket import errno from socket import error as socket_error +def single_node_env(num_gpus=1): + """Setup environment variables for Hadoop compatibility and GPU allocation""" + import tensorflow as tf + # ensure expanded CLASSPATH w/o glob characters (required for Spark 2.1 + JNI) + if 'HADOOP_PREFIX' in os.environ and 'TFOS_CLASSPATH_UPDATED' not in os.environ: + classpath = os.environ['CLASSPATH'] + hadoop_path = os.path.join(os.environ['HADOOP_PREFIX'], 'bin', 'hadoop') + hadoop_classpath = subprocess.check_output([hadoop_path, 'classpath', '--glob']).decode() + os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath + os.environ['TFOS_CLASSPATH_UPDATED'] = '1' + + # reserve GPU, if requested + if tf.test.is_built_with_cuda(): + gpus_to_use = gpu_info.get_gpus(num_gpus) + logging.info("Using gpu(s): {0}".format(gpus_to_use)) + os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use + else: + # CPU + logging.info("Using CPU") + os.environ['CUDA_VISIBLE_DEVICES'] = '' + def get_ip_address(): """Simple utility to get host IP address.""" try: @@ -24,7 +46,7 @@ def get_ip_address(): ip_address = socket.gethostbyname(socket.getfqdn()) finally: s.close() - + return ip_address