In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import subprocess
from com.yahoo.ml.tf import TFCluster
import mnist_dist

In [31]:
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format")
parser.add_argument("-f", "--format", help="example format", choices=["csv","pickle","tfr"], default="csv")
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/test", default="mnist_model")
parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions")
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=500)
parser.add_argument("-X", "--mode", help="train|inference", default="train")
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
num_executors = 2

In [34]:
#remove existing models if any
subprocess.call(["rm", "-rf", "mnist_model"])

0

In [35]:
#verify training images
train_images_files = "csv/train/images"
print(subprocess.check_output(["ls", "-l", train_images_files]))

total 213808
-rw-r--r--  1 afeng  staff         0 Feb  8 14:52 _SUCCESS
-rw-r--r--  1 afeng  staff   9338236 Feb  8 14:52 part-00000
-rw-r--r--  1 afeng  staff  11231804 Feb  8 14:52 part-00001
-rw-r--r--  1 afeng  staff  11214784 Feb  8 14:52 part-00002
-rw-r--r--  1 afeng  staff  11226100 Feb  8 14:52 part-00003
-rw-r--r--  1 afeng  staff  11212767 Feb  8 14:52 part-00004
-rw-r--r--  1 afeng  staff  11173834 Feb  8 14:52 part-00005
-rw-r--r--  1 afeng  staff  11214285 Feb  8 14:52 part-00006
-rw-r--r--  1 afeng  staff  11201024 Feb  8 14:52 part-00007
-rw-r--r--  1 afeng  staff  11194141 Feb  8 14:52 part-00008
-rw-r--r--  1 afeng  staff  10449019 Feb  8 14:52 part-00009



In [36]:
#verify training labels
train_labels_files = "csv/train/labels"
print(subprocess.check_output(["ls", "-l", train_labels_files]))

total 4688
-rw-r--r--  1 afeng  staff       0 Feb  8 14:52 _SUCCESS
-rw-r--r--  1 afeng  staff  204800 Feb  8 14:52 part-00000
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00001
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00002
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00003
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00004
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00005
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00006
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00007
-rw-r--r--  1 afeng  staff  245760 Feb  8 14:52 part-00008
-rw-r--r--  1 afeng  staff  229120 Feb  8 14:52 part-00009



In [None]:
#reserver a cluster for training
cluster = TFCluster.reserve(sc, num_executors, 1, TFCluster.InputMode.TENSORFLOW, TFCluster.InputMode.SPARK)

In [37]:
args = parser.parse_args(['--model', 'mnist_model', 
                          '--images', train_images_files, 
                          '--labels', train_labels_files, 
                          '--tensorboard'])

In [38]:
cluster.start(mnist_dist.map_fun, args)
#Feed data via Spark RDD
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
cluster.train(dataRDD, args.epochs)

In [39]:
cluster.shutdown()

connecting to ('notforever-lm', 57027), 4fe1cc03-8c73-479d-a6f4-81f579d44295


In [40]:
print(subprocess.check_output(["ls", "-l", "mnist_model"]))

total 8672
-rw-r--r--  1 afeng  staff     263 Feb  8 17:41 checkpoint
-rw-r--r--  1 afeng  staff  113289 Feb  8 17:41 graph.pbtxt
-rw-r--r--  1 afeng  staff  814164 Feb  8 17:41 model.ckpt-0.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb  8 17:41 model.ckpt-0.index
-rw-r--r--  1 afeng  staff   43896 Feb  8 17:41 model.ckpt-0.meta
-rw-r--r--  1 afeng  staff  814164 Feb  8 17:41 model.ckpt-118.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb  8 17:41 model.ckpt-118.index
-rw-r--r--  1 afeng  staff   43896 Feb  8 17:41 model.ckpt-118.meta
-rw-r--r--  1 afeng  staff  814164 Feb  8 17:41 model.ckpt-243.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb  8 17:41 model.ckpt-243.index
-rw-r--r--  1 afeng  staff   43896 Feb  8 17:41 model.ckpt-243.meta
-rw-r--r--  1 afeng  staff  814164 Feb  8 17:41 model.ckpt-369.data-00000-of-00001
-rw-r--r--  1 afeng  staff     372 Feb  8 17:41 model.ckpt-369.index
-rw-r--r--  1 afeng  staff   43896 Feb  8 17:41 model.ckpt-369.m

In [41]:
#verify test images
test_images_files = "csv/test/images"
print(subprocess.check_output(["ls", "-l", test_images_files]))

total 35720
-rw-r--r--  1 afeng  staff        0 Feb  8 14:53 _SUCCESS
-rw-r--r--  1 afeng  staff  1810248 Feb  8 14:53 part-00000
-rw-r--r--  1 afeng  staff  1806102 Feb  8 14:53 part-00001
-rw-r--r--  1 afeng  staff  1811128 Feb  8 14:53 part-00002
-rw-r--r--  1 afeng  staff  1812952 Feb  8 14:53 part-00003
-rw-r--r--  1 afeng  staff  1810946 Feb  8 14:53 part-00004
-rw-r--r--  1 afeng  staff  1835497 Feb  8 14:53 part-00005
-rw-r--r--  1 afeng  staff  1845261 Feb  8 14:53 part-00006
-rw-r--r--  1 afeng  staff  1850655 Feb  8 14:53 part-00007
-rw-r--r--  1 afeng  staff  1852712 Feb  8 14:53 part-00008
-rw-r--r--  1 afeng  staff  1833942 Feb  8 14:53 part-00009



In [42]:
#verify test labels
test_labels_files = "csv/test/labels"
print(subprocess.check_output(["ls", "-l", test_labels_files]))

total 800
-rw-r--r--  1 afeng  staff      0 Feb  8 14:53 _SUCCESS
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00000
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00001
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00002
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00003
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00004
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00005
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00006
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00007
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00008
-rw-r--r--  1 afeng  staff  40000 Feb  8 14:53 part-00009



In [43]:
#reserver cluster for inference
cluster = TFCluster.reserve(sc, num_executors, 1, TFCluster.InputMode.TENSORFLOW, TFCluster.InputMode.SPARK)

{'addr': ('notforever-lm', 57115), 'task_index': 0, 'port': 57125, 'authkey': UUID('51b0f93e-44d1-4a9c-a69d-48214d2d0d5f'), 'worker_num': 0, 'host': 'notforever-lm', 'ppid': 8996, 'job_name': 'ps', 'tb_port': 0}
{'addr': '/var/folders/mk/v_wmtt4n6491cqfl6qsctdv0000l60/T/pymp-9rwgyX/listener-eNZblt', 'task_index': 0, 'port': 57116, 'authkey': UUID('1f09964a-ab71-4722-ba5d-025727eac04a'), 'worker_num': 1, 'host': 'notforever-lm', 'ppid': 8995, 'job_name': 'worker', 'tb_port': 0}


In [44]:
args = parser.parse_args(['--model', 'mnist_model', 
                          '--mode', 'inference', 
                          '--images', test_images_files, 
                          '--labels', test_labels_files])

In [45]:
cluster.start(mnist_dist.map_fun, args)
#prepare data as Spark RDD
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
#feed data for inference
prediction_results = cluster.inference(dataRDD)
prediction_results.take(20)

['2017-02-08T17:43:10.800489 Label: [ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.], Prediction: 7',
 '2017-02-08T17:43:10.801168 Label: [ 0.  0.  1.  0.  0.  0.  0.  0.  0.  0.], Prediction: 2',
 '2017-02-08T17:43:10.801530 Label: [ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.], Prediction: 1',
 '2017-02-08T17:43:10.802064 Label: [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.], Prediction: 0',
 '2017-02-08T17:43:10.802509 Label: [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.], Prediction: 4',
 '2017-02-08T17:43:10.802918 Label: [ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.], Prediction: 1',
 '2017-02-08T17:43:10.803290 Label: [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.], Prediction: 4',
 '2017-02-08T17:43:10.803642 Label: [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.], Prediction: 9',
 '2017-02-08T17:43:10.803990 Label: [ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.], Prediction: 6',
 '2017-02-08T17:43:10.804336 Label: [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.], Prediction: 9',
 '2017-02-08T17:43:10.804742 Label: [ 1.  0.  0.  

In [46]:
cluster.shutdown()

connecting to ('notforever-lm', 57115), 51b0f93e-44d1-4a9c-a69d-48214d2d0d5f
