In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from pyspark.context import SparkContext
from pyspark.conf import SparkConf

import argparse
import os
import numpy

import sys
import tensorflow as tf
import threading
from datetime import datetime

from com.yahoo.ml.tf import TFCluster
import mnist_dist



In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("-f", "--format", help="example format: (csv|pickle|tfr)", choices=["csv","pickle","tfr"], default="csv")
parser.add_argument("-l", "--labels", help="HDFS path to MNIST labels in parallelized format")
parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/test", default="mnist_model")
parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions")
parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000)
parser.add_argument("-X", "--mode", help="train|test", default="train")
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
args = parser.parse_args(['-f', 'csv', '-m', 'mnist_test_model', '-r', '1', '-i', 'mnist/csv/test/images', '-l', 'mnist/csv/test/labels', '-X', 'test'])
print(args)

Namespace(format='csv', images='mnist/csv/test/images', labels='mnist/csv/test/labels', mode='test', model='mnist_test_model', output='predictions', readers=1, steps=1000, tensorboard=False)


In [3]:
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])

dataRDD = images.zip(labels)
dataRDD.count()

10000

In [4]:
num_executors = int(sc._conf.get("spark.executor.instances"))
num_ps = 1



In [5]:
# print("reserving TFNodes", "w/ TensorBoard" if args.tensorboard else "")
cluster = TFCluster.reserve(sc, num_executors, num_ps, False, TFCluster.InputMode.SPARK)


In [6]:
for node in cluster.cluster_info:
  print(node)


{'addr': ('gpbl191n04.blue.ygrid.yahoo.com', 38707), 'task_index': 0, 'port': 50943, 'authkey': UUID('bbdd8c81-6495-4ab1-b32d-29f031eed90e'), 'worker_num': 0, 'host': 'gpbl191n04.blue.ygrid.yahoo.com', 'ppid': 106407, 'job_name': 'ps', 'tb_port': 0}
{'addr': '/tmp/pymp-vF6Xqe/listener-MNNwWC', 'task_index': 0, 'port': 54972, 'authkey': UUID('32857dd3-ba20-46d1-be59-a4d771d13094'), 'worker_num': 1, 'host': 'gpbl191n17.blue.ygrid.yahoo.com', 'ppid': 14546, 'job_name': 'worker', 'tb_port': 0}
{'addr': '/tmp/pymp-xTeEsP/listener-JaUVXw', 'task_index': 1, 'port': 40427, 'authkey': UUID('12d40b04-9bf0-405a-afff-e45069938a7a'), 'worker_num': 2, 'host': 'gpbl191n04.blue.ygrid.yahoo.com', 'ppid': 106509, 'job_name': 'worker', 'tb_port': 0}
{'addr': '/tmp/pymp-hERBmx/listener-4XWkUp', 'task_index': 2, 'port': 50226, 'authkey': UUID('8d39f405-ff29-4d31-a321-526eb262442f'), 'worker_num': 3, 'host': 'gpbl191n04.blue.ygrid.yahoo.com', 'ppid': 106411, 'job_name': 'worker', 'tb_port': 0}


In [7]:
cluster.start(mnist_dist.map_fun, args)


In [8]:
resultRDD = cluster.test(dataRDD)

In [9]:
resultRDD.take(25)


[7, 2, 1, 0, 4, 1, 4, 9, 6, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5, 4]

In [10]:
cluster.shutdown()

connecting to ('gpbl191n04.blue.ygrid.yahoo.com', 38707), bbdd8c81-6495-4ab1-b32d-29f031eed90e
