diff --git a/examples/mnist/spark/mnist_dist.py b/examples/mnist/spark/mnist_dist.py index da544731..d8afc835 100755 --- a/examples/mnist/spark/mnist_dist.py +++ b/examples/mnist/spark/mnist_dist.py @@ -1,4 +1,4 @@ -# Copyright 2017 Yahoo Inc. +#Copyright 2018 Yahoo Inc. # Licensed under the terms of the Apache 2.0 license. # Please see LICENSE file in the project root for terms. @@ -58,37 +58,38 @@ def feed_dict(batch): worker_device="/job:worker/task:%d" % task_index, cluster=cluster)): - # Variables of the hidden layer - hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units], - stddev=1.0 / IMAGE_PIXELS), name="hid_w") - hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b") - tf.summary.histogram("hidden_weights", hid_w) - - # Variables of the softmax layer - sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10], - stddev=1.0 / math.sqrt(hidden_units)), name="sm_w") - sm_b = tf.Variable(tf.zeros([10]), name="sm_b") - tf.summary.histogram("softmax_weights", sm_w) - - # Placeholders or QueueRunner/Readers for input data - x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="x") - y_ = tf.placeholder(tf.float32, [None, 10], name="y_") - - x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1]) - tf.summary.image("x_img", x_img) - - hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) - hid = tf.nn.relu(hid_lin) - - y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) - - global_step = tf.Variable(0) - - loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) - tf.summary.scalar("loss", loss) - - train_op = tf.train.AdagradOptimizer(0.01).minimize( - loss, global_step=global_step) + # Placeholders or QueueRunner/Readers for input data + with tf.name_scope('inputs'): + x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS] , name="x") + y_ = tf.placeholder(tf.float32, [None, 10], name="y_") + + x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1]) + tf.summary.image("x_img", x_img) + + with tf.name_scope('layer'): + # Variables of the hidden layer + with tf.name_scope('hidden_layer'): + hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w") + hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b") + tf.summary.histogram("hidden_weights", hid_w) + hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) + hid = tf.nn.relu(hid_lin) + + # Variables of the softmax layer + with tf.name_scope('softmax_layer'): + sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10], stddev=1.0 / math.sqrt(hidden_units)), name="sm_w") + sm_b = tf.Variable(tf.zeros([10]), name="sm_b") + tf.summary.histogram("softmax_weights", sm_w) + y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) + + global_step = tf.train.get_or_create_global_step() + + with tf.name_scope('loss'): + loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) + tf.summary.scalar("loss", loss) + + with tf.name_scope('train'): + train_op = tf.train.AdagradOptimizer(0.01).minimize(loss, global_step=global_step) # Test trained model label = tf.argmax(y_, 1, name="label") @@ -98,48 +99,29 @@ def feed_dict(batch): accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy") tf.summary.scalar("acc", accuracy) - saver = tf.train.Saver() summary_op = tf.summary.merge_all() - init_op = tf.global_variables_initializer() - # Create a "supervisor", which oversees the training process and stores model state into HDFS logdir = ctx.absolute_path(args.model) + # logdir = args.model print("tensorflow model path: {0}".format(logdir)) - + hooks = [tf.train.StopAtStepHook(last_step=100000)] + if job_name == "worker" and task_index == 0: summary_writer = tf.summary.FileWriter(logdir, graph=tf.get_default_graph()) - if args.mode == "train": - sv = tf.train.Supervisor(is_chief=(task_index == 0), - logdir=logdir, - init_op=init_op, - summary_op=None, - summary_writer=None, - saver=saver, - global_step=global_step, - stop_grace_secs=300, - save_model_secs=10) - else: - sv = tf.train.Supervisor(is_chief=(task_index == 0), - logdir=logdir, - summary_op=None, - saver=saver, - global_step=global_step, - stop_grace_secs=300, - save_model_secs=0) - - # The supervisor takes care of session initialization, restoring from - # a checkpoint, and closing when done or an error occurs. - with sv.managed_session(server.target) as sess: - print("{0} session ready".format(datetime.now().isoformat())) - - # Loop until the supervisor shuts down or 1000000 steps have completed. + # The MonitoredTrainingSession takes care of session initialization, restoring from + # a checkpoint, and closing when done or an error occurs + with tf.train.MonitoredTrainingSession(master=server.target, + is_chief=(task_index == 0), + checkpoint_dir=logdir, + hooks=hooks) as mon_sess: + step = 0 tf_feed = ctx.get_data_feed(args.mode == "train") - while not sv.should_stop() and not tf_feed.should_stop() and step < args.steps: - # Run a training step asynchronously. - # See `tf.train.SyncReplicasOptimizer` for additional details on how to - # perform *synchronous* training. + while not mon_sess.should_stop() and not tf_feed.should_stop() and step < args.steps: + # Run a training step asynchronously + # See `tf.train.SyncReplicasOptimizer` for additional details on how to + # perform *synchronous* training. # using feed_dict batch_xs, batch_ys = feed_dict(tf_feed.next_batch(batch_size)) @@ -147,24 +129,23 @@ def feed_dict(batch): if len(batch_xs) > 0: if args.mode == "train": - _, summary, step = sess.run([train_op, summary_op, global_step], feed_dict=feed) + _, summary, step = mon_sess.run([train_op, summary_op, global_step], feed_dict=feed) # print accuracy and save model checkpoint to HDFS every 100 steps if (step % 100 == 0): - print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy,{x: batch_xs, y_: batch_ys}))) + print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, mon_sess.run(accuracy,{x: batch_xs, y_: batch_ys}))) - if sv.is_chief: + if task_index == 0: summary_writer.add_summary(summary, step) else: # args.mode == "inference" - labels, preds, acc = sess.run([label, prediction, accuracy], feed_dict=feed) + labels, preds, acc = mon_sess.run([label, prediction, accuracy], feed_dict=feed) results = ["{0} Label: {1}, Prediction: {2}".format(datetime.now().isoformat(), l, p) for l,p in zip(labels,preds)] tf_feed.batch_results(results) - print("acc: {0}".format(acc)) + print("results: {0}, acc: {1}".format(results, acc)) - if sv.should_stop() or step >= args.steps: + if mon_sess.should_stop() or step >= args.steps: tf_feed.terminate() # Ask for all the services to stop. - print("{0} stopping supervisor".format(datetime.now().isoformat())) - sv.stop() - + print("{0} stopping MonitoredTrainingSession".format(datetime.now().isoformat())) + summary_writer.close()