Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 5 additions & 40 deletions examples/mnist/tf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,6 @@
# hdfs dfs -rm -r mnist_model
# hdfs dfs -rm -r predictions

${SPARK_HOME}/bin/spark-submit \
--master yarn \
--deploy-mode cluster \
--queue ${QUEUE} \
--num-executors 4 \
--executor-memory 27G \
--py-files TensorFlowOnSpark/tfspark.zip,TensorFlowOnSpark/examples/mnist/tf/mnist_dist_dataset.py \
--conf spark.dynamicAllocation.enabled=false \
--conf spark.yarn.maxAppAttempts=1 \
--archives hdfs:///user/${USER}/Python.zip#Python \
--conf spark.executorEnv.LD_LIBRARY_PATH=$LIB_CUDA:$LIB_JVM:$LIB_HDFS \
--driver-library-path=$LIB_CUDA \
TensorFlowOnSpark/examples/mnist/tf/mnist_spark_dataset.py \
${TF_ROOT}/${TF_VERSION}/examples/mnist/tf/mnist_spark_dataset.py \
--images_labels mnist/csv2/train \
--format csv2 \
--mode train \
--model mnist_model

# to use inference mode, change `--mode train` to `--mode inference` and add `--output predictions`
# one item in csv2 format is `image | label`, to use input data in TFRecord format, change `--format csv` to `--format tfr`
# to use infiniband, add `--rdma`
```

### _using QueueRunners_
```bash
# for CPU mode:
# export QUEUE=default
# remove references to $LIB_CUDA

# hdfs dfs -rm -r mnist_model
# hdfs dfs -rm -r predictions

${SPARK_HOME}/bin/spark-submit \
--master yarn \
--deploy-mode cluster \
Expand All @@ -55,16 +22,14 @@ ${SPARK_HOME}/bin/spark-submit \
--conf spark.executorEnv.LD_LIBRARY_PATH=$LIB_CUDA:$LIB_JVM:$LIB_HDFS \
--driver-library-path=$LIB_CUDA \
TensorFlowOnSpark/examples/mnist/tf/mnist_spark.py \
--images mnist/tfr/train/images \
--labels mnist/tfr/train/labels \
--format csv \
--images_labels mnist/csv2/train \
--format csv2 \
--mode train \
--model mnist_model

# to use inference mode, change `--mode train` to `--mode inference` and add `--output predictions`
# to use input data in TFRecord format, change `--format csv` to `--format tfr`
# one item in csv2 format is `image | label`, to use input data in TFRecord format, change `--format csv` to `--format tfr`
# to use infiniband, add `--rdma`
```

### _using Spark ML Pipeline_
```bash
Expand All @@ -83,7 +48,7 @@ ${SPARK_HOME}/bin/spark-submit \
--queue ${QUEUE} \
--num-executors 4 \
--executor-memory 27G \
--jars hdfs:///user/${USER}/tensorflow-hadoop-1.0-SNAPSHOT.jar \
--jars hdfs:///user/${USER}/tensorflow-hadoop-1.0-SNAPSHOT.jar \
--py-files TensorFlowOnSpark/tfspark.zip,TensorFlowOnSpark/examples/mnist/tf/mnist_dist_pipeline.py \
--conf spark.dynamicAllocation.enabled=false \
--conf spark.yarn.maxAppAttempts=1 \
Expand All @@ -102,6 +67,6 @@ TensorFlowOnSpark/examples/mnist/tf/mnist_spark_pipeline.py \
--inference_output predictions

# to use input data in TFRecord format, change `--format csv` to `--format tfr`
# tensorflow-hadoop-1.0-SNAPSHOT.jar is needed for transforming csv input to TFRecord
# tensorflow-hadoop-1.0-SNAPSHOT.jar is needed for transforming csv input to TFRecord
# `--tfrecord_dir` is needed for temporarily saving dataframe to TFRecord on hdfs
```
86 changes: 65 additions & 21 deletions examples/mnist/tf/mnist_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def print_log(worker_num, arg):

def map_fun(args, ctx):
from datetime import datetime
from tensorflowonspark import TFNode
import math
import os
import tensorflow as tf
Expand Down Expand Up @@ -54,6 +55,27 @@ def _parse_tfr(example_proto):
label = tf.to_float(features['label'])
return (image, label)

def build_model(graph, x):
with graph.as_default():
# Variables of the hidden layer
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
tf.summary.histogram("hidden_weights", hid_w)

# Variables of the softmax layer
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
tf.summary.histogram("softmax_weights", sm_w)

hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)

y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
prediction = tf.argmax(y, 1, name="prediction")
return y, prediction

if job_name == "ps":
server.join()
elif job_name == "worker":
Expand All @@ -78,36 +100,21 @@ def _parse_tfr(example_proto):
iterator = ds.make_one_shot_iterator()
x, y_ = iterator.get_next()

# Variables of the hidden layer
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
tf.summary.histogram("hidden_weights", hid_w)

# Variables of the softmax layer
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
tf.summary.histogram("softmax_weights", sm_w)
# Build core model
y, prediction = build_model(tf.get_default_graph(), x)

# Add training bits
x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
tf.summary.image("x_img", x_img)

hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)

y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))

global_step = tf.train.get_or_create_global_step()

loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
tf.summary.scalar("loss", loss)
train_op = tf.train.AdagradOptimizer(0.01).minimize(
loss, global_step=global_step)

# Test trained model
label = tf.argmax(y_, 1, name="label")
prediction = tf.argmax(y, 1, name="prediction")
correct_prediction = tf.equal(prediction, label)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
tf.summary.scalar("acc", accuracy)
Expand All @@ -117,8 +124,10 @@ def _parse_tfr(example_proto):
init_op = tf.global_variables_initializer()

# Create a "supervisor", which oversees the training process and stores model state into HDFS
logdir = ctx.absolute_path(args.model)
print("tensorflow model path: {0}".format(logdir))
model_dir = ctx.absolute_path(args.model)
export_dir = ctx.absolute_path(args.export)
print("tensorflow model path: {0}".format(model_dir))
print("tensorflow export path: {0}".format(export_dir))
summary_writer = tf.summary.FileWriter("tensorboard_%d" % worker_num, graph=tf.get_default_graph())

if args.mode == 'inference':
Expand All @@ -130,7 +139,7 @@ def _parse_tfr(example_proto):
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(task_index == 0),
scaffold=tf.train.Scaffold(init_op=init_op, summary_op=summary_op, saver=saver),
checkpoint_dir=logdir,
checkpoint_dir=model_dir,
hooks=[tf.train.StopAtStepHook(last_step=args.steps)]) as sess:
print("{} session ready".format(datetime.now().isoformat()))

Expand Down Expand Up @@ -163,6 +172,41 @@ def _parse_tfr(example_proto):

print("{} stopping MonitoredTrainingSession".format(datetime.now().isoformat()))

# export model (on chief worker only)
if args.mode == "train" and task_index == 0:
tf.reset_default_graph()

# add placeholders for input images (and optional labels)
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
label = tf.argmax(y_, 1, name="label")

# add core model
y, prediction = build_model(tf.get_default_graph(), x)

# restore from last checkpoint
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(model_dir)
print("ckpt: {}".format(ckpt))
assert ckpt, "Invalid model checkpoint path: {}".format(model_dir)
saver.restore(sess, ckpt.model_checkpoint_path)

print("Exporting saved_model to: {}".format(export_dir))
# exported signatures defined in code
signatures = {
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: {
'inputs': { 'image': x },
'outputs': { 'prediction': prediction },
'method_name': tf.saved_model.signature_constants.PREDICT_METHOD_NAME
}
}
TFNode.export_saved_model(sess,
export_dir,
tf.saved_model.tag_constants.SERVING,
signatures)
print("Exported saved_model")

# WORKAROUND for https://github.com/tensorflow/tensorflow/issues/21745
# wait for all other nodes to complete (via done files)
done_dir = "{}/{}/done".format(ctx.absolute_path(args.model), args.mode)
Expand Down
106 changes: 106 additions & 0 deletions examples/mnist/tf/mnist_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2018 Yahoo Inc.
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.

# This example demonstrates how to leverage Spark for parallel inferencing from a SavedModel.
#
# Normally, you can use TensorFlowOnSpark to just form a TensorFlow cluster for training and inferencing.
# However, in some situations, you may have a SavedModel without the original code for defining the inferencing
# graph. In these situations, we can use Spark to instantiate a single-node TensorFlow instance on each executor,
# where each executor can independently load the model and inference on input data.
#
# Note: this particular example demonstrates use of `tf.data.Dataset` to read the input data for inferencing,
# but it could also be adapted to just use an RDD of TFRecords from Spark.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import logging
import sys
import tensorflow as tf
import time
import traceback

IMAGE_PIXELS = 28

def inference(it, num_workers, args):
from tensorflowonspark import util

# consume worker number from RDD partition iterator
for i in it:
worker_num = i
print("worker_num: {}".format(i))

# setup env for single-node TF
util.single_node_env()

# load saved_model using default tag and signature
sess = tf.Session()
tf.saved_model.loader.load(sess, ['serve'], args.export)

# parse function for TFRecords
def parse_tfr(example_proto):
feature_def = {"label": tf.FixedLenFeature(10, tf.int64),
"image": tf.FixedLenFeature(IMAGE_PIXELS * IMAGE_PIXELS, tf.int64)}
features = tf.parse_single_example(example_proto, feature_def)
norm = tf.constant(255, dtype=tf.float32, shape=(784,))
image = tf.div(tf.to_float(features['image']), norm)
label = tf.to_float(features['label'])
return (image, label)

# define a new tf.data.Dataset (for inferencing)
ds = tf.data.Dataset.list_files("{}/part-*".format(args.images_labels))
ds = ds.shard(num_workers, worker_num)
ds = ds.interleave(tf.data.TFRecordDataset, cycle_length=1)
ds = ds.map(parse_tfr).batch(10)
iterator = ds.make_one_shot_iterator()
image_label = iterator.get_next(name='inf_image')

# create an output file per spark worker for the predictions
tf.gfile.MakeDirs(args.output)
output_file = tf.gfile.GFile("{}/part-{:05d}".format(args.output, worker_num), mode='w')

while True:
try:
# get images and labels from tf.data.Dataset
img, lbl = sess.run(['inf_image:0', 'inf_image:1'])

# inference by feeding these images and labels into the input tensors
# you can view the exported model signatures via:
# saved_model_cli show --dir mnist_export --all

# note that we feed directly into the graph tensors (bypassing the exported signatures)
# also note that we can feed/fetch tensors that were not explicitly exported, e.g. `y_` and `label:0`

labels, preds = sess.run(['label:0', 'prediction:0'], feed_dict={'x:0': img, 'y_:0': lbl})
for i in range(len(labels)):
output_file.write("{} {}\n".format(labels[i], preds[i]))
except tf.errors.OutOfRangeError:
break

output_file.close()

if __name__ == '__main__':
import os
from pyspark.context import SparkContext
from pyspark.conf import SparkConf

sc = SparkContext(conf=SparkConf().setAppName("mnist_inference"))
executors = sc._conf.get("spark.executor.instances")
num_executors = int(executors) if executors is not None else 1

parser = argparse.ArgumentParser()
parser.add_argument("--cluster_size", help="number of nodes in the cluster (for S with labelspark Standalone)", type=int, default=num_executors)
parser.add_argument('--images_labels', type=str, help='Directory for input images with labels')
parser.add_argument("--export", help="HDFS path to export model", type=str, default="mnist_export")
parser.add_argument("--output", help="HDFS path to save predictions", type=str, default="predictions")
args, _ = parser.parse_known_args()
print("args: {}".format(args))

# Not using TFCluster... just running single-node TF instances on each executor
nodes = list(range(args.cluster_size))
nodeRDD = sc.parallelize(list(range(args.cluster_size)), args.cluster_size)
nodeRDD.foreachPartition(lambda worker_num: inference(worker_num, args.cluster_size, args))

1 change: 1 addition & 0 deletions examples/mnist/tf/mnist_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
parser.add_argument("--driver_ps_nodes", help="""run tensorflow PS node on driver locally.
You will need to set cluster_size = num_executors + num_ps""", default=False)
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("--export", help="HDFS path to export model", type=str, default="mnist_export")
parser.add_argument("--format", help="example format: (csv2|tfr)", choices=["csv2", "tfr"], default="tfr")
parser.add_argument("--images_labels", help="HDFS path to MNIST image_label files in parallelized format")
parser.add_argument("--mode", help="train|inference", default="train")
Expand Down
27 changes: 5 additions & 22 deletions tensorflowonspark/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import tensorflow as tf
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.python.saved_model import loader
from . import TFCluster, gpu_info, dfutil
from . import TFCluster, gpu_info, dfutil, util

import argparse
import copy
Expand Down Expand Up @@ -570,32 +570,15 @@ def single_node_env(args):
Args:
:args: command line arguments as either argparse args or argv list
"""
# setup ARGV for the TF process
if isinstance(args, list):
sys.argv = args
elif args.argv:
sys.argv = args.argv

# ensure expanded CLASSPATH w/o glob characters (required for Spark 2.1 + JNI)
if 'HADOOP_PREFIX' in os.environ and 'TFOS_CLASSPATH_UPDATED' not in os.environ:
classpath = os.environ['CLASSPATH']
hadoop_path = os.path.join(os.environ['HADOOP_PREFIX'], 'bin', 'hadoop')
hadoop_classpath = subprocess.check_output([hadoop_path, 'classpath', '--glob']).decode()
logging.debug("CLASSPATH: {0}".format(hadoop_classpath))
os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath
os.environ['TFOS_CLASSPATH_UPDATED'] = '1'

# reserve GPU, if requested
if tf.test.is_built_with_cuda():
# GPU
num_gpus = args.num_gpus if 'num_gpus' in args else 1
gpus_to_use = gpu_info.get_gpus(num_gpus)
logging.info("Using gpu(s): {0}".format(gpus_to_use))
os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use
# Note: if there is a GPU conflict (CUDA_ERROR_INVALID_DEVICE), the entire task will fail and retry.
else:
# CPU
logging.info("Using CPU")
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# setup ENV for Hadoop-compatibility and/or GPU allocation
num_gpus = args.num_gpus if 'num_gpus' in args else 1
util.single_node_env(num_gpus)


def get_meta_graph_def(saved_model_dir, tag_set):
Expand Down
Loading