Skip to content

Commit

Permalink
Step 1 of rename cross_tower_ops -> cross_device_ops in
Browse files Browse the repository at this point in the history
`MirroredStrategy.__init__()`.

PiperOrigin-RevId: 218755068
  • Loading branch information
tensorflower-gardener committed Oct 25, 2018
1 parent c1dc36b commit 9a1a8af
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tensorflow/contrib/distribute/python/mirrored_strategy.py
Expand Up @@ -346,24 +346,27 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
num_gpus_per_worker: number of GPUs per worker. This is the same as
`num_gpus` and only one of `num_gpus` and `num_gpus_per_worker` can be
specified.
cross_tower_ops: optional, a descedant of `CrossDeviceOps`. If this is not
cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not
set, the `configure` method will try to find the best one.
prefetch_on_device: optional boolean to specify whether to prefetch input
data to devices.
auto_shard_dataset: whether to auto-shard the dataset when there are
multiple workers.
cross_tower_ops: Deprecated alias for `cross_device_ops`.
"""

def __init__(self,
devices=None,
num_gpus=None,
num_gpus_per_worker=None,
cross_tower_ops=None,
cross_device_ops=None,
prefetch_on_device=None,
auto_shard_dataset=False):
auto_shard_dataset=False,
cross_tower_ops=None):
super(MirroredStrategy, self).__init__()

self._cross_tower_ops = cross_tower_ops
assert not (cross_device_ops and cross_tower_ops)
self._cross_tower_ops = cross_device_ops or cross_tower_ops
self._prefetch_on_device = prefetch_on_device
self._auto_shard_dataset = auto_shard_dataset
# Remember num GPUs which might be needed by `configure` method.
Expand Down

0 comments on commit 9a1a8af

Please sign in to comment.