diff --git a/docs/.buildinfo b/docs/.buildinfo
index 4ce8b70f..36fb09ab 100644
--- a/docs/.buildinfo
+++ b/docs/.buildinfo
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
-config: 874109c9e8f56215fdcb46cac4aab9f9
+config: abbb35398bf3c41c0f421213a6263bf9
tags: 645f666f9bcd5a90fca523b33c5a78b7
diff --git a/docs/_modules/index.html b/docs/_modules/index.html
index a086bc32..10c1b7d0 100644
--- a/docs/_modules/index.html
+++ b/docs/_modules/index.html
@@ -4,27 +4,18 @@
+ logging.warn("Unknown scheme {0} with relative path: {1}".format(ctx.defaultFS,path))
+ return"{0}/{1}".format(ctx.defaultFS,path)
[docs]defstart_cluster_server(ctx,num_gpus=1,rdma=False):"""Function that wraps the creation of TensorFlow ``tf.train.Server`` for a node in a distributed TensorFlow cluster.
@@ -155,12 +146,12 @@
Source code for tensorflowonspark.TFNode
# Create and start a server for the local task.server=tf.train.Server(cluster,ctx.job_name,ctx.task_index)
-
- return(cluster,server)
+
+ return(cluster,server)
[docs]defnext_batch(mgr,batch_size,qname='input'):
- """*DEPRECATED*. Use TFNode.DataFeed class instead."""
- raiseException("DEPRECATED: Use TFNode.DataFeed class instead")
+ """*DEPRECATED*. Use TFNode.DataFeed class instead."""
+ raiseException("DEPRECATED: Use TFNode.DataFeed class instead")
[docs]defexport_saved_model(sess,export_dir,tag_set,signatures):"""Convenience function to export a saved_model using provided arguments
@@ -203,16 +194,16 @@
[docs]defbatch_results(mgr,results,qname='output'):
- """*DEPRECATED*. Use TFNode.DataFeed class instead."""
- raiseException("DEPRECATED: Use TFNode.DataFeed class instead")
+ """*DEPRECATED*. Use TFNode.DataFeed class instead."""
+ raiseException("DEPRECATED: Use TFNode.DataFeed class instead")
[docs]defterminate(mgr,qname='input'):
- """*DEPRECATED*. Use TFNode.DataFeed class instead."""
- raiseException("DEPRECATED: Use TFNode.DataFeed class instead")
+ """*DEPRECATED*. Use TFNode.DataFeed class instead."""
+ raiseException("DEPRECATED: Use TFNode.DataFeed class instead")
[docs]classDataFeed(object):"""This class manages the *InputMode.SPARK* data feeding process from the perspective of the TensorFlow application.
@@ -278,12 +269,12 @@
[docs]defabsolute_path(self,path):
- """Convenience function to access ``TFNode.hdfs_path`` directly from this object instance."""
- returnTFNode.hdfs_path(self,path)
+ """Convenience function to access ``TFNode.hdfs_path`` directly from this object instance."""
+ returnTFNode.hdfs_path(self,path)
[docs]defstart_cluster_server(self,num_gpus=1,rdma=False):
- """Convenience function to access ``TFNode.start_cluster_server`` directly from this object instance."""
- returnTFNode.start_cluster_server(self,num_gpus,rdma)
+ """Convenience function to access ``TFNode.start_cluster_server`` directly from this object instance."""
[docs]defexport_saved_model(self,sess,export_dir,tag_set,signatures):
- """Convenience function to access ``TFNode.export_saved_model`` directly from this object instance."""
- TFNode.export_saved_model(sess,export_dir,tag_set,signatures)
+ """Convenience function to access ``TFNode.export_saved_model`` directly from this object instance."""
[docs]defget_data_feed(self,train_mode=True,qname_in='input',qname_out='output',input_mapping=None):
- """Convenience function to access ``TFNode.DataFeed`` directly from this object instance."""
- returnTFNode.DataFeed(self.mgr,train_mode,qname_in,qname_out,input_mapping)
+ """Convenience function to access ``TFNode.DataFeed`` directly from this object instance."""
+ returnTFNode.DataFeed(self.mgr,train_mode,qname_in,qname_out,input_mapping)
This also manages a reference to the TFManager "singleton" per executor. Since Spark can spawn more than one python-worker
per executor, this will reconnect to the "singleton" instance as needed. """
- mgr=None#: TFManager instance
- cluster_id=None#: Unique ID for a given TensorFlowOnSpark cluster, used for invalidating state for new clusters.
+ mgr=None#: TFManager instance
+ cluster_id=None#: Unique ID for a given TensorFlowOnSpark cluster, used for invalidating state for new clusters.def_get_manager(cluster_info,host,ppid):"""Returns this executor's "singleton" instance of the multiprocessing.Manager, reconnecting per python-worker if needed.
@@ -144,6 +138,14 @@
Source code for tensorflowonspark.TFSparkNode
authkey =node['authkey']TFSparkNode.mgr=TFManager.connect(addr,authkey)break
+
+ ifTFSparkNode.mgrisNone:
+ msg="No TFManager found on this node, please ensure that:\n"+ \
+ "1. Spark num_executors matches TensorFlow cluster_size\n"+ \
+ "2. Spark cores/tasks per executor is 1.\n"+ \
+ "3. Spark dynamic allocation is disabled."
+ raiseException(msg)
+
logging.info("Connected to TFSparkNode.mgr on {0}, ppid={1}, state={2}".format(host,ppid,str(TFSparkNode.mgr.get('state'))))returnTFSparkNode.mgr
@@ -199,7 +201,7 @@
Source code for tensorflowonspark.TFSparkNode
addr =Noneifjob_name=='ps':# PS nodes must be remotely accessible in order to shutdown from Spark driver.
- TFSparkNode.mgr=TFManager.start(authkey,['control'],'remote')
+ TFSparkNode.mgr=TFManager.start(authkey,['control','error'],'remote')addr=(host,TFSparkNode.mgr.address[1])else:# worker nodes only need to be locally accessible within the executor for data feeding
@@ -285,7 +287,11 @@
Source code for tensorflowonspark.TFSparkNode
# construct a TensorFlow clusterspec from cluster_info
sorted_cluster_info=sorted(cluster_info,key=lambdak:k['worker_num'])spec={}
+ last_worker_num=-1fornodeinsorted_cluster_info:
+ if(node['worker_num']==last_worker_num):
+ raiseException("Duplicate worker/task in cluster_info")
+ last_worker_num=node['worker_num']logging.info("node: {0}".format(node))(njob,nhost,nport)=(node['job_name'],node['host'],node['port'])hosts=[]ifnjobnotinspecelsespec[njob]
@@ -315,11 +321,21 @@
Source code for tensorflowonspark.TFSparkNode
sys.argv=argsfn(args,context)
+ defwrapper_fn_background(args,context):
+ """Wrapper function that signals exceptions to foreground process."""
+ errq=TFSparkNode.mgr.get_queue('error')
+ try:
+ wrapper_fn(args,context)
+ exceptException:
+ errq.put(traceback.format_exc())
+ errq.join()
+
ifjob_name=='ps'orbackground:# invoke the TensorFlow main function in a background threadlogging.info("Starting TensorFlow {0}:{1} as {2} on cluster node {3} on background process".format(job_name,task_index,job_name,worker_num))
- p=multiprocessing.Process(target=wrapper_fn,args=(tf_args,ctx))
+
+ p=multiprocessing.Process(target=wrapper_fn_background,args=(tf_args,ctx))ifjob_name=='ps':p.daemon=Truep.start()
@@ -327,8 +343,15 @@
Source code for tensorflowonspark.TFSparkNode
# for ps nodes only, wait indefinitely in foreground thread for a "control" event (None == "stop")
ifjob_name=='ps':queue=TFSparkNode.mgr.get_queue('control')
+ equeue=TFSparkNode.mgr.get_queue('error')done=Falsewhilenotdone:
+ while(queue.empty()andequeue.empty()):
+ time.sleep(1)
+ if(notequeue.empty()):
+ e_str=equeue.get()
+ equeue.task_done()
+ raiseException("exception in ps:\n"+e_str)msg=queue.get(block=True)logging.info("Got msg: {0}".format(msg))ifmsgisNone:
@@ -341,8 +364,8 @@
Source code for tensorflowonspark.TFSparkNode
logging.info("Starting TensorFlow {0}:{1} on cluster node {2} on foreground thread".format(job_name,task_index,worker_num))wrapper_fn(tf_args,ctx)logging.info("Finished TensorFlow {0}:{1} on cluster node {2}".format(job_name,task_index,worker_num))
-
- return_mapfn
+
+ return_mapfn
[docs]deftrain(cluster_info,cluster_meta,qname='input'):"""Feeds Spark partitions into the shared multiprocessing.Queue.
@@ -358,7 +381,13 @@
Source code for tensorflowonspark.TFSparkNode
def _train(iter):# get shared queue, reconnecting if necessarymgr=_get_manager(cluster_info,util.get_ip_address(),os.getppid())
- queue=mgr.get_queue(qname)
+ try:
+ queue=mgr.get_queue(qname)
+ equeue=mgr.get_queue('error')
+ except(AttributeError,KeyError):
+ msg="Queue '{}' not found on this node, check for exceptions on other nodes.".format(qname)
+ raiseException(msg)
+
state=str(mgr.get('state'))logging.info("mgr.state={0}".format(state))terminating=state=="'terminating'"
@@ -368,15 +397,23 @@
Source code for tensorflowonspark.TFSparkNode
for iteminiter:count+=1logging.info("Skipped {0} items from partition".format(count))
-
else:logging.info("Feeding partition {0} into {1} queue {2}".format(iter,qname,queue))count=0foriteminiter:count+=1queue.put(item,block=True)
+
# wait for consumers to finish processing all items in queue before "finishing" this iterator
- queue.join()
+ joinThr=Thread(target=queue.join)
+ joinThr.start()
+ while(joinThr.isAlive()):
+ if(notequeue.empty()):
+ e_str=equeue.get()
+ equeue.task_done()
+ raiseException("exception in worker:\n"+e_str)
+ time.sleep(1)
+# queue.join()logging.info("Processed {0} items in partition".format(count))# check if TF is terminating feed after this partition
@@ -392,8 +429,8 @@
Source code for tensorflowonspark.TFSparkNode
# ignore any errors while requesting stop
logging.debug("Error while requesting stop: {0}".format(e))return[terminating]
-
- return_train
+
+ return_train
[docs]definference(cluster_info,qname='input'):"""Feeds Spark partitions into the shared multiprocessing.Queue and returns inference results.
@@ -408,7 +445,12 @@
Source code for tensorflowonspark.TFSparkNode
def _inference(iter):# get shared queue, reconnecting if necessarymgr=_get_manager(cluster_info,util.get_ip_address(),os.getppid())
- queue_in=mgr.get_queue(qname)
+ try:
+ queue_in=mgr.get_queue(qname)
+ equeue=mgr.get_queue('error')
+ except(AttributeError,KeyError):
+ msg="Queue '{}' not found on this node, check for exceptions on other nodes.".format(qname)
+ raiseException(msg)logging.info("Feeding partition {0} into {1} queue {2}".format(iter,qname,queue_in))count=0
@@ -424,7 +466,15 @@
Source code for tensorflowonspark.TFSparkNode
return []# wait for consumers to finish processing all items in queue before "finishing" this iterator
- queue_in.join()
+ joinThr=Thread(target=queue_in.join)
+ joinThr.start()
+ while(joinThr.isAlive()):
+ if(notequeue.empty()):
+ e_str=equeue.get()
+ equeue.task_done()
+ raiseException("exception in worker:\n"+e_str)
+ time.sleep(1)
+
logging.info("Processed {0} items in partition".format(count))# read result queue
@@ -438,8 +488,8 @@
[docs]defshutdown(cluster_info,queues=['input']):"""Stops all TensorFlow nodes by feeding ``None`` into the multiprocessing.Queues.
@@ -469,15 +519,19 @@
Source code for tensorflowonspark.TFSparkNode
# terminate any listening queues
logging.info("Stopping all queues")forqinqueues:
- queue=mgr.get_queue(q)
- logging.info("Feeding None into {0} queue".format(q))
- queue.put(None,block=True)
+ try:
+ queue=mgr.get_queue(q)
+ logging.info("Feeding None into {0} queue".format(q))
+ queue.put(None,block=True)
+ except(AttributeError,KeyError):
+ msg="Queue '{}' not found on this node, check for exceptions on other nodes.".format(q)
+ raiseException(msg)logging.info("Setting mgr.state to 'stopped'")mgr.set('state','stopped')return[True]
-
- return_shutdown
raiseException("Unable to find free GPU:\n{0}".format(smi_output))return','.join(free_gpus[:num_gpu])
- exceptsubprocess.CalledProcessErrorase:
- print("nvidia-smi error",e.output)
+ exceptsubprocess.CalledProcessErrorase:
+ print("nvidia-smi error",e.output)# Function to get the gpu informationdef_get_free_gpu(max_gpu_utilization=40,min_free_memory=0.5,num_gpu=1):
@@ -221,12 +212,14 @@
Args:
:meta: a dictonary of metadata about a node """
- withself.lock:
- self.reservations.append(meta)
+ withself.lock:
+ self.reservations.append(meta)
[docs]defdone(self):"""Returns True if the ``required`` number of reservations have been fulfilled."""
- withself.lock:
- returnlen(self.reservations)>=self.required
+ withself.lock:
+ returnlen(self.reservations)>=self.required
[docs]defget(self):"""Get the list of current reservations."""
- withself.lock:
- returnself.reservations
+ withself.lock:
+ returnself.reservations
[docs]defremaining(self):"""Get a count of remaining/unfulfilled reservations."""
- withself.lock:
- returnself.required-len(self.reservations)
[docs]defawait_reservations(self,sc,status={},timeout=600):"""Block until all reservations are received."""
+ timespent=0whilenotself.reservations.done():logging.info("waiting for {0} reservations".format(self.reservations.remaining()))
+ # check status flags for any errors
+ if'error'instatus:
+ sc.cancelAllJobs()
+ sc.stop()
+ sys.exit(1)time.sleep(1)
- logging.info("all reservations completed")
- returnself.reservations.get()
+ timespent+=1
+ if(timespent>timeout):
+ raiseException("timed out waiting for reservations to complete")
+ logging.info("all reservations completed")
[docs]defawait_reservations(self):"""Poll until all reservations completed, then return cluster_info."""done=Falsewhilenotdone:done=self._request('QUERY')
- time.sleep(1)
- returnself.get_reservations()
+ time.sleep(1)
+ returnself.get_reservations()
[docs]defrequest_stop(self):"""Request server stop."""
- resp=self._request('STOP')
- returnresp
[docs]defget_ip_address():"""Simple utility to get host IP address."""s=socket.socket(socket.AF_INET,socket.SOCK_DGRAM)
- s.connect(("8.8.8.8",80))
- returns.getsockname()[0]
This module provides a high-level API to manage the TensorFlowOnSpark cluster.
-
There are three main phases of operation:
-
-
Reservation/Startup - reserves a port for the TensorFlow process on each executor, starts a multiprocessing.Manager to
-listen for data/control messages, and then launches the Tensorflow main function on the executors.
-
Data feeding - For InputMode.SPARK only. Sends RDD data to the TensorFlow nodes via each executor’s multiprocessing.Manager. PS
-nodes will tie up their executors, so they won’t receive any subsequent data feeding tasks.
-
Shutdown - sends a shutdown control message to the multiprocessing.Managers of the PS nodes and pushes end-of-feed markers into the data
-queues of the worker nodes.
For InputMode.SPARK only: Feeds Spark RDD partitions into the TensorFlow worker nodes and returns an RDD of results
-
It is the responsibility of the TensorFlow “main” function to interpret the rows of the RDD and provide valid data for the output RDD.
-
This will use the distributed TensorFlow cluster for inferencing, so the TensorFlow “main” function should be capable of inferencing.
-Per Spark design, the output RDD will be lazily-executed only when a Spark action is invoked on the RDD.
-
-
Args:
-
-
-
-
-
dataRDD:
input data as a Spark RDD
-
-
qname:
INTERNAL_USE
-
-
-
-
-
Returns:
-
A Spark RDD representing the output of the TensorFlow inferencing
For InputMode.SPARK only. Feeds Spark RDD partitions into the TensorFlow worker nodes
-
It is the responsibility of the TensorFlow “main” function to interpret the rows of the RDD.
-
Since epochs are implemented via RDD.union() and the entire RDD must generally be processed in full, it is recommended
-to set num_epochs to closely match your training termination condition (e.g. steps or accuracy). See TFNode.DataFeed
-for more details.
-
-
Args:
-
-
-
-
-
dataRDD:
input data as a Spark RDD.
-
-
num_epochs:
number of times to repeat the dataset during training.
Starts the TensorFlowOnSpark cluster and Runs the TensorFlow “main” function on the Spark executors
-
-
Args:
-
-
-
-
-
sc:
SparkContext
-
-
map_fun:
user-supplied TensorFlow “main” function
-
-
tf_args:
argparse args, or command-line ARGV. These will be passed to the map_fun.
-
-
num_executors:
number of Spark executors. This should match your Spark job’s --num_executors.
-
-
num_ps:
number of Spark executors which are reserved for TensorFlow PS nodes. All other executors will be used as TensorFlow worker nodes.
-
-
tensorboard:
boolean indicating if the chief worker should spawn a Tensorboard server.
-
-
input_mode:
TFCluster.InputMode
-
-
log_dir:
directory to save tensorboard event logs. If None, defaults to a fixed path on local filesystem.
-
-
driver_ps_nodes:
-
run the PS nodes on the driver locally instead of on the spark executors; this help maximizing computing resources (esp. GPU). You will need to set cluster_size = num_executors + num_ps
-
-
queues:
INTERNAL_USE
-
-
-
-
-
Returns:
-
A TFCluster object representing the started cluster.
Push a batch of output results to the Spark output RDD of TFCluster.inference().
-
Note: this currently expects a one-to-one mapping of input to output data, so the length of the results array should match the length of
+
Push a batch of output results to the Spark output RDD of TFCluster.inference().
+
Note: this currently expects a one-to-one mapping of input to output data, so the length of the results array should match the length of
the previously retrieved batch of input data.
Args:
@@ -111,12 +102,12 @@
Navigation
Gets a batch of items from the input RDD.
If multiple tensors are provided per row in the input RDD, e.g. tuple of (tensor1, tensor2, …, tensorN) and:
-
no input_mapping was provided to the DataFeed constructor, this will return an array of batch_size tuples,
+
no input_mapping was provided to the DataFeed constructor, this will return an array of batch_size tuples,
and the caller is responsible for separating the tensors.
-
an input_mapping was provided to the DataFeed constructor, this will return a dictionary of N tensors,
-with tensor names as keys and arrays of length batch_size as values.
+
an input_mapping was provided to the DataFeed constructor, this will return a dictionary of N tensors,
+with tensor names as keys and arrays of length batch_size as values.
-
Note: if the end of the data is reached, this may return with fewer than batch_size items.
+
Note: if the end of the data is reached, this may return with fewer than batch_size items.
class TFNodeContext(worker_num, job_name, task_index, cluster_spec, defaultFS, working_dir, mgr)[source]¶
-
Encapsulates unique metadata for a TensorFlowOnSpark node/executor and provides methods to interact with Spark and HDFS.
+
Bases: object
+
Encapsulates unique metadata for a TensorFlowOnSpark node/executor and provides methods to interact with Spark and HDFS.
An instance of this object will be passed to the TensorFlow “main” function via the ctx argument.
To simply the end-user API, this class now mirrors the functions of the TFNode module.
@@ -67,7 +59,7 @@
Navigation
-
worker_num:
integer identifier for this executor, per nodeRDD=sc.parallelize(range(num_executors),num_executors).
+
worker_num:
integer identifier for this executor, per nodeRDD=sc.parallelize(range(num_executors),num_executors).
job_name:
TensorFlow job name (e.g. ‘ps’ or ‘worker’) of this TF node, per cluster_spec.
@@ -75,7 +67,7 @@
Navigation
cluster_spec:
dictionary for constructing a tf.train.ClusterSpec.
-
defaultFS:
string representation of default FileSystem, e.g. file:// or hdfs://<namenode>:8020/.
+
defaultFS:
string representation of default FileSystem, e.g. file:// or hdfs://<namenode>:8020/.
working_dir:
the current working directory for local filesystems, or YARN containers.
Low-level functions used by the high-level TFCluster APIs to manage cluster state.
This class is not intended for end-users (see TFNode for end-user APIs).
For cluster management, this wraps the per-node cluster logic as Spark RDD mapPartitions functions, where the RDD is expected to be
-a “nodeRDD” of the form: nodeRDD=sc.parallelize(range(num_executors),num_executors).
+a “nodeRDD” of the form: nodeRDD=sc.parallelize(range(num_executors),num_executors).
For data feeding, this wraps the feeding logic as Spark RDD mapPartitions functions on a standard “dataRDD”.
This also manages a reference to the TFManager “singleton” per executor. Since Spark can spawn more than one python-worker
per executor, this will reconnect to the “singleton” instance as needed.
@@ -170,7 +162,7 @@
Navigation
fn:
TensorFlow “main” function provided by the user.
-
tf_args:
argparse args, or command line ARGV. These will be passed to the fn.
+
tf_args:
argparse args, or command line ARGV. These will be passed to the fn.
cluster_meta:
dictionary of cluster metadata (e.g. cluster_id, reservation.Server address, etc).
mapPartition function to convert an RDD of serialized tf.train.Example bytestring into an RDD of Row.
-
Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to
-disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a “hint”
-from the caller in the binary_features argument.
-
-
Args:
-
-
-
-
-
iter:
the RDD partition iterator
-
-
binary_features:
-
a list of tf.train.Example features which are expected to be binary/bytearrays.
-
-
-
-
-
Returns:
-
An array/iterator of DataFrame Row with features converted into columns.
Given a tf.train.Example, infer the Spark DataFrame schema (StructFields).
-
Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to
-disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a “hint”
-from the caller in the binary_features argument.
-
-
Args:
-
-
-
-
-
example:
a tf.train.Example
-
-
binary_features:
-
a list of tf.train.Example features which are expected to be binary/bytearrays.
This will attempt to automatically convert the tf.train.Example features into Spark DataFrame columns of equivalent types.
-
Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to
-disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a “hint”
-from the caller in the binary_features argument.
-
-
Args:
-
-
-
-
-
sc:
SparkContext
-
-
input_dir:
location of TFRecords on disk.
-
-
binary_features:
-
a list of tf.train.Example features which are expected to be binary/bytearrays.
-
-
-
-
-
Returns:
-
A Spark DataFrame mirroring the tf.train.Example schema.
mapPartition function to convert a Spark RDD of Row into an RDD of serialized tf.train.Example bytestring.
-
Note that tf.train.Example is a fairly flat structure with limited datatypes, e.g. tf.train.FloatList,
-tf.train.Int64List, and tf.train.BytesList, so most DataFrame types will be coerced into one of these types.
-
-
Args:
-
-
-
-
-
dtypes:
the DataFrame.dtypes of the source DataFrame.
-
-
-
-
-
Returns:
-
A mapPartition function which converts the source DataFrame into tf.train.Example bytestrings.
This module extends the TensorFlowOnSpark API to support Spark ML Pipelines.
-
It provides a TFEstimator class to fit a TFModel using TensorFlow. The TFEstimator will actually spawn a TensorFlowOnSpark cluster
-to conduct distributed training, but due to architectural limitations, the TFModel will only run single-node TensorFlow instances
-when inferencing on the executors. The executors will run in parallel, but the TensorFlow model must fit in the memory
-of each executor.
-
There is also an option to provide a separate “export” function, which allows users to export a different graph for inferencing vs. training.
-This is useful when the training graph uses InputMode.TENSORFLOW with queue_runners, but the inferencing graph needs placeholders.
-And this is especially useful for exporting saved_models for TensorFlow Serving.
-tfrecord_dir = Param(parent='undefined', name='tfrecord_dir', doc='Path to temporarily export a DataFrame as TFRecords (for InputMode.TENSORFLOW apps)')¶
Spark ML Estimator which launches a TensorFlowOnSpark cluster for distributed training.
-
The columns of the DataFrame passed to the fit() method will be mapped to TensorFlow tensors according to the setInputMapping() method.
-
If an export_fn was provided to the constructor, it will be run on a single executor immediately after the distributed training has completed.
-This allows users to export a TensorFlow saved_model with a different execution graph for inferencing, e.g. replacing an input graph of
-TFReaders and QueueRunners with Placeholders.
-
For InputMode.TENSORFLOW, the input DataFrame will be exported as TFRecords to a temporary location specified by the tfrecord_dir.
-The TensorFlow application will then be expected to read directly from this location during training. However, if the input DataFrame was
-produced by the dfutil.loadTFRecords() method, i.e. originated from TFRecords on disk, then the tfrecord_dir will be set to the
-original source location of the TFRecords with the additional export step.
-
-
Args:
-
-
-
-
-
train_fn:
TensorFlow “main” function for training.
-
-
tf_args:
Arguments specific to the TensorFlow “main” function.
Spark ML Model backed by a TensorFlow model checkpoint/saved_model on disk.
-
During transform(), each executor will run an independent, single-node instance of TensorFlow in parallel, so the model must fit in memory.
-The model/session will be loaded/initialized just once for each Spark Python worker, and the session will be cached for
-subsequent tasks/partitions to avoid re-loading the model for each partition.
-
-
Args:
-
-
-
-
-
tf_args:
Dictionary of arguments specific to TensorFlow “main” function.
Simple utility to shutdown a Spark StreamingContext by signaling the reservation Server.
-Note: use the reservation server address (host, port) reported in the driver logs.