Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Integrate w/ Spark ML Pipelines #114

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6d48d2a
initial POC for ML pipelines
Jul 20, 2017
bc355e9
single-node inferencing
Jul 24, 2017
5fed5ff
fix file
Jul 25, 2017
4d78b7f
separate pipeline class from client/mnist code
Jul 25, 2017
f61edbb
generic inferencing from saved_model
Jul 27, 2017
e760799
cleanup code
Jul 27, 2017
43a76fa
minor updates
Jul 28, 2017
f4a2dd5
move TFPipeline into tfos dir
Jul 28, 2017
3054df4
Merge branch 'master' into leewyang_pipeline2
Jul 31, 2017
1860e7b
expand CLASSPATH; add gpu allocation to TFModel
Jul 31, 2017
8d48f39
add ml params
Aug 1, 2017
0c11751
refactor merge as a mixin class
Aug 1, 2017
f4e38e3
separate tf_args from spark-submit args
Aug 1, 2017
6431f34
move pipeline example
Aug 1, 2017
94348d0
support exporting model with multiple signatures
Aug 3, 2017
9be76fe
remove mnist_spark_session.py; rename rdma->protocol, HasPredictionCo…
Aug 3, 2017
9f1b86b
remove mnist.py
Aug 3, 2017
6095ed9
support inputCol param
Aug 4, 2017
4241e24
configurable params for input/output columns & tensors
Aug 9, 2017
1305728
configurable mapping of input/output columns to tensors
Aug 9, 2017
94ac9b1
add signatures as Param
Aug 9, 2017
1b7b2d8
support input_mapping for training
Aug 11, 2017
48f4b62
misc. cleanup
Aug 11, 2017
9b268e3
revert to hard-coded signatures on export; support tensor alias/name …
Aug 15, 2017
5385a62
use dict for mappings instead of arrays of strings
Aug 15, 2017
fe222b5
inference-only example from saved_model
Aug 15, 2017
8c16196
minor updates
Aug 17, 2017
9dcacc7
remove label from exported signature
Aug 17, 2017
06e9a38
use tensor name strings instead of tensors
Aug 17, 2017
2b0c338
support multiple output tensors as DataFrame columns
Aug 17, 2017
fc18537
swap output mapping key/values
Aug 18, 2017
2902657
fix bug w/ output mapping order
Aug 18, 2017
b10375b
add support for inferencing from checkpoint
Aug 21, 2017
e22e8b6
merge w/ latest master
Aug 22, 2017
23afd7a
tests for Spark ML pipeline
Aug 23, 2017
fe7bdf9
support InputMode.TENSORFLOW in pipeline
Sep 1, 2017
48b6f96
add utility to load/save TFRecords as DataFrames; bypass auto-convers…
Sep 5, 2017
f094aae
tests for dfutil
Sep 6, 2017
0dc2c63
minor fixes
Sep 7, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
175 changes: 175 additions & 0 deletions examples/mnist/pipeline/spark/mnist_dist.py
@@ -0,0 +1,175 @@
# Copyright 2017 Yahoo Inc.
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.

# Distributed MNIST on grid based on TensorFlow MNIST example

from __future__ import absolute_import
from __future__ import division
from __future__ import nested_scopes
from __future__ import print_function

def print_log(worker_num, arg):
print("{0}: {1}".format(worker_num, arg))

def map_fun(args, ctx):
from tensorflowonspark import TFNode
from datetime import datetime
import math
import numpy
import tensorflow as tf
import time

worker_num = ctx.worker_num
job_name = ctx.job_name
task_index = ctx.task_index

IMAGE_PIXELS = 28

# Delay PS nodes a bit, since workers seem to reserve GPUs more quickly/reliably (w/o conflict)
if job_name == "ps":
time.sleep((worker_num + 1) * 5)

# Parameters
hidden_units = 128
batch_size = args.batch_size

# Get TF cluster and server instances
cluster, server = TFNode.start_cluster_server(ctx, 1, args.protocol == 'rdma')

def feed_dict(batch):
# Convert from dict of named arrays to two numpy arrays of the proper type
images = batch['image']
labels = batch['label']
xs = numpy.array(images)
xs = xs.astype(numpy.float32)
xs = xs / 255.0
ys = numpy.array(labels)
ys = ys.astype(numpy.uint8)
return (xs, ys)

if job_name == "ps":
server.join()
elif job_name == "worker":

# Assigns ops to the local worker by default.
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % task_index,
cluster=cluster)):

# Variables of the hidden layer
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
tf.summary.histogram("hidden_weights", hid_w)

# Variables of the softmax layer
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
tf.summary.histogram("softmax_weights", sm_w)

# Placeholders or QueueRunner/Readers for input data
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="x")
y_ = tf.placeholder(tf.float32, [None, 10], name="y_")

x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
tf.summary.image("x_img", x_img)

hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)

y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))

global_step = tf.Variable(0)

loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
tf.summary.scalar("loss", loss)

train_op = tf.train.AdagradOptimizer(0.01).minimize(
loss, global_step=global_step)

# Test trained model
label = tf.argmax(y_, 1, name="label")
prediction = tf.argmax(y, 1,name="prediction")
correct_prediction = tf.equal(prediction, label)

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
tf.summary.scalar("acc", accuracy)

saver = tf.train.Saver()
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()

# Create a "supervisor", which oversees the training process and stores model state into HDFS
logdir = TFNode.hdfs_path(ctx, args.model_dir)
print("tensorflow model path: {0}".format(logdir))
summary_writer = tf.summary.FileWriter("tensorboard_%d" % (worker_num), graph=tf.get_default_graph())

sv = tf.train.Supervisor(is_chief=(task_index == 0),
logdir=logdir,
init_op=init_op,
summary_op=None,
saver=saver,
global_step=global_step,
stop_grace_secs=300,
save_model_secs=10)

# The supervisor takes care of session initialization, restoring from
# a checkpoint, and closing when done or an error occurs.
with sv.managed_session(server.target) as sess:
print("{0} session ready".format(datetime.now().isoformat()))

# Loop until the supervisor shuts down or 1000000 steps have completed.
step = 0
#tf_feed = TFNode.DataFeed(ctx.mgr)
tf_feed = TFNode.DataFeed(ctx.mgr, input_mapping=args.input_mapping)
while not sv.should_stop() and not tf_feed.should_stop() and step < args.steps:
# Run a training step asynchronously.
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.

# using feed_dict
batch_xs, batch_ys = feed_dict(tf_feed.next_batch(batch_size))
feed = {x: batch_xs, y_: batch_ys}

if len(batch_xs) > 0:
_, summary, step = sess.run([train_op, summary_op, global_step], feed_dict=feed)
# print accuracy and save model checkpoint to HDFS every 100 steps
if (step % 100 == 0):
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy,{x: batch_xs, y_: batch_ys})))

if sv.is_chief:
summary_writer.add_summary(summary, step)

if sv.should_stop() or step >= args.steps:
tf_feed.terminate()

if sv.is_chief and args.export_dir:
print("{0} exporting saved_model to: {1}".format(datetime.now().isoformat(), args.export_dir))
# exported signatures defined in code
signatures = {
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: {
'inputs': { 'image': x },
'outputs': { 'prediction': prediction },
'method_name': tf.saved_model.signature_constants.PREDICT_METHOD_NAME
},
'featurize': {
'inputs': { 'image': x },
'outputs': { 'features': hid },
'method_name': 'featurize'
}
}
TFNode.export_saved_model(sess,
args.export_dir,
tf.saved_model.tag_constants.SERVING,
signatures)
else:
# non-chief workers should wait for chief
while not sv.should_stop():
print("Waiting for chief")
time.sleep(5)

# Ask for all the services to stop.
print("{0} stopping supervisor".format(datetime.now().isoformat()))
sv.stop()
130 changes: 130 additions & 0 deletions examples/mnist/pipeline/spark/mnist_inference.py
@@ -0,0 +1,130 @@
# Copyright 2017 Yahoo Inc.
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql import SparkSession

import argparse
import numpy
import tensorflow as tf
from datetime import datetime

from tensorflowonspark.pipeline import TFModel

sc = SparkContext(conf=SparkConf().setAppName("mnist_spark"))
spark = SparkSession(sc)

executors = sc._conf.get("spark.executor.instances")
num_executors = int(executors) if executors is not None else 1
num_ps = 1

parser = argparse.ArgumentParser()

######## PARAMS ########

## TFoS/cluster
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
parser.add_argument("--export_dir", help="HDFS path to exported saved_model", type=str)
parser.add_argument("--model_dir", help="HDFS path to model checkpoint", type=str)

######## ARGS ########

# Spark input/output
parser.add_argument("--format", help="example format: (csv|pickle|tfr)", choices=["csv","pickle","tfr"], default="csv")
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")

args = parser.parse_args()
print("args:",args)

print("{0} ===== Start".format(datetime.now().isoformat()))

if args.format == "tfr":
images = sc.newAPIHadoopFile(args.images, "org.tensorflow.hadoop.io.TFRecordFileInputFormat",
keyClass="org.apache.hadoop.io.BytesWritable",
valueClass="org.apache.hadoop.io.NullWritable")
def toNumpy(bytestr):
example = tf.train.Example()
example.ParseFromString(bytestr)
features = example.features.feature
image = numpy.array(features['image'].int64_list.value)
label = numpy.array(features['label'].int64_list.value)
return (image, label)
dataRDD = images.map(lambda x: toNumpy(str(x[0])))
else:
if args.format == "csv":
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
else: # args.format == "pickle":
images = sc.pickleFile(args.images)
labels = sc.pickleFile(args.labels)
print("zipping images and labels")
dataRDD = images.zip(labels)

# Pipeline API
df = spark.createDataFrame(dataRDD, ['col1', 'col2'])

model = TFModel(args) \
.setBatchSize(args.batch_size)

#
# Using saved_model w/ signature defs and tensor aliases
#

# prediction
model.setTagSet(tf.saved_model.tag_constants.SERVING) \
.setExportDir(args.export_dir) \
.setSignatureDefKey(tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) \
.setInputMapping({'col1':'image'}) \
.setOutputMapping({'prediction':'col_out'})

# featurize
# model.setTagSet(tf.saved_model.tag_constants.SERVING) \
# .setExportDir(args.export_dir) \
# .setSignatureDefKey('featurize') \
# .setInputMapping({'col1':'image'}) \
# .setOutputMapping({'features':'col_out'})

#
# Using saved_model w/ custom/direct mappings of tensors
#

# prediction
# model.setTagSet(tf.saved_model.tag_constants.SERVING) \
# .setExportDir(args.export_dir) \
# .setInputMapping({'col1':'x'}) \
# .setOutputMapping({'prediction':'col_out'})

# featurize
# model.setTagSet(tf.saved_model.tag_constants.SERVING) \
# .setExportDir(args.export_dir) \
# .setInputMapping({'col1':'x'}) \
# .setOutputMapping({'prediction':'col_out', 'Relu':'col_out2'})

#
# Using checkpoint w/ custom/direct mappings of tensors
#

# prediction
# model.setModelDir(args.model_dir) \
# .setInputMapping({'col1':'x'}) \
# .setOutputMapping({'prediction':'col_out'})

# featurize
# model.setModelDir(args.model_dir) \
# .setInputMapping({'col1':'x'}) \
# .setOutputMapping({'prediction':'col_out', 'Relu':'col_out2'})

print("{0} ===== Model.transform()".format(datetime.now().isoformat()))
preds = model.transform(df)
preds.write.json(args.output)

print("{0} ===== Stop".format(datetime.now().isoformat()))