Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions examples/mnist/keras/mnist_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def build_and_compile_cnn_model():
# callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=args.model_dir)]
tf.io.gfile.makedirs(args.model_dir)
filepath = args.model_dir + "/weights-{epoch:04d}"
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True)]
callbacks = [
tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True),
tf.keras.callbacks.TensorBoard(log_dir=args.model_dir)
]

with strategy.scope():
multi_worker_model = build_and_compile_cnn_model()
Expand Down Expand Up @@ -90,5 +93,5 @@ def build_and_compile_cnn_model():
args = parser.parse_args()
print("args:", args)

cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief')
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard, input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief', log_dir=args.model_dir)
cluster.shutdown()
13 changes: 10 additions & 3 deletions tensorflowonspark/TFSparkNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def run(fn, tf_args, cluster_meta, tensorboard, log_dir, queues, background):
"""
def _mapfn(iter):
import tensorflow as tf
from packaging import version

# Note: consuming the input iterator helps Pyspark re-use this worker,
for i in iter:
Expand Down Expand Up @@ -198,10 +199,12 @@ def _mapfn(iter):
logger.debug("CLASSPATH: {0}".format(hadoop_classpath))
os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath

# start TensorBoard if requested
# start TensorBoard if requested, on 'worker:0' if available (for backwards-compatibility), otherwise on 'chief:0' or 'master:0'
job_names = sorted([k for k in cluster_template.keys() if k in ['chief', 'master', 'worker']])
tb_job_name = 'worker' if 'worker' in job_names else job_names[0]
tb_pid = 0
tb_port = 0
if tensorboard and job_name == 'worker' and task_index == 0:
if tensorboard and job_name == tb_job_name and task_index == 0:
tb_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tb_sock.bind(('', 0))
tb_port = tb_sock.getsockname()[1]
Expand All @@ -223,7 +226,11 @@ def _mapfn(iter):
raise Exception("Unable to find 'tensorboard' in: {}".format(search_path))

# launch tensorboard
tb_proc = subprocess.Popen([pypath, tb_path, "--logdir=%s" % logdir, "--port=%d" % tb_port], env=os.environ)
if version.parse(tf.__version__) >= version.parse('2.0.0'):
tb_proc = subprocess.Popen([pypath, tb_path, "--reload_multifile=True", "--logdir=%s" % logdir, "--port=%d" % tb_port], env=os.environ)
else:
tb_proc = subprocess.Popen([pypath, tb_path, "--logdir=%s" % logdir, "--port=%d" % tb_port], env=os.environ)

tb_pid = tb_proc.pid

# check server to see if this task is being retried (i.e. already reserved)
Expand Down
9 changes: 4 additions & 5 deletions tensorflowonspark/compat.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
# Copyright 2019 Yahoo Inc / Verizon Media
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.
"""Helper functions to abstract API changes between TensorFlow versions."""
"""Helper functions to abstract API changes between TensorFlow versions, intended for end-user TF code."""

import tensorflow as tf

TF_VERSION = tf.__version__
from packaging import version


def export_saved_model(model, export_dir, is_chief=False):
if TF_VERSION == '2.0.0':
if version.parse(tf.__version__) == version.parse('2.0.0'):
if is_chief:
tf.keras.experimental.export_saved_model(model, export_dir)
else:
model.save(export_dir, save_format='tf')


def disable_auto_shard(options):
if TF_VERSION == '2.0.0':
if version.parse(tf.__version__) == version.parse('2.0.0'):
options.experimental_distribute.auto_shard = False
else:
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF