In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import subprocess
from tensorflowonspark import TFCluster
import mnist_dist

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=0)
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("-f", "--format", help="example format", choices=["csv","tfr"], default="tfr")
parser.add_argument("-m", "--model", help="HDFS path to save/load model", default="mnist_model")
parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions")
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000)
parser.add_argument("-X", "--mode", help="train|inference", default="train")
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
#Number of executors you have actually launched
num_executors = 3

In [3]:
#remove existing models and predictions
subprocess.call(["hadoop", "fs", "-rm", "-R", "mnist_model"])
subprocess.call(["hadoop", "fs", "-rm", "-R", "predictions"])

1

In [4]:
#reserve a TensorFlow cluster
cluster = TFCluster.reserve(sc, num_executors, 1, True, TFCluster.InputMode.TENSORFLOW)
#kick off training
args = parser.parse_args(['--mode', 'train', 
                          '--images', 'mnist/tfr/train'])
cluster.start(mnist_dist.map_fun, args)

2017-02-09 19:38:54,066 INFO (MainThread-7122) Reserving TFSparkNodes w/ TensorBoard
2017-02-09 19:39:05,708 INFO (MainThread-7122) TensorBoard running at: http://ip-172-31-25-197:59675
2017-02-09 19:39:05,710 INFO (MainThread-7122) Starting TensorFlow


In [5]:
#The cluster will only be shutddown when the training is actually completed.
#It will take a few minutes.
cluster.shutdown()

2017-02-09 19:39:28,498 INFO (MainThread-7122) Stopping TensorFlow nodes


In [6]:
#examine the newly trained model
print(subprocess.check_output(["hadoop", "fs", "-ls", "mnist_model"]))

Found 17 items
-rw-r--r--   3 root supergroup        265 2017-02-09 19:40 mnist_model/checkpoint
-rw-r--r--   3 root supergroup     142752 2017-02-09 19:39 mnist_model/graph.pbtxt
-rw-r--r--   3 root supergroup     814164 2017-02-09 19:40 mnist_model/model.ckpt-416.data-00000-of-00001
-rw-r--r--   3 root supergroup        372 2017-02-09 19:40 mnist_model/model.ckpt-416.index
-rw-r--r--   3 root supergroup      56894 2017-02-09 19:40 mnist_model/model.ckpt-416.meta
-rw-r--r--   3 root supergroup     814164 2017-02-09 19:40 mnist_model/model.ckpt-543.data-00000-of-00001
-rw-r--r--   3 root supergroup        372 2017-02-09 19:40 mnist_model/model.ckpt-543.index
-rw-r--r--   3 root supergroup      56894 2017-02-09 19:40 mnist_model/model.ckpt-543.meta
-rw-r--r--   3 root supergroup     814164 2017-02-09 19:40 mnist_model/model.ckpt-669.data-00000-of-00001
-rw-r--r--   3 root supergroup        372 2017-02-09 19:40 mnist_model/model.ckpt-669.index
-rw-r--r--   3 root supergroup      56894 20

In [7]:
#reserve a TensorFlow cluster
cluster = TFCluster.reserve(sc, num_executors, 1, False, TFCluster.InputMode.TENSORFLOW)
#kick off inference using the trained model
inf_args = parser.parse_args(['--mode', 'inference', 
                              '--images', 'mnist/tfr/test', 
                              '--epochs', '1'])
cluster.start(mnist_dist.map_fun, inf_args)

2017-02-09 19:41:32,734 INFO (MainThread-7122) Reserving TFSparkNodes
2017-02-09 19:41:43,519 INFO (MainThread-7122) Starting TensorFlow


In [8]:
#The cluster will only be shutddown when the inference is actually completed
cluster.shutdown()

2017-02-09 19:41:48,531 INFO (MainThread-7122) Stopping TensorFlow nodes


In [9]:
#The prediction result with (lable, prediction) for each example
predictions = sc.textFile("predictions")
predictions.take(10)

[u'7 7',
 u'2 2',
 u'1 1',
 u'0 0',
 u'4 4',
 u'1 1',
 u'4 4',
 u'9 9',
 u'5 6',
 u'9 9']