Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Tensorflow eager mode #670

Merged
merged 28 commits into from
Jan 9, 2019
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c7a6843
support tensorflow eager mode
kuroko1t Nov 26, 2018
40cbd5a
add tf.GradientTape wrapper
kuroko1t Nov 27, 2018
5763876
remove space
kuroko1t Nov 27, 2018
642736b
modify comment
kuroko1t Nov 27, 2018
2c80222
modify eager example
kuroko1t Nov 27, 2018
f4e134b
modify eager example
kuroko1t Nov 27, 2018
4f522b8
add eager utest
kuroko1t Nov 29, 2018
da38b1b
update eager utest
kuroko1t Nov 29, 2018
2179d1e
modify file with autopep8
kuroko1t Nov 30, 2018
a6c60fd
modify eager test
kuroko1t Dec 9, 2018
8f9ff60
modify bcat timing
kuroko1t Dec 9, 2018
73e3e72
add comment
kuroko1t Dec 10, 2018
e101014
fix gpu utest, modify distributedtape, fix example
kuroko1t Dec 11, 2018
5a5adf2
delete Subtest for python2
kuroko1t Dec 13, 2018
ffa00c5
update utest
kuroko1t Dec 17, 2018
c7c1498
mege uber master
kuroko1t Dec 17, 2018
5140e92
Merge branch 'master' into tf_eager
kuroko1t Dec 17, 2018
e4b62b8
update eager mnist sample
kuroko1t Dec 17, 2018
8aeb8f4
update eager mnist sample
kuroko1t Dec 17, 2018
b3467f4
modify broadcast_global_variables
kuroko1t Dec 20, 2018
d9e113a
modify style
kuroko1t Dec 20, 2018
d672eff
modify gradients func
kuroko1t Dec 23, 2018
c0f8a27
change eager utest, check hasattr in eager mode
kuroko1t Jan 2, 2019
4306909
support eager tf version > 1.9.0 for custom op error
kuroko1t Jan 2, 2019
365629a
add run_all_in_graph_and_eager_modes_with_config
kuroko1t Jan 6, 2019
977b9d1
remove run_all_in_graph_and_eager_modes_with_config
kuroko1t Jan 6, 2019
1c1eadc
tf 1.9.0 hasnt watch_accessed_variables, add eager example travis.yml
kuroko1t Jan 7, 2019
b1de533
update example
kuroko1t Jan 8, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions examples/tensorflow_mnist_eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2017 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/usr/bin/env python

import tensorflow as tf
import horovod.tensorflow as hvd


def main(_):
# Horovod: initialize Horovod.
hvd.init()

# Horovod: pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.visible_device_list = str(hvd.local_rank())

tf.enable_eager_execution(config=config)

mnist_model = tf.keras.Sequential([
tf.keras.layers.Conv2D(16, [3, 3], activation='relu'),
tf.keras.layers.Conv2D(16, [3, 3], activation='relu'),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(10)
])

# Horovod: adjust learning rate based on number of GPUs.
opt = tf.train.RMSPropOptimizer(0.001 * hvd.size())

(mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data()

dataset = tf.data.Dataset.from_tensor_slices(
(tf.cast(mnist_images[..., tf.newaxis] / 255, tf.float32),
tf.cast(mnist_labels, tf.int64))
)
dataset = dataset.shuffle(1000).batch(32)

# Horovod: save checkpoints only on worker 0 to prevent other workers from
checkpoint_dir = './checkpoints'
step_counter = tf.train.get_or_create_global_step()
checkpoint = tf.train.Checkpoint(
model=mnist_model, optimizer=opt, step_counter=step_counter)

hvd.bcast(0, mnist_model.variables)
# Horovod: adjust number of steps based on number of GPUs.
for (batch, (images, labels)) in enumerate(
dataset.take(20000 // hvd.size())):
with tf.GradientTape() as tape:
logits = mnist_model(images, training=True)
loss_value = tf.losses.sparse_softmax_cross_entropy(labels, logits)
# Horovod: broadcast initial variable states
alsrgv marked this conversation as resolved.
Show resolved Hide resolved
# from rank 0 to all other processes. This is necessary to ensure consistent
# initialization of all workers when training is started with random weights
# or restored from a checkpoint.
hvd.bcast(0, mnist_model.variables) if batch == 0 else None

# Horovod: add Horovod Distributed GradientTape.
tape = hvd.DistributedGradientTape(tape)

grads = tape.gradient(loss_value, mnist_model.variables)
opt.apply_gradients(zip(grads, mnist_model.variables),
global_step=tf.train.get_or_create_global_step())
if batch % 10 == 0 and hvd.local_rank() == 0:
print('Step #%d\tLoss: %.6f' % (batch, loss_value))

checkpoint.save(checkpoint_dir) if hvd.rank() == 0 else None
alsrgv marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
tf.app.run()
89 changes: 88 additions & 1 deletion horovod/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from horovod.tensorflow.mpi_ops import mpi_threads_supported

import tensorflow as tf

from tensorflow.python.eager import context

def allreduce(tensor, average=True, device_dense='', device_sparse='',
compression=Compression.none):
Expand Down Expand Up @@ -97,6 +97,17 @@ def broadcast_global_variables(root_rank):
return tf.group(*[tf.assign(var, broadcast(var, root_rank))
for var in tf.global_variables()])

def bcast(root_rank, variables):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename this to broadcast_variables?

Copy link
Contributor Author

@kuroko1t kuroko1t Dec 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. i renamed to broadcast_variable

"""Broadcasts variables from root rank to all other processes.

Arguments:
root_rank: rank of the process from which global variables will be broadcasted
to all other processes.
variables: variables for broadcast
"""
return tf.group(*[tf.assign(var, broadcast(var, root_rank))
for var in variables])


class BroadcastGlobalVariablesHook(tf.train.SessionRunHook):
"""
Expand Down Expand Up @@ -223,3 +234,79 @@ def get_slot_names(self, *args, **kwargs):
def variables(self, *args, **kwargs):
"""Calls this same method on the underlying optimizer."""
return self._optimizer.variables(*args, **kwargs)


class _DistributedGradientTape(tf.GradientTape):
alsrgv marked this conversation as resolved.
Show resolved Hide resolved
"""An tape that wraps another tf.GradientTape, using an allreduce to
alsrgv marked this conversation as resolved.
Show resolved Hide resolved
average gradient values before applying gradients to model weights.

Args:
gradtape:
GradientTape to use for computing gradients and applying updates.
persistent:
alsrgv marked this conversation as resolved.
Show resolved Hide resolved
Boolean controlling whether a persistent gradient tape
is created. False by default, which means at most one call can
be made to the gradient() method on this object.
watch_accessed_variables:
Boolean controlling whether the tape will
automatically `watch` any (trainable) variables accessed while the tape
is active. Defaults to True meaning gradients can be requested from any
result computed in the tape derived from reading a trainable `Variable`.
If False users must explicitly `watch` any `Variable`s they want to
request gradients from.
device_dense:
Device to be used for dense tensors. Uses GPU by default
if Horovod was build with HOROVOD_GPU_ALLREDUCE.
device_sparse:
Device to be used for sparse tensors. Uses GPU by default
if Horovod was build with HOROVOD_GPU_ALLGATHER.
compression:
Compression algorithm used during allreduce to reduce the amount
of data sent during the each parameter update step. Defaults to
not using compression.
sparse_as_dense:
Treat all sparse gradients as dense tensors. This can help improve
performance and memory utilization if the original sparse gradient
has high density. Defaults to false.
"""

def __init__(self, persistent, watch_accessed_variables,
tape, device_dense='', device_sparse='',
compression=Compression.none, sparse_as_dense=False):
super(self.__class__, self).__init__(persistent, watch_accessed_variables)
self._tape = tape
self._persistent = persistent
self._watch_accessed_variables = watch_accessed_variables
self._name = "Distributed"
self._device_dense = device_dense
self._device_sparse = device_sparse
self._compression = compression
self._sparse_as_dense = sparse_as_dense

def gradient(self, target, sources, output_gradients=None):
gradients = super(self.__class__, self).gradient(target, sources, output_gradients)
if size() > 1:
averaged_gradients = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you refactor this per the changes to DistributedOptimizer in #704? Basically, without using tf.contrib.eager.defun, this will be very slow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. i updated it

with tf.name_scope(self._name + "_Allreduce"):
for grad in gradients:
if self._sparse_as_dense and \
isinstance(grad, tf.IndexedSlices):
grad = tf.convert_to_tensor(grad)
avg_grad = allreduce(grad,
tgaddair marked this conversation as resolved.
Show resolved Hide resolved
device_dense=self._device_dense,
device_sparse=self._device_sparse,
compression=self._compression)
averaged_gradients.append(avg_grad)
return averaged_gradients
else:
return gradients


def DistributedGradientTape(gradtape, device_dense='', device_sparse='', compression=Compression.none,
sparse_as_dense=False):
cls = type(gradtape.__class__.__name__, (gradtape.__class__,),
dict(_DistributedGradientTape.__dict__))
return cls(gradtape._persistent, gradtape._watch_accessed_variables,
gradtape._tape,
device_dense='', device_sparse='',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should pass device_dense and etc. correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i updated it.

compression=Compression.none, sparse_as_dense=False)
12 changes: 9 additions & 3 deletions horovod/tensorflow/mpi_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ def _allreduce(tensor, name=None):
A tensor of the same shape and type as `tensor`, summed across all
processes.
"""
if name is None:
if name is None and not tf.executing_eagerly():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to make eager execution a special case here? For some things, like the upcoming autotuning framework, we rely on differentiating tensors by their name for things like detecting loops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when eager mode, tensor.name() is meaningless.So i add this code.reference here
If we need to distinguish tensor, we need to use something that changes to tensor.name ().
I am looking into that way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. One issue here is that older versions of TF <1.7 will not support tf.executing_eagerly(). I made a change in #704 to address this by adding a new utility function. I think this should work for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I update it

name = 'HorovodAllreduce_%s' % _normalize_name(tensor.name)
if tf.executing_eagerly():
name = 'HorovodAllreduce'
return MPI_LIB.horovod_allreduce(tensor, name=name)


Expand Down Expand Up @@ -118,8 +120,10 @@ def allgather(tensor, name=None):
the first dimension, which may be greater and is the sum of all first
dimensions of the tensors in different Horovod processes.
"""
if name is None:
if name is None and not tf.executing_eagerly():
name = 'HorovodAllgather_%s' % _normalize_name(tensor.name)
if tf.executing_eagerly():
name = 'HorovodAllgather'
return MPI_LIB.horovod_allgather(tensor, name=name)


Expand Down Expand Up @@ -159,8 +163,10 @@ def broadcast(tensor, root_rank, name=None):
A tensor of the same shape and type as `tensor`, with the value broadcasted
from root rank.
"""
if name is None:
if name is None and not tf.executing_eagerly():
name = 'HorovodBroadcast_%s' % _normalize_name(tensor.name)
if tf.executing_eagerly():
name = 'HorovodBroadcast'
return MPI_LIB.horovod_broadcast(tensor, name=name, root_rank=root_rank)


Expand Down
Loading