Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 55 additions & 74 deletions examples/mnist/spark/mnist_dist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017 Yahoo Inc.
#Copyright 2018 Yahoo Inc.
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.

Expand Down Expand Up @@ -58,37 +58,38 @@ def feed_dict(batch):
worker_device="/job:worker/task:%d" % task_index,
cluster=cluster)):

# Variables of the hidden layer
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
tf.summary.histogram("hidden_weights", hid_w)

# Variables of the softmax layer
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
tf.summary.histogram("softmax_weights", sm_w)

# Placeholders or QueueRunner/Readers for input data
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="x")
y_ = tf.placeholder(tf.float32, [None, 10], name="y_")

x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
tf.summary.image("x_img", x_img)

hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)

y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))

global_step = tf.Variable(0)

loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
tf.summary.scalar("loss", loss)

train_op = tf.train.AdagradOptimizer(0.01).minimize(
loss, global_step=global_step)
# Placeholders or QueueRunner/Readers for input data
with tf.name_scope('inputs'):
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS] , name="x")
y_ = tf.placeholder(tf.float32, [None, 10], name="y_")

x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
tf.summary.image("x_img", x_img)

with tf.name_scope('layer'):
# Variables of the hidden layer
with tf.name_scope('hidden_layer'):
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w")
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
tf.summary.histogram("hidden_weights", hid_w)
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)

# Variables of the softmax layer
with tf.name_scope('softmax_layer'):
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10], stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
tf.summary.histogram("softmax_weights", sm_w)
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))

global_step = tf.train.get_or_create_global_step()

with tf.name_scope('loss'):
loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
tf.summary.scalar("loss", loss)

with tf.name_scope('train'):
train_op = tf.train.AdagradOptimizer(0.01).minimize(loss, global_step=global_step)

# Test trained model
label = tf.argmax(y_, 1, name="label")
Expand All @@ -98,73 +99,53 @@ def feed_dict(batch):
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
tf.summary.scalar("acc", accuracy)

saver = tf.train.Saver()
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()

# Create a "supervisor", which oversees the training process and stores model state into HDFS
logdir = ctx.absolute_path(args.model)
# logdir = args.model
print("tensorflow model path: {0}".format(logdir))

hooks = [tf.train.StopAtStepHook(last_step=100000)]

if job_name == "worker" and task_index == 0:
summary_writer = tf.summary.FileWriter(logdir, graph=tf.get_default_graph())

if args.mode == "train":
sv = tf.train.Supervisor(is_chief=(task_index == 0),
logdir=logdir,
init_op=init_op,
summary_op=None,
summary_writer=None,
saver=saver,
global_step=global_step,
stop_grace_secs=300,
save_model_secs=10)
else:
sv = tf.train.Supervisor(is_chief=(task_index == 0),
logdir=logdir,
summary_op=None,
saver=saver,
global_step=global_step,
stop_grace_secs=300,
save_model_secs=0)

# The supervisor takes care of session initialization, restoring from
# a checkpoint, and closing when done or an error occurs.
with sv.managed_session(server.target) as sess:
print("{0} session ready".format(datetime.now().isoformat()))

# Loop until the supervisor shuts down or 1000000 steps have completed.
# The MonitoredTrainingSession takes care of session initialization, restoring from
# a checkpoint, and closing when done or an error occurs
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(task_index == 0),
checkpoint_dir=logdir,
hooks=hooks) as mon_sess:

step = 0
tf_feed = ctx.get_data_feed(args.mode == "train")
while not sv.should_stop() and not tf_feed.should_stop() and step < args.steps:
# Run a training step asynchronously.
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.
while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args.steps:
# Run a training step asynchronously
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.

# using feed_dict
batch_xs, batch_ys = feed_dict(tf_feed.next_batch(batch_size))
feed = {x: batch_xs, y_: batch_ys}

if len(batch_xs) > 0:
if args.mode == "train":
_, summary, step = sess.run([train_op, summary_op, global_step], feed_dict=feed)
_, summary, step = mon_sess.run([train_op, summary_op, global_step], feed_dict=feed)
# print accuracy and save model checkpoint to HDFS every 100 steps
if (step % 100 == 0):
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy,{x: batch_xs, y_: batch_ys})))
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, mon_sess.run(accuracy,{x: batch_xs, y_: batch_ys})))

if sv.is_chief:
if task_index == 0:
summary_writer.add_summary(summary, step)
else: # args.mode == "inference"
labels, preds, acc = sess.run([label, prediction, accuracy], feed_dict=feed)
labels, preds, acc = mon_sess.run([label, prediction, accuracy], feed_dict=feed)

results = ["{0} Label: {1}, Prediction: {2}".format(datetime.now().isoformat(), l, p) for l,p in zip(labels,preds)]
tf_feed.batch_results(results)
print("acc: {0}".format(acc))
print("results: {0}, acc: {1}".format(results, acc))

if sv.should_stop() or step >= args.steps:
if mon_sess.should_stop() or step >= args.steps:
tf_feed.terminate()

# Ask for all the services to stop.
print("{0} stopping supervisor".format(datetime.now().isoformat()))
sv.stop()

print("{0} stopping MonitoredTrainingSession".format(datetime.now().isoformat()))
summary_writer.close()