In [1]:
import argparse
import subprocess
from com.yahoo.ml.tf import TFCluster
import mnist_dist

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format")
parser.add_argument("-f", "--format", help="example format", choices=["csv","pickle","tfr"], default="csv")
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/test", default="mnist_model")
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=500)
parser.add_argument("-X", "--mode", help="train|inference", default="train")
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
num_executors = 2

In [3]:
#remove existing models if any
subprocess.call(["rm", "-rf", "mnist_model"])

0

In [4]:
#verify training images
train_images_files = "csv/train/images"
print(subprocess.check_output(["ls", "-l", train_images_files]))

total 213808
-rw-r--r--  1 afeng  staff         0 Feb  8 14:52 _SUCCESS
-rw-r--r--  1 afeng  staff   9338236 Feb  8 14:52 part-00000
-rw-r--r--  1 afeng  staff  11231804 Feb  8 14:52 part-00001
-rw-r--r--  1 afeng  staff  11214784 Feb  8 14:52 part-00002
-rw-r--r--  1 afeng  staff  11226100 Feb  8 14:52 part-00003
-rw-r--r--  1 afeng  staff  11212767 Feb  8 14:52 part-00004
-rw-r--r--  1 afeng  staff  11173834 Feb  8 14:52 part-00005
-rw-r--r--  1 afeng  staff  11214285 Feb  8 14:52 part-00006
-rw-r--r--  1 afeng  staff  11201024 Feb  8 14:52 part-00007
-rw-r--r--  1 afeng  staff  11194141 Feb  8 14:52 part-00008
-rw-r--r--  1 afeng  staff  10449019 Feb  8 14:52 part-00009



In [5]:
#verify training labels
train_labels_files = "csv/train/labels"
print(subprocess.check_output(["ls", "-l", train_labels_files]))

total 4688
-rw-r--r--  1 afeng  staff       0 Feb  8 14:52 _SUCCESS
-rw-r--r--  1 afeng  staff  204800 Feb  8 14:52 part-00000
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00001
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00002
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00003
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00004
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00005
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00006
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00007
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00008
-rw-r--r--  1 afeng  staff  229120 Feb  8 14:52 part-00009



In [6]:
#reserver a cluster for training
cluster = TFCluster.reserve(sc, num_executors, 1, True, TFCluster.InputMode.SPARK)

{'addr': ('notforever-lm', 64607), 'task_index': 0, 'port': 64616, 'authkey': UUID('b70557c3-ab4a-4ad8-b59a-fd43aebf921a'), 'worker_num': 0, 'host': 'notforever-lm', 'ppid': 49056, 'job_name': 'ps', 'tb_port': 0}
{'addr': '/var/folders/mk/v_wmtt4n6491cqfl6qsctdv0000l60/T/pymp-xIhNRC/listener-kLHARu', 'task_index': 0, 'port': 64618, 'authkey': UUID('788f951b-40b1-4356-9e06-18887d82175a'), 'worker_num': 1, 'host': 'notforever-lm', 'ppid': 49057, 'job_name': 'worker', 'tb_port': 64617}


In [7]:
args = parser.parse_args(['--mode', 'train', 
                          '--images', train_images_files, 
                          '--labels', train_labels_files])
cluster.start(mnist_dist.map_fun, args)
#Feed data via Spark RDD
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
cluster.train(dataRDD, args.epochs)
#
#Check out tensorboard at http://localhost:<tb_port> per tb_port listed in previous snippet

In [8]:
cluster.shutdown()

In [9]:
print(subprocess.check_output(["ls", "-l", "mnist_model"]))

total 8672
-rw-r--r--  1 afeng  staff     263 Feb 10 23:49 checkpoint
-rw-r--r--  1 afeng  staff  113289 Feb 10 23:48 graph.pbtxt
-rw-r--r--  1 afeng  staff  814164 Feb 10 23:48 model.ckpt-0.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb 10 23:48 model.ckpt-0.index
-rw-r--r--  1 afeng  staff   43896 Feb 10 23:48 model.ckpt-0.meta
-rw-r--r--  1 afeng  staff  814164 Feb 10 23:48 model.ckpt-115.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb 10 23:48 model.ckpt-115.index
-rw-r--r--  1 afeng  staff   43896 Feb 10 23:48 model.ckpt-115.meta
-rw-r--r--  1 afeng  staff  814164 Feb 10 23:48 model.ckpt-238.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb 10 23:48 model.ckpt-238.index
-rw-r--r--  1 afeng  staff   43896 Feb 10 23:48 model.ckpt-238.meta
-rw-r--r--  1 afeng  staff  814164 Feb 10 23:48 model.ckpt-361.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb 10 23:48 model.ckpt-361.index
-rw-r--r--  1 afeng  staff   43896 Feb 10 23:48 model.ckpt-361.m

In [10]:
#verify test images
test_images_files = "csv/test/images"
print(subprocess.check_output(["ls", "-l", test_images_files]))

total 35720
-rw-r--r--  1 afeng  staff        0 Feb  8 14:53 _SUCCESS
-rw-r--r--  1 afeng  staff  1810248 Feb  8 14:53 part-00000
-rw-r--r--  1 afeng  staff  1806102 Feb  8 14:53 part-00001
-rw-r--r--  1 afeng  staff  1811128 Feb  8 14:53 part-00002
-rw-r--r--  1 afeng  staff  1812952 Feb  8 14:53 part-00003
-rw-r--r--  1 afeng  staff  1810946 Feb  8 14:53 part-00004
-rw-r--r--  1 afeng  staff  1835497 Feb  8 14:53 part-00005
-rw-r--r--  1 afeng  staff  1845261 Feb  8 14:53 part-00006
-rw-r--r--  1 afeng  staff  1850655 Feb  8 14:53 part-00007
-rw-r--r--  1 afeng  staff  1852712 Feb  8 14:53 part-00008
-rw-r--r--  1 afeng  staff  1833942 Feb  8 14:53 part-00009



In [11]:
#verify test labels
test_labels_files = "csv/test/labels"
print(subprocess.check_output(["ls", "-l", test_labels_files]))

total 800
-rw-r--r--  1 afeng  staff      0 Feb  8 14:53 _SUCCESS
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00000
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00001
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00002
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00003
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00004
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00005
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00006
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00007
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00008
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00009



In [12]:
#reserver cluster for inference
cluster = TFCluster.reserve(sc, num_executors, 1, True, TFCluster.InputMode.SPARK)

{'addr': ('notforever-lm', 64718), 'task_index': 0, 'port': 64729, 'authkey': UUID('8853c86d-851f-40a5-bde0-813d73ba6950'), 'worker_num': 0, 'host': 'notforever-lm', 'ppid': 49057, 'job_name': 'ps', 'tb_port': 0}
{'addr': '/var/folders/mk/v_wmtt4n6491cqfl6qsctdv0000l60/T/pymp-iODloJ/listener-d4Uqp5', 'task_index': 0, 'port': 64728, 'authkey': UUID('042e4b77-c88d-4ae5-a2f6-0be0dd18bc39'), 'worker_num': 1, 'host': 'notforever-lm', 'ppid': 49056, 'job_name': 'worker', 'tb_port': 64722}


In [13]:
args = parser.parse_args(['--mode', 'inference', 
                          '--images', test_images_files, 
                          '--labels', test_labels_files])
cluster.start(mnist_dist.map_fun, args)
#prepare data as Spark RDD
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
#feed data for inference
prediction_results = cluster.inference(dataRDD)
prediction_results.take(20)
#
#Check out tensorboard at http://localhost:<tb_port> per tb_port listed in previous snippet

['2017-02-10T23:50:36.504700 Label: 7, Prediction: 7',
 '2017-02-10T23:50:36.504815 Label: 2, Prediction: 2',
 '2017-02-10T23:50:36.504861 Label: 1, Prediction: 1',
 '2017-02-10T23:50:36.504902 Label: 0, Prediction: 0',
 '2017-02-10T23:50:36.504941 Label: 4, Prediction: 4',
 '2017-02-10T23:50:36.504980 Label: 1, Prediction: 1',
 '2017-02-10T23:50:36.505019 Label: 4, Prediction: 4',
 '2017-02-10T23:50:36.505056 Label: 9, Prediction: 9',
 '2017-02-10T23:50:36.505094 Label: 5, Prediction: 6',
 '2017-02-10T23:50:36.505133 Label: 9, Prediction: 9',
 '2017-02-10T23:50:36.505171 Label: 0, Prediction: 0',
 '2017-02-10T23:50:36.505209 Label: 6, Prediction: 6',
 '2017-02-10T23:50:36.505247 Label: 9, Prediction: 9',
 '2017-02-10T23:50:36.505284 Label: 0, Prediction: 0',
 '2017-02-10T23:50:36.505322 Label: 1, Prediction: 1',
 '2017-02-10T23:50:36.505359 Label: 5, Prediction: 5',
 '2017-02-10T23:50:36.505396 Label: 9, Prediction: 9',
 '2017-02-10T23:50:36.505435 Label: 7, Prediction: 7',
 '2017-02-

In [14]:
cluster.shutdown()