In [4]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import argparse
import subprocess
from tensorflowonspark import TFCluster
import mnist_dist
from importlib import *

In [5]:
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%I:%M:%S')

In [6]:
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
parser.add_argument("--format", help="example format", choices=["csv","pickle","tfr"], default="csv")
parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model")
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("--steps", help="maximum number of steps", type=int, default=500)
parser.add_argument("--batch_size", help="number of examples per batch", type=int, default=100)
parser.add_argument("--mode", help="train|inference", default="train")
parser.add_argument("--rdma", help="use rdma connection", default=False)
num_executors = 2

In [7]:
#remove existing models if any
subprocess.call(["rm", "-rf", "mnist_model"])

0

In [9]:
#verify training images
train_images_files = "mnist/csv/train"
print(subprocess.check_output(["ls", "-l", train_images_files]))

CalledProcessError: Command '['ls', '-l', 'mnist/csv/train']' returned non-zero exit status 2.

In [None]:
#verify training labels
train_labels_files = "csv/train/labels"
print(subprocess.check_output(["ls", "-l", train_labels_files]))

In [None]:
#Parse arguments for training
args = parser.parse_args(['--mode', 'train', '--steps', '3000', '--epochs', '5',
                          '--images', train_images_files, 
                          '--labels', train_labels_files])


In [None]:
#start the cluster for training
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, num_executors, 1, True, TFCluster.InputMode.SPARK)

In [None]:
#Feed data via Spark RDD
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
cluster.train(dataRDD, args.epochs)

In [None]:
cluster.shutdown()

In [None]:
print(subprocess.check_output(["ls", "-l", "mnist_model"]))

In [None]:
#verify test images
test_images_files = "csv/test/images"
print(subprocess.check_output(["ls", "-l", test_images_files]))

In [None]:
#verify test labels
test_labels_files = "csv/test/labels"
print(subprocess.check_output(["ls", "-l", test_labels_files]))

In [None]:
#Parse arguments for inference
args = parser.parse_args(['--mode', 'inference', 
                          '--images', test_images_files, 
                          '--labels', test_labels_files])

In [None]:
#Start the cluster for inference
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, num_executors, 1, False, TFCluster.InputMode.SPARK)

In [None]:
#prepare data as Spark RDD
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
#feed data for inference
prediction_results = cluster.inference(dataRDD)
prediction_results.take(20)

In [None]:
cluster.shutdown()