Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/mnist/tf/mnist_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
parser.add_argument("-X", "--mode", help="train|inference", default="train")
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
parser.add_argument("-p", "--driver_ps_nodes", help="run tensorflow PS node on driver locally", default=False)
args = parser.parse_args()
print("args:",args)


print("{0} ===== Start".format(datetime.now().isoformat()))
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW, log_dir=args.model)
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW,
driver_ps_nodes=args.driver_ps_nodes, log_dir=args.model)
cluster.shutdown()

print("{0} ===== Stop".format(datetime.now().isoformat()))
Expand Down
5 changes: 4 additions & 1 deletion examples/mnist/tf/mnist_spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@
parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true")
parser.add_argument("-X", "--mode", help="train|inference", default="train")
parser.add_argument("-c", "--rdma", help="use rdma connection", default=False)
parser.add_argument("-p", "--driver_ps_nodes", help="""run tensorflow PS node on driver locally.
You will need to set cluster_size = num_executors + num_ps""", default=False)
args = parser.parse_args()
print("args:",args)


print("{0} ===== Start".format(datetime.now().isoformat()))
cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.TENSORFLOW)
cluster = TFCluster.run(sc, mnist_dist_dataset.map_fun, args, args.cluster_size, num_ps, args.tensorboard,
TFCluster.InputMode.TENSORFLOW, driver_ps_nodes=args.driver_ps_nodes)
cluster.shutdown()

print("{0} ===== Stop".format(datetime.now().isoformat()))
Expand Down
3 changes: 3 additions & 0 deletions examples/mnist/tf/mnist_spark_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
parser.add_argument("--tfrecord_dir", help="HDFS path to temporarily save DataFrame to disk", type=str)
parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
parser.add_argument("--num_ps", help="number of PS nodes in cluster", type=int, default=1)
parser.add_argument("-p", "--driver_ps_nodes", help="""run tensorflow PS node on driver locally.
You will need to set cluster_size = num_executors + num_ps""", default=False)
parser.add_argument("--protocol", help="Tensorflow network protocol (grpc|rdma)", default="grpc")
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
Expand Down Expand Up @@ -81,6 +83,7 @@
.setExportDir(args.export_dir) \
.setClusterSize(args.cluster_size) \
.setNumPS(args.num_ps) \
.setDriverPSNodes(args.driver_ps_nodes) \
.setInputMode(TFCluster.InputMode.TENSORFLOW) \
.setTFRecordDir(args.tfrecord_dir) \
.setProtocol(args.protocol) \
Expand Down
30 changes: 27 additions & 3 deletions tensorflowonspark/TFCluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def tensorboard_url(self):
tb_url = "http://{0}:{1}".format(node['host'], node['tb_port'])
return tb_url

def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mode=InputMode.TENSORFLOW, log_dir=None, queues=['input', 'output']):
def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mode=InputMode.TENSORFLOW,
log_dir=None, driver_ps_nodes=False, queues=['input', 'output']):
"""Starts the TensorFlowOnSpark cluster and Runs the TensorFlow "main" function on the Spark executors

Args:
Expand All @@ -198,6 +199,7 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
:tensorboard: boolean indicating if the chief worker should spawn a Tensorboard server.
:input_mode: TFCluster.InputMode
:log_dir: directory to save tensorboard event logs. If None, defaults to a fixed path on local filesystem.
:driver_ps_nodes: run the PS nodes on the driver locally instead of on the spark executors; this help maximizing computing resources (esp. GPU). You will need to set cluster_size = num_executors + num_ps
:queues: *INTERNAL_USE*

Returns:
Expand All @@ -206,10 +208,14 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
logging.info("Reserving TFSparkNodes {0}".format("w/ TensorBoard" if tensorboard else ""))
assert num_ps < num_executors

if driver_ps_nodes and input_mode != InputMode.TENSORFLOW:
raise Exception('running PS nodes on driver locally is only supported in InputMode.TENSORFLOW')

# build a cluster_spec template using worker_nums
cluster_template = {}
cluster_template['ps'] = range(num_ps)
cluster_template['worker'] = range(num_ps, num_executors)
logging.info("worker node range %s, ps node range %s" % (cluster_template['worker'], cluster_template['ps']))

# get default filesystem from spark
defaultFS = sc._jsc.hadoopConfiguration().get("fs.defaultFS")
Expand All @@ -234,7 +240,25 @@ def run(sc, map_fun, tf_args, num_executors, num_ps, tensorboard=False, input_mo
'working_dir': working_dir,
'server_addr': server_addr
}
nodeRDD = sc.parallelize(range(num_executors), num_executors)
if driver_ps_nodes:
nodeRDD = sc.parallelize(range(num_ps, num_executors), num_executors - num_ps)
else:
nodeRDD = sc.parallelize(range(num_executors), num_executors)

if driver_ps_nodes:
def _start_ps(node_index):
logging.info("starting ps node locally %d" % node_index)
TFSparkNode.run(map_fun,
tf_args,
cluster_meta,
tensorboard,
log_dir,
queues,
background=(input_mode == InputMode.SPARK))([node_index])
for i in cluster_template['ps']:
ps_thread = threading.Thread(target=lambda: _start_ps(i))
ps_thread.daemon = True
ps_thread.start()

# start TF on a background thread (on Spark driver) to allow for feeding job
def _start():
Expand All @@ -244,7 +268,7 @@ def _start():
tensorboard,
log_dir,
queues,
(input_mode == InputMode.SPARK)))
background=(input_mode == InputMode.SPARK)))
t = threading.Thread(target=_start)
t.start()

Expand Down
5 changes: 4 additions & 1 deletion tensorflowonspark/TFSparkNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,11 @@ def wrapper_fn(args, context):

if job_name == 'ps' or background:
# invoke the TensorFlow main function in a background thread
logging.info("Starting TensorFlow {0}:{1} on cluster node {2} on background process".format(job_name, task_index, worker_num))
logging.info("Starting TensorFlow {0}:{1} as {2} on cluster node {3} on background process".format(
job_name, task_index, job_name, worker_num))
p = multiprocessing.Process(target=wrapper_fn, args=(tf_args, ctx))
if job_name == 'ps':
p.daemon = True
p.start()

# for ps nodes only, wait indefinitely in foreground thread for a "control" event (None == "stop")
Expand Down
9 changes: 8 additions & 1 deletion tensorflowonspark/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,17 @@ def getModelDir(self):

class HasNumPS(Params):
num_ps = Param(Params._dummy(), "num_ps", "Number of PS nodes in cluster", typeConverter=TypeConverters.toInt)
driver_ps_nodes = Param(Params._dummy(), "driver_ps_nodes", "Run PS nodes on driver locally", typeConverter=TypeConverters.toBoolean)
def __init__(self):
super(HasNumPS, self).__init__()
def setNumPS(self, value):
return self._set(num_ps=value)
def getNumPS(self):
return self.getOrDefault(self.num_ps)
def setDriverPSNodes(self, value):
return self._set(driver_ps_nodes=value)
def getDriverPSNodes(self):
return self.getOrDefault(self.driver_ps_nodes)

class HasOutputMapping(Params):
output_mapping = Param(Params._dummy(), "output_mapping", "Mapping of output tensor to output DataFrame column", typeConverter=TFTypeConverters.toDict)
Expand Down Expand Up @@ -276,6 +281,7 @@ def __init__(self, train_fn, tf_args, export_fn=None):
self._setDefault(input_mapping={},
cluster_size=1,
num_ps=0,
driver_ps_nodes=False,
input_mode=TFCluster.InputMode.SPARK,
protocol='grpc',
tensorboard=False,
Expand Down Expand Up @@ -319,7 +325,8 @@ def _fit(self, dataset):
logging.info("Done saving")

tf_args = self.args.argv if self.args.argv else local_args
cluster = TFCluster.run(sc, self.train_fn, tf_args, local_args.cluster_size, local_args.num_ps, local_args.tensorboard, local_args.input_mode)
cluster = TFCluster.run(sc, self.train_fn, tf_args, local_args.cluster_size, local_args.num_ps,
local_args.tensorboard, local_args.input_mode, driver_ps_nodes=local_args.driver_ps_nodes)
if local_args.input_mode == TFCluster.InputMode.SPARK:
# feed data, using a deterministic order for input columns (lexicographic by key)
input_cols = sorted(self.getInputMapping().keys())
Expand Down