Skip to content

Commit

Permalink
add per_replica support for keras
Browse files Browse the repository at this point in the history
  • Loading branch information
kushanam authored and serach24 committed Jun 4, 2021
1 parent fbb913d commit 568625c
Showing 1 changed file with 47 additions and 26 deletions.
73 changes: 47 additions & 26 deletions tensorflow/python/distribute/input_lib.py
Expand Up @@ -29,6 +29,7 @@
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
Expand Down Expand Up @@ -759,11 +760,11 @@ class DistributedIteratorSpec(type_spec.TypeSpec):

__slots__ = [
"_input_workers", "_element_spec", "_strategy",
"_enable_get_next_as_optional"
"_enable_get_next_as_optional", "_options"
]

def __init__(self, input_workers, element_spec, strategy,
enable_get_next_as_optional):
enable_get_next_as_optional, options):
# We don't want to allow deserialization of this class because we don't
# serialize the strategy object. Currently the only places where
# _deserialize is called is when we save/restore using SavedModels.
Expand All @@ -775,6 +776,7 @@ def __init__(self, input_workers, element_spec, strategy,
self._element_spec = element_spec
self._strategy = strategy
self._enable_get_next_as_optional = enable_get_next_as_optional
self._options = options

@property
def value_type(self):
Expand All @@ -784,7 +786,7 @@ def _serialize(self):
# We cannot serialize the strategy object so we convert it to an id that we
# can use for comparison.
return (self._input_workers.serialize(),
self._element_spec, id(self._strategy))
self._element_spec, id(self._strategy), id(self._options))

def _deserialize(self):
raise ValueError("Deserialization is currently unsupported for "
Expand Down Expand Up @@ -816,7 +818,8 @@ def most_specific_compatible_type(self, other):
other._element_spec)
return DistributedIteratorSpec(self._input_workers, element_spec,
self._strategy,
self._enable_get_next_as_optional)
self._enable_get_next_as_optional,
self._options)

@property
def _component_specs(self):
Expand All @@ -828,7 +831,8 @@ def _component_specs(self):
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
compute_devices,
element_spec))
element_spec,
self._options))
return specs

def _to_components(self, value):
Expand All @@ -841,22 +845,25 @@ def _from_components(self, components):
components=components,
element_spec=self._element_spec,
strategy=self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
enable_get_next_as_optional=self._enable_get_next_as_optional,
options=self._options)

@staticmethod
def from_value(value):
# pylint: disable=protected-access
return DistributedIteratorSpec(value._input_workers, value._element_spec,
value._strategy,
value._enable_get_next_as_optional)
value._enable_get_next_as_optional,
value._options)

def _with_tensor_ranks_only(self):
element_spec = nest.map_structure(
lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access
self._element_spec)
return DistributedIteratorSpec(self._input_workers, element_spec,
self._strategy,
self._enable_get_next_as_optional)
self._enable_get_next_as_optional,
self._options)


class DistributedIterator(DistributedIteratorBase,
Expand All @@ -869,14 +876,16 @@ def __init__(self,
strategy=None,
components=None,
element_spec=None,
enable_get_next_as_optional=False):
enable_get_next_as_optional=False,
options=None):
if input_workers is None:
raise ValueError("`input_workers` should be "
"provided.")

error_message = ("Either `input_workers` or "
"both `components` and `element_spec` need to be "
"provided.")
self._options = options

if iterators is None:
if (components is None or element_spec is None):
Expand Down Expand Up @@ -916,7 +925,8 @@ def _type_spec(self):
# TODO(b/163362689): remove the comment after the bug if fixed.
return DistributedIteratorSpec(self._input_workers, self._element_spec,
self._strategy,
self._enable_get_next_as_optional)
self._enable_get_next_as_optional,
self._options)


class _IterableInput(DistributedDatasetInterface):
Expand Down Expand Up @@ -1290,7 +1300,8 @@ def __iter__(self):
input_workers=self._input_workers,
iterators=iterators,
strategy=self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
enable_get_next_as_optional=self._enable_get_next_as_optional,
options=self._options)
iterator._element_spec = self._element_spec # pylint: disable=protected-access

# When async eager is enabled, sometimes the iterator may not finish
Expand Down Expand Up @@ -1585,9 +1596,7 @@ def get_next(self, device, name=None):
"""Get next element for the given device."""
del name
with ops.device(self._worker):
if isinstance(self._iterator,
(multi_device_iterator_ops.OwnedMultiDeviceIterator,
multi_device_iterator_ops.MultiDeviceIterator)):
if _should_use_multi_device_iterator(self._options):
return self._iterator.get_next(device)
else:
return self._iterator.get_next()
Expand Down Expand Up @@ -1665,25 +1674,30 @@ def get_next_as_list(self, name=None):
class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
"""Type specification for `_SingleWorkerOwnedDatasetIterator`."""

__slots__ = ["_worker", "_devices", "_element_spec"]
__slots__ = ["_worker", "_devices", "_element_spec", "_options"]

def __init__(self, worker, devices, element_spec):
self._worker = worker
self._devices = tuple(device_util.canonicalize(d) for d in devices)
self._element_spec = element_spec
self._options = options

@property
def value_type(self):
return _SingleWorkerOwnedDatasetIterator

def _serialize(self):
return (self._worker, self._devices, self._element_spec)
return (self._worker, self._devices, self._element_spec, self._options)

@property
def _component_specs(self):
specs = []
specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec(
self._devices, self._worker, element_spec=self._element_spec))
if _should_use_multi_device_iterator(self._options):
specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec(
self._devices, self._worker, element_spec=self._element_spec))
else:
specs.append(iterator_ops.IteratorSpec(
element_spec=self._element_spec))
return specs

def _to_components(self, value):
Expand All @@ -1695,13 +1709,14 @@ def _from_components(self, components):
worker=self._worker,
devices=self._devices,
components=components,
element_spec=self._element_spec)
element_spec=self._element_spec,
options=self._options)

@staticmethod
def from_value(value):
# pylint: disable=protected-access
return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
value._element_spec)
value._element_spec, value._options)


class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
Expand Down Expand Up @@ -1758,10 +1773,7 @@ def _make_iterator(self):
if not self._worker:
raise ValueError("Worked device must be specified when creating an "
"owned iterator.")
if (self._options is None or self._options.experimental_replication_mode ==
InputReplicationMode.PER_WORKER or
(self._options.experimental_replication_mode == InputReplicationMode
.PER_REPLICA and self._options.experimental_prefetch_to_device)):
if _should_use_multi_device_iterator(self._options):
host_device = device_util.get_host_for_device(self._worker)
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
Expand All @@ -1777,7 +1789,7 @@ def element_spec(self):
@property
def _type_spec(self):
return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
self._element_spec)
self._element_spec, self._options)

@property
def output_classes(self):
Expand Down Expand Up @@ -1908,7 +1920,7 @@ def _create_iterators_per_worker(worker_datasets,
options=options)
else:
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
worker_devices)
worker_devices, options)
iterators.append(iterator)
return iterators

Expand Down Expand Up @@ -1988,6 +2000,15 @@ def _get_dataset_attributes(dataset):

return batch_size, drop_remainder, prefetch_buffer

def _should_use_multi_device_iterator(options):
"""Determine whether to use multi_device_iterator_ops.OwnedMultiDeviceIterator"""
if (options is None
or options.experimental_replication_mode == InputReplicationMode.PER_WORKER
or (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
and options.experimental_prefetch_to_device)):
return True
return False


class MultiStepContext(object):
"""A context object that can be used to capture things when running steps.
Expand Down

0 comments on commit 568625c

Please sign in to comment.