In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import argparse
import subprocess
from com.yahoo.ml.tf import TFCluster
import mnist_dist

In [2]:
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%I:%M:%S')

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format")
parser.add_argument("-f", "--format", help="example format", choices=["csv","pickle","tfr"], default="csv")
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/test", default="mnist_model")
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=500)
parser.add_argument("-X", "--mode", help="train|inference", default="train")
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
num_executors = 2

In [16]:
#remove existing models if any
subprocess.call(["rm", "-rf", "mnist_model"])

0

In [5]:
#verify training images
train_images_files = "csv/train/images"
print(subprocess.check_output(["ls", "-l", train_images_files]))

total 213808
-rw-r--r--  1 afeng  staff         0 Feb  8 14:52 _SUCCESS
-rw-r--r--  1 afeng  staff   9338236 Feb  8 14:52 part-00000
-rw-r--r--  1 afeng  staff  11231804 Feb  8 14:52 part-00001
-rw-r--r--  1 afeng  staff  11214784 Feb  8 14:52 part-00002
-rw-r--r--  1 afeng  staff  11226100 Feb  8 14:52 part-00003
-rw-r--r--  1 afeng  staff  11212767 Feb  8 14:52 part-00004
-rw-r--r--  1 afeng  staff  11173834 Feb  8 14:52 part-00005
-rw-r--r--  1 afeng  staff  11214285 Feb  8 14:52 part-00006
-rw-r--r--  1 afeng  staff  11201024 Feb  8 14:52 part-00007
-rw-r--r--  1 afeng  staff  11194141 Feb  8 14:52 part-00008
-rw-r--r--  1 afeng  staff  10449019 Feb  8 14:52 part-00009



In [6]:
#verify training labels
train_labels_files = "csv/train/labels"
print(subprocess.check_output(["ls", "-l", train_labels_files]))

total 4688
-rw-r--r--  1 afeng  staff       0 Feb  8 14:52 _SUCCESS
-rw-r--r--  1 afeng  staff  204800 Feb  8 14:52 part-00000
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00001
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00002
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00003
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00004
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00005
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00006
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00007
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00008
-rw-r--r--  1 afeng  staff  229120 Feb  8 14:52 part-00009



In [17]:
#reserver a cluster for training
cluster = TFCluster.reserve(sc, num_executors, 1, True, TFCluster.InputMode.SPARK)

09:18:56 INFO:Reserving TFSparkNodes w/ TensorBoard
09:19:06 INFO:TensorBoard running at: http://notforever-lm:54463


In [18]:
#Check out tensorboard at http://localhost:<tb_port> per above during the execution of this step.
#It may wait a little for TensorFlow TAG to be loaded.
#
args = parser.parse_args(['--mode', 'train', 
                          '--images', train_images_files, 
                          '--labels', train_labels_files])
cluster.start(mnist_dist.map_fun, args)
#Feed data via Spark RDD
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
cluster.train(dataRDD, args.epochs)

09:19:27 INFO:Starting TensorFlow
09:19:33 INFO:Feeding training data


In [19]:
cluster.shutdown()

09:21:02 INFO:Stopping TensorFlow nodes


In [20]:
print(subprocess.check_output(["ls", "-l", "mnist_model"]))

total 8672
-rw-r--r--  1 afeng  staff     263 Feb 12 09:20 checkpoint
-rw-r--r--  1 afeng  staff  113289 Feb 12 09:19 graph.pbtxt
-rw-r--r--  1 afeng  staff  814164 Feb 12 09:19 model.ckpt-0.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb 12 09:19 model.ckpt-0.index
-rw-r--r--  1 afeng  staff   43896 Feb 12 09:19 model.ckpt-0.meta
-rw-r--r--  1 afeng  staff  814164 Feb 12 09:19 model.ckpt-117.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb 12 09:19 model.ckpt-117.index
-rw-r--r--  1 afeng  staff   43896 Feb 12 09:19 model.ckpt-117.meta
-rw-r--r--  1 afeng  staff  814164 Feb 12 09:19 model.ckpt-240.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb 12 09:19 model.ckpt-240.index
-rw-r--r--  1 afeng  staff   43896 Feb 12 09:19 model.ckpt-240.meta
-rw-r--r--  1 afeng  staff  814164 Feb 12 09:20 model.ckpt-364.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb 12 09:20 model.ckpt-364.index
-rw-r--r--  1 afeng  staff   43896 Feb 12 09:20 model.ckpt-364.m

In [11]:
#verify test images
test_images_files = "csv/test/images"
print(subprocess.check_output(["ls", "-l", test_images_files]))

total 35720
-rw-r--r--  1 afeng  staff        0 Feb  8 14:53 _SUCCESS
-rw-r--r--  1 afeng  staff  1810248 Feb  8 14:53 part-00000
-rw-r--r--  1 afeng  staff  1806102 Feb  8 14:53 part-00001
-rw-r--r--  1 afeng  staff  1811128 Feb  8 14:53 part-00002
-rw-r--r--  1 afeng  staff  1812952 Feb  8 14:53 part-00003
-rw-r--r--  1 afeng  staff  1810946 Feb  8 14:53 part-00004
-rw-r--r--  1 afeng  staff  1835497 Feb  8 14:53 part-00005
-rw-r--r--  1 afeng  staff  1845261 Feb  8 14:53 part-00006
-rw-r--r--  1 afeng  staff  1850655 Feb  8 14:53 part-00007
-rw-r--r--  1 afeng  staff  1852712 Feb  8 14:53 part-00008
-rw-r--r--  1 afeng  staff  1833942 Feb  8 14:53 part-00009



In [12]:
#verify test labels
test_labels_files = "csv/test/labels"
print(subprocess.check_output(["ls", "-l", test_labels_files]))

total 800
-rw-r--r--  1 afeng  staff      0 Feb  8 14:53 _SUCCESS
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00000
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00001
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00002
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00003
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00004
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00005
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00006
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00007
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00008
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00009



In [21]:
#reserver cluster for inference
cluster = TFCluster.reserve(sc, num_executors, 1, True, TFCluster.InputMode.SPARK)

09:21:16 INFO:Reserving TFSparkNodes w/ TensorBoard
09:21:27 INFO:TensorBoard running at: http://notforever-lm:54558


In [22]:
#Check out tensorboard at http://localhost:<tb_port> per above
#
args = parser.parse_args(['--mode', 'inference', 
                          '--images', test_images_files, 
                          '--labels', test_labels_files])
cluster.start(mnist_dist.map_fun, args)
#prepare data as Spark RDD
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
#feed data for inference
prediction_results = cluster.inference(dataRDD)
prediction_results.take(20)

09:22:03 INFO:Starting TensorFlow
09:22:08 INFO:Feeding inference data


['2017-02-12T09:22:10.997358 Label: 7, Prediction: 7',
 '2017-02-12T09:22:10.997468 Label: 2, Prediction: 2',
 '2017-02-12T09:22:10.997511 Label: 1, Prediction: 1',
 '2017-02-12T09:22:10.997549 Label: 0, Prediction: 0',
 '2017-02-12T09:22:10.997586 Label: 4, Prediction: 4',
 '2017-02-12T09:22:10.997623 Label: 1, Prediction: 1',
 '2017-02-12T09:22:10.997661 Label: 4, Prediction: 4',
 '2017-02-12T09:22:10.997697 Label: 9, Prediction: 9',
 '2017-02-12T09:22:10.997734 Label: 5, Prediction: 6',
 '2017-02-12T09:22:10.997771 Label: 9, Prediction: 9',
 '2017-02-12T09:22:10.997808 Label: 0, Prediction: 0',
 '2017-02-12T09:22:10.997845 Label: 6, Prediction: 6',
 '2017-02-12T09:22:10.997881 Label: 9, Prediction: 9',
 '2017-02-12T09:22:10.997917 Label: 0, Prediction: 0',
 '2017-02-12T09:22:10.997954 Label: 1, Prediction: 1',
 '2017-02-12T09:22:10.997990 Label: 5, Prediction: 5',
 '2017-02-12T09:22:10.998026 Label: 9, Prediction: 9',
 '2017-02-12T09:22:10.998063 Label: 7, Prediction: 7',
 '2017-02-

In [23]:
cluster.shutdown()

09:22:36 INFO:Stopping TensorFlow nodes
