/
distribute_lib.py
3860 lines (3221 loc) · 164 KB
/
distribute_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=line-too-long
"""Library for running a computation across multiple devices.
The intent of this library is that you can write an algorithm in a stylized way
and it will be usable with a variety of different `tf.distribute.Strategy`
implementations. Each descendant will implement a different strategy for
distributing the algorithm across multiple devices/machines. Furthermore, these
changes can be hidden inside the specific layers and other library classes that
need special treatment to run in a distributed setting, so that most users'
model definition code can run unchanged. The `tf.distribute.Strategy` API works
the same way with eager and graph execution.
*Guides*
* [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training)
* [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb)
*Tutorials*
* [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/)
The tutorials cover how to use `tf.distribute.Strategy` to do distributed
training with native Keras APIs, custom training loops,
and Estimator APIs. They also cover how to save/load model when using
`tf.distribute.Strategy`.
*Glossary*
* _Data parallelism_ is where we run multiple copies of the model
on different slices of the input data. This is in contrast to
_model parallelism_ where we divide up a single copy of a model
across multiple devices.
Note: we only support data parallelism for now, but
hope to add support for model parallelism in the future.
* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that
TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple
devices on a single machine, or be connected to devices on multiple
machines. Devices used to run computations are called _worker devices_.
Devices used to store variables are _parameter devices_. For some strategies,
such as `tf.distribute.MirroredStrategy`, the worker and parameter devices
will be the same (see mirrored variables below). For others they will be
different. For example, `tf.distribute.experimental.CentralStorageStrategy`
puts the variables on a single device (which may be a worker device or may be
the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the
variables on separate machines called _parameter servers_ (see below).
* A _replica_ is one copy of the model, running on one slice of the
input data. Right now each replica is executed on its own
worker device, but once we add support for model parallelism
a replica may span multiple worker devices.
* A _host_ is the CPU device on a machine with worker devices, typically
used for running input pipelines.
* A _worker_ is defined to be the physical machine(s) containing the physical
devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A
worker may contain one or more replicas, but contains at least one
replica. Typically one worker will correspond to one machine, but in the case
of very large models with model parallelism, one worker may span multiple
machines. We typically run one input pipeline per worker, feeding all the
replicas on that worker.
* _Synchronous_, or more commonly _sync_, training is where the updates from
each replica are aggregated together before updating the model variables. This
is in contrast to _asynchronous_, or _async_ training, where each replica
updates the model variables independently. You may also have replicas
partitioned into groups which are in sync within each group but async between
groups.
* _Parameter servers_: These are machines that hold a single copy of
parameters/variables, used by some strategies (right now just
`tf.distribute.experimental.ParameterServerStrategy`). All replicas that want
to operate on a variable retrieve it at the beginning of a step and send an
update to be applied at the end of the step. These can in principle support
either sync or async training, but right now we only have support for async
training with parameter servers. Compare to
`tf.distribute.experimental.CentralStorageStrategy`, which puts all variables
on a single device on the same machine (and does sync training), and
`tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices
(see below).
* _Replica context_ vs. _Cross-replica context_ vs _Update context_
A _replica context_ applies
when you execute the computation function that was called with `strategy.run`.
Conceptually, you're in replica context when executing the computation
function that is being replicated.
An _update context_ is entered in a `tf.distribute.StrategyExtended.update`
call.
An _cross-replica context_ is entered when you enter a `strategy.scope`. This
is useful for calling `tf.distribute.Strategy` methods which operate across
the replicas (like `reduce_to()`). By default you start in a _replica context_
(the "default single _replica context_") and then some methods can switch you
back and forth.
* _Distributed value_: Distributed value is represented by the base class
`tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful
to represent values on multiple devices, and it contains a map from replica id
to values. Two representative types of `tf.distribute.DistributedValues`
are `tf.types.experimental.PerReplica` and `tf.types.experimental.Mirrored`
values.
`PerReplica` values exist on the worker devices, with a different value for
each replica. They are produced by iterating through a distributed dataset
returned by `tf.distribute.Strategy.experimental_distribute_dataset` and
`tf.distribute.Strategy.distribute_datasets_from_function`. They are also the
typical result returned by `tf.distribute.Strategy.run`.
`Mirrored` values are like `PerReplica` values, except we know that the value
on all replicas are the same. `Mirrored` values are kept synchronized by the
distribution strategy in use, while `PerReplica` values are left
unsynchronized. `Mirrored` values typically represent model weights. We can
safely read a `Mirrored` value in a cross-replica context by using the value
on any replica, while PerReplica values can only be read within a replica
context.
* _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple
replicas, like `strategy.run(fn, args=[w])` with an
argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will
have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc.
`strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on
device `d0`, `fn(w1)` on device `d1`, etc. It then merges the return
values from `fn()`, which leads to one common object if the returned values
are the same object from every replica, or a `DistributedValues` object
otherwise.
* _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating
multiple values into one value, like "sum" or "mean". If a strategy is doing
sync training, we will perform a reduction on the gradients to a parameter
from all replicas before applying the update. _All-reduce_ is an algorithm for
performing a reduction on values from multiple devices and making the result
available on all of those devices.
* _Mirrored variables_: These are variables that are created on multiple
devices, where we keep the variables in sync by applying the same
updates to every copy. Mirrored variables are created with
`tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`.
Normally they are only used in synchronous training.
* _SyncOnRead variables_
_SyncOnRead variables_ are created by
`tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and
they are created on multiple devices. In replica context, each
component variable on the local replica can perform reads and writes without
synchronization with each other. When the
_SyncOnRead variable_ is read in cross-replica context, the values from
component variables are aggregated and returned.
_SyncOnRead variables_ bring a lot of custom configuration difficulty to the
underlying logic, so we do not encourage users to instantiate and use
_SyncOnRead variable_ on their own. We have mainly used _SyncOnRead
variables_ for use cases such as batch norm and metrics. For performance
reasons, we often don't need to keep these statistics in sync every step and
they can be accumulated on each replica independently. The only time we want
to sync them is reporting or checkpointing, which typically happens in
cross-replica context. _SyncOnRead variables_ are also often used by advanced
users who want to control when variable values are aggregated. For example,
users sometimes want to maintain gradients independently on each replica for a
couple of steps without aggregation.
* _Distribute-aware layers_
Layers are generally called in a replica context, except when defining a
Keras functional model. `tf.distribute.in_cross_replica_context` will let you
determine which case you are in. If in a replica context,
the `tf.distribute.get_replica_context` function will return the default
replica context outside a strategy scope, `None` within a strategy scope, and
a `tf.distribute.ReplicaContext` object inside a strategy scope and within a
`tf.distribute.Strategy.run` function. The `ReplicaContext` object has an
`all_reduce` method for aggregating across all replicas.
Note that we provide a default version of `tf.distribute.Strategy` that is
used when no other strategy is in scope, that provides the same API with
reasonable default behavior.
"""
# pylint: enable=line-too-long
import collections
import copy
import enum # pylint: disable=g-bad-import-order
import functools
import threading
import weakref
import six
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import numpy_dataset
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context as eager_context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.trackable import base as trackable
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls
# ------------------------------------------------------------------------------
# Context tracking whether in a strategy.update() or .update_non_slot() call.
_update_replica_id = threading.local()
def get_update_replica_id():
"""Get the current device if in a `tf.distribute.Strategy.update()` call."""
try:
return _update_replica_id.current
except AttributeError:
return None
class UpdateContext(object):
"""Context manager when you are in `update()` or `update_non_slot()`."""
__slots__ = ["_replica_id", "_old_replica_id"]
def __init__(self, replica_id):
self._replica_id = replica_id
self._old_replica_id = None
def __enter__(self):
self._old_replica_id = get_update_replica_id()
_update_replica_id.current = self._replica_id
def __exit__(self, exception_type, exception_value, traceback):
del exception_type, exception_value, traceback
_update_replica_id.current = self._old_replica_id
# ------------------------------------------------------------------------------
# Public utility functions.
@tf_export(v1=["distribute.get_loss_reduction"])
def get_loss_reduction():
"""`tf.distribute.ReduceOp` corresponding to the last loss reduction.
This is used to decide whether loss should be scaled in optimizer (used only
for estimator + v1 optimizer use case).
Returns:
`tf.distribute.ReduceOp` corresponding to the last loss reduction for
estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise.
"""
if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access
# If we are not in Estimator context then return 'SUM'. We do not need to
# scale loss in the optimizer.
return reduce_util.ReduceOp.SUM
last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
if (last_reduction == losses_impl.Reduction.SUM or
last_reduction == "sum"): # Check for tf.keras.losses.Reduction.SUM
return reduce_util.ReduceOp.SUM
return reduce_util.ReduceOp.MEAN
# ------------------------------------------------------------------------------
# Internal API for validating the current thread mode
def _require_cross_replica_or_default_context_extended(extended,
error_message=None):
"""Verify in cross-replica context."""
context = _get_per_thread_mode()
cross_replica = context.cross_replica_context
if cross_replica is not None and cross_replica.extended is extended:
return
if context is _get_default_replica_mode():
return
strategy = extended._container_strategy() # pylint: disable=protected-access
# We have an error to report, figure out the right message.
if context.strategy is not strategy:
_wrong_strategy_scope(strategy, context)
assert cross_replica is None
if not error_message:
error_message = ("Method requires being in cross-replica context, use "
"get_replica_context().merge_call()")
raise RuntimeError(error_message)
def _wrong_strategy_scope(strategy, context):
# Figure out the right error message.
if not distribution_strategy_context.has_strategy():
raise RuntimeError(
'Need to be inside "with strategy.scope()" for %s' %
(strategy,))
else:
raise RuntimeError(
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
(context.strategy, strategy))
def require_replica_context(replica_ctx):
"""Verify in `replica_ctx` replica context."""
context = _get_per_thread_mode()
if context.replica_context is replica_ctx: return
# We have an error to report, figure out the right message.
if context.replica_context is None:
raise RuntimeError("Need to be inside `call_for_each_replica()`")
if context.strategy is replica_ctx.strategy:
# Two different ReplicaContexts with the same tf.distribute.Strategy.
raise RuntimeError("Mismatching ReplicaContext.")
raise RuntimeError(
"Mismatching tf.distribute.Strategy objects: %s is not %s." %
(context.strategy, replica_ctx.strategy))
def _require_strategy_scope_strategy(strategy):
"""Verify in a `strategy.scope()` in this thread."""
context = _get_per_thread_mode()
if context.strategy is strategy: return
_wrong_strategy_scope(strategy, context)
def _require_strategy_scope_extended(extended):
"""Verify in a `distribution_strategy.scope()` in this thread."""
context = _get_per_thread_mode()
if context.strategy.extended is extended: return
# Report error.
strategy = extended._container_strategy() # pylint: disable=protected-access
_wrong_strategy_scope(strategy, context)
# ------------------------------------------------------------------------------
# Internal context managers used to implement the DistributionStrategy
# base class
class _CurrentDistributionContext(object):
"""Context manager setting the current `tf.distribute.Strategy`.
Also: overrides the variable creator and optionally the current device.
"""
def __init__(self,
strategy,
var_creator_scope,
var_scope=None,
resource_creator_scope=None,
default_device=None):
self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
strategy)
self._var_creator_scope = var_creator_scope
self._var_scope = var_scope
self._resource_creator_scope = resource_creator_scope
if default_device:
self._device_scope = ops.device(default_device)
else:
self._device_scope = None
self._same_scope_again_count = 0
def __enter__(self):
# Allow this scope to be entered if this strategy is already in scope.
if distribution_strategy_context.has_strategy():
_require_cross_replica_or_default_context_extended(
self._context.strategy.extended)
self._same_scope_again_count += 1
else:
_push_per_thread_mode(self._context)
if self._var_scope:
self._var_scope.__enter__()
self._var_creator_scope.__enter__()
if self._resource_creator_scope:
nest.map_structure(lambda scope: scope.__enter__(),
self._resource_creator_scope)
if self._device_scope:
self._device_scope.__enter__()
return self._context.strategy
def __exit__(self, exception_type, exception_value, traceback):
if self._same_scope_again_count > 0:
self._same_scope_again_count -= 1
return
if self._device_scope:
try:
self._device_scope.__exit__(exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Device scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
try:
self._var_creator_scope.__exit__(
exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Variable creator scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
if self._resource_creator_scope:
try:
if isinstance(self._resource_creator_scope, list):
reversed_resource_creator_scope = self._resource_creator_scope[::-1]
nest.map_structure(
lambda scope: scope.__exit__(exception_type, exception_value, # pylint:disable=g-long-lambda
traceback),
reversed_resource_creator_scope)
else:
self._resource_creator_scope.__exit__(exception_type, exception_value,
traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Resource creator scope nesting error: move call "
"to tf.distribute.set_strategy() out of `with` "
"scope."), e)
if self._var_scope:
try:
self._var_scope.__exit__(exception_type, exception_value, traceback)
except RuntimeError as e:
six.raise_from(
RuntimeError("Variable scope nesting error: move call to "
"tf.distribute.set_strategy() out of `with` scope."),
e)
_pop_per_thread_mode()
# TODO(yuefengz): add more replication modes.
@tf_export("distribute.InputReplicationMode")
class InputReplicationMode(enum.Enum):
"""Replication mode for input function.
* `PER_WORKER`: The input function will be called on each worker
independently, creating as many input pipelines as number of workers.
Replicas will dequeue from the local Dataset on their worker.
`tf.distribute.Strategy` doesn't manage any state sharing between such
separate input pipelines.
* `PER_REPLICA`: The input function will be called on each replica separately.
`tf.distribute.Strategy` doesn't manage any state sharing between such
separate input pipelines.
"""
PER_WORKER = "PER_WORKER"
PER_REPLICA = "PER_REPLICA"
@tf_export("distribute.InputContext")
class InputContext(object):
"""A class wrapping information needed by an input function.
This is a context class that is passed to the user's input function and
contains information about the compute replicas and input pipelines. The
number of compute replicas (in sync training) helps compute the local batch
size from the desired global batch size for each replica. The input pipeline
information can be used to return a different subset of the input in each
replica (for e.g. shard the input pipeline, use a different input
source etc).
"""
__slots__ = [
"_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync"
]
def __init__(self,
num_input_pipelines=1,
input_pipeline_id=0,
num_replicas_in_sync=1):
"""Initializes an InputContext object.
Args:
num_input_pipelines: the number of input pipelines in a cluster.
input_pipeline_id: the current input pipeline id, should be an int in
[0,`num_input_pipelines`).
num_replicas_in_sync: the number of replicas that are in sync.
"""
self._num_input_pipelines = num_input_pipelines
self._input_pipeline_id = input_pipeline_id
self._num_replicas_in_sync = num_replicas_in_sync
@property
def num_replicas_in_sync(self):
"""Returns the number of compute replicas in sync."""
return self._num_replicas_in_sync
@property
def input_pipeline_id(self):
"""Returns the input pipeline ID."""
return self._input_pipeline_id
@property
def num_input_pipelines(self):
"""Returns the number of input pipelines."""
return self._num_input_pipelines
def get_per_replica_batch_size(self, global_batch_size):
"""Returns the per-replica batch size.
Args:
global_batch_size: the global batch size which should be divisible by
`num_replicas_in_sync`.
Returns:
the per-replica batch size.
Raises:
ValueError: if `global_batch_size` not divisible by
`num_replicas_in_sync`.
"""
if global_batch_size % self._num_replicas_in_sync != 0:
raise ValueError("The `global_batch_size` %r is not divisible by "
"`num_replicas_in_sync` %r " %
(global_batch_size, self._num_replicas_in_sync))
return global_batch_size // self._num_replicas_in_sync
def __str__(self):
return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
self.input_pipeline_id, self.num_input_pipelines)
@tf_export("distribute.experimental.ValueContext", v1=[])
class ValueContext(object):
"""A class wrapping information needed by a distribute function.
This is a context class that is passed to the `value_fn` in
`strategy.experimental_distribute_values_from_function` and contains
information about the compute replicas. The `num_replicas_in_sync` and
`replica_id` can be used to customize the value on each replica.
Example usage:
1. Directly constructed.
>>> def value_fn(context):
... return context.replica_id_in_sync_group/context.num_replicas_in_sync
>>> context = tf.distribute.experimental.ValueContext(
... replica_id_in_sync_group=2, num_replicas_in_sync=4)
>>> per_replica_value = value_fn(context)
>>> per_replica_value
0.5
2. Passed in by `experimental_distribute_values_from_function`. {: value=2}
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> def value_fn(value_context):
... return value_context.num_replicas_in_sync
>>> distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(2, 2)
"""
__slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"]
def __init__(self,
replica_id_in_sync_group=0,
num_replicas_in_sync=1):
"""Initializes an ValueContext object.
Args:
replica_id_in_sync_group: the current replica_id, should be an int in
[0,`num_replicas_in_sync`).
num_replicas_in_sync: the number of replicas that are in sync.
"""
self._replica_id_in_sync_group = replica_id_in_sync_group
self._num_replicas_in_sync = num_replicas_in_sync
@property
def num_replicas_in_sync(self):
"""Returns the number of compute replicas in sync."""
return self._num_replicas_in_sync
@property
def replica_id_in_sync_group(self):
"""Returns the replica ID."""
return self._replica_id_in_sync_group
def __str__(self):
return (("tf.distribute.ValueContext(replica id {}, "
" total replicas in sync: ""{})")
.format(self.replica_id_in_sync_group, self.num_replicas_in_sync))
@tf_export("distribute.RunOptions")
class RunOptions(
collections.namedtuple("RunOptions", [
"experimental_enable_dynamic_batch_size",
"experimental_bucketizing_dynamic_shape",
"experimental_xla_options",
])):
"""Run options for `strategy.run`.
This can be used to hold some strategy specific configs.
Attributes:
experimental_enable_dynamic_batch_size: Boolean. Only applies to
TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic
padder to support dynamic batch size for the inputs. Otherwise only static
shape inputs are allowed.
experimental_bucketizing_dynamic_shape: Boolean. Only applies to
TPUStrategy. Default to False. If True, TPUStrategy will automatic
bucketize inputs passed into `run` if the input shape is
dynamic. This is a performance optimization to reduce XLA recompilation,
which should not have impact on correctness.
experimental_xla_options: A `tf.tpu.XLAOptions` instance. Only applies to
TPUStrategy. Controls the XLA compiling options on TPUs. Default to None.
"""
def __new__(cls,
experimental_enable_dynamic_batch_size=True,
experimental_bucketizing_dynamic_shape=False,
experimental_xla_options=None):
return super(RunOptions,
cls).__new__(cls, experimental_enable_dynamic_batch_size,
experimental_bucketizing_dynamic_shape,
experimental_xla_options)
@tf_export("distribute.InputOptions", v1=[])
class InputOptions(
collections.namedtuple("InputOptions", [
"experimental_fetch_to_device",
"experimental_replication_mode",
"experimental_place_dataset_on_device",
"experimental_per_replica_buffer_size",
])):
"""Run options for `experimental_distribute_dataset(s_from_function)`.
This can be used to hold some strategy specific configs.
```python
# Setup TPUStrategy
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
dataset = tf.data.Dataset.range(16)
distributed_dataset_on_host = (
strategy.experimental_distribute_dataset(
dataset,
tf.distribute.InputOptions(
experimental_replication_mode=
experimental_replication_mode.PER_WORKER,
experimental_place_dataset_on_device=False,
experimental_per_replica_buffer_size=1)))
```
Attributes:
experimental_fetch_to_device: Boolean. If True, dataset
elements will be prefetched to accelerator device memory. When False,
dataset elements are prefetched to host device memory. Must be False when
using TPUEmbedding API. experimental_fetch_to_device can only be used
with experimental_replication_mode=PER_WORKER. Default behavior is same as
setting it to True.
experimental_replication_mode: Replication mode for the input function.
Currently, the InputReplicationMode.PER_REPLICA is only supported with
tf.distribute.MirroredStrategy.
experimental_distribute_datasets_from_function.
The default value is InputReplicationMode.PER_WORKER.
experimental_place_dataset_on_device: Boolean. Default to False. When True,
dataset will be placed on the device, otherwise it will remain on the
host. experimental_place_dataset_on_device=True can only be used with
experimental_replication_mode=PER_REPLICA
experimental_per_replica_buffer_size: Integer. Default to 1. Indicates the
prefetch buffer size in the replica device memory. Users can set it
to 0 to completely disable prefetching behavior, or a number greater than
1 to enable larger buffer size. Note that this option is still
valid with `experimental_fetch_to_device=False`.
"""
def __new__(cls,
experimental_fetch_to_device=None,
experimental_replication_mode=InputReplicationMode.PER_WORKER,
experimental_place_dataset_on_device=False,
experimental_per_replica_buffer_size=1):
if experimental_fetch_to_device is None:
experimental_fetch_to_device = True
return super(InputOptions,
cls).__new__(cls, experimental_fetch_to_device,
experimental_replication_mode,
experimental_place_dataset_on_device,
experimental_per_replica_buffer_size)
# ------------------------------------------------------------------------------
# Base classes for all distribution strategies.
# Base class for v1 Strategy and v2 Strategy classes. For API's specific to
# v1/v2 Strategy, add to implementing classes of StrategyBase.
# pylint: disable=line-too-long
class StrategyBase(object):
"""A state & compute distribution policy on a list of devices.
See [the guide](https://www.tensorflow.org/guide/distributed_training)
for overview and examples. See `tf.distribute.StrategyExtended` and
[`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute)
for a glossary of concepts mentioned on this page such as "per-replica",
_replica_, and _reduce_.
In short:
* To use it with Keras `compile`/`fit`,
[please
read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras).
* You may pass descendant of `tf.distribute.Strategy` to
`tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator`
should distribute its computation. See
[guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support).
* Otherwise, use `tf.distribute.Strategy.scope` to specify that a
strategy should be used when building an executing your model.
(This puts you in the "cross-replica context" for this strategy, which
means the strategy is put in control of things like variable placement.)
* If you are writing a custom training loop, you will need to call a few more
methods,
[see the
guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops):
* Start by creating a `tf.data.Dataset` normally.
* Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert
a `tf.data.Dataset` to something that produces "per-replica" values.
If you want to manually specify how the dataset should be partitioned
across replicas, use
`tf.distribute.Strategy.distribute_datasets_from_function`
instead.
* Use `tf.distribute.Strategy.run` to run a function
once per replica, taking values that may be "per-replica" (e.g.
from a `tf.distribute.DistributedDataset` object) and returning
"per-replica" values.
This function is executed in "replica context", which means each
operation is performed separately on each replica.
* Finally use a method (such as `tf.distribute.Strategy.reduce`) to
convert the resulting "per-replica" values into ordinary `Tensor`s.
A custom training loop can be as simple as:
```
with my_strategy.scope():
@tf.function
def distribute_train_epoch(dataset):
def replica_fn(input):
# process input and return result
return result
total_result = 0
for x in dataset:
per_replica_result = my_strategy.run(replica_fn, args=(x,))
total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
per_replica_result, axis=None)
return total_result
dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
for _ in range(EPOCHS):
train_result = distribute_train_epoch(dist_dataset)
```
This takes an ordinary `dataset` and `replica_fn` and runs it
distributed using a particular `tf.distribute.Strategy` named
`my_strategy` above. Any variables created in `replica_fn` are created
using `my_strategy`'s policy, and library functions called by
`replica_fn` can use the `get_replica_context()` API to implement
distributed-specific behavior.
You can use the `reduce` API to aggregate results across replicas and use
this as a return value from one iteration over a
`tf.distribute.DistributedDataset`. Or
you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to
accumulate metrics across steps in a given epoch.
See the
[custom training loop
tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training)
for a more detailed example.
Note: `tf.distribute.Strategy` currently does not support TensorFlow's
partitioned variables (where a single variable is split across multiple
devices) at this time.
"""
# pylint: enable=line-too-long
# TODO(josh11b): Partitioned computations, state; sharding
# TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
def __init__(self, extended):
self._extended = extended
# Flag that is used to indicate whether distribution strategy is used with
# Estimator. This is required for backward compatibility of loss scaling
# when using v1 optimizer with estimator.
self._scale_loss_for_estimator = False
if not hasattr(extended, "_retrace_functions_for_each_device"):
# pylint: disable=protected-access
# `extended._retrace_functions_for_each_device` dictates
# whether the same function will be retraced when it is called on
# different devices.
try:
extended._retrace_functions_for_each_device = (
len(extended.worker_devices) > 1)
distribution_strategy_replica_gauge.get_cell("num_replicas").set(
self.num_replicas_in_sync)
except: # pylint: disable=bare-except
# Default for the case where extended.worker_devices can't return
# a sensible value.
extended._retrace_functions_for_each_device = True
# Below are the dicts of axis(int) -> `tf.function`.
self._mean_reduce_helper_fns = {}
self._reduce_sum_fns = {}
# Whether this strategy is designed to work with `ClusterCoordinator`.
self._should_use_with_coordinator = False
@property
def extended(self):
"""`tf.distribute.StrategyExtended` with additional methods."""
return self._extended
@tf_contextlib.contextmanager
def _scale_loss_for_estimator_enabled(self):
"""Scope which sets a flag used for scaling losses in optimizer.
Yields:
`_scale_loss_for_estimator_enabled` is a context manager with a
side effect, but doesn't return a value.
"""
self._scale_loss_for_estimator = True
try:
yield
finally:
self._scale_loss_for_estimator = False
# pylint: disable=line-too-long
def scope(self):
"""Context manager to make the strategy current and distribute variables.
This method returns a context manager, and is used as follows:
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> # Variable created inside scope:
>>> with strategy.scope():
... mirrored_variable = tf.Variable(1.)
>>> mirrored_variable
MirroredVariable:{
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
}
>>> # Variable created outside scope:
>>> regular_variable = tf.Variable(1.)
>>> regular_variable
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>
_What happens when Strategy.scope is entered?_
* `strategy` is installed in the global context as the "current" strategy.
Inside this scope, `tf.distribute.get_strategy()` will now return this
strategy. Outside this scope, it returns the default no-op strategy.
* Entering the scope also enters the "cross-replica context". See
`tf.distribute.StrategyExtended` for an explanation on cross-replica and
replica contexts.
* Variable creation inside `scope` is intercepted by the strategy. Each
strategy defines how it wants to affect the variable creation. Sync
strategies like `MirroredStrategy`, `TPUStrategy` and
`MultiWorkerMiroredStrategy` create variables replicated on each replica,
whereas `ParameterServerStrategy` creates variables on the parameter
servers. This is done using a custom `tf.variable_creator_scope`.
* In some strategies, a default device scope may also be entered: in
`MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is
entered on each worker.
Note: Entering a scope does not automatically distribute a computation, except
in the case of high level training framework like keras `model.fit`. If
you're not using `model.fit`, you
need to use `strategy.run` API to explicitly distribute that computation.
See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training).
_What should be in scope and what should be outside?_
There are a number of requirements on what needs to happen inside the scope.
However, in places where we have information about which strategy is in use,
we often enter the scope for the user, so they don't have to do it
explicitly (i.e. calling those either inside or outside the scope is OK).
* Anything that creates variables that should be distributed variables
must be called in a `strategy.scope`. This can be accomplished either by
directly calling the variable creating function within the scope context,
or by relying on another API like `strategy.run` or `keras.Model.fit` to
automatically enter it for you. Any variable that is created outside scope
will not be distributed and may have performance implications. Some common
objects that create variables in TF are Models, Optimizers, Metrics. Such
objects should always be initialized in the scope, and any functions
that may lazily create variables (e.g., `Model.__call__()`, tracing a
`tf.function`, etc.) should similarly be called within scope. Another
source of variable creation can be a checkpoint restore - when variables
are created lazily. Note that any variable created inside a strategy
captures the strategy information. So reading and writing to these
variables outside the `strategy.scope` can also work seamlessly, without
the user having to enter the scope.
* Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which
require to be in a strategy's scope, enter the scope automatically, which
means when using those APIs you don't need to explicitly enter the scope
yourself.
* When a `tf.keras.Model` is created inside a `strategy.scope`, the Model
object captures the scope information. When high level training framework
methods such as `model.compile`, `model.fit`, etc. are then called, the
captured scope will be automatically entered, and the associated strategy
will be used to distribute the training etc. See a detailed example in
[distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras).
WARNING: Simply calling `model(..)` does not automatically enter the
captured scope -- only high level training framework APIs support this
behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict`
and `model.save` can all be called inside or outside the scope.
* The following can be either inside or outside the scope:
* Creating the input datasets
* Defining `tf.function`s that represent your training step
* Saving APIs such as `tf.saved_model.save`. Loading creates variables,
so that should go inside the scope if you want to train the model in a
distributed way.
* Checkpoint saving. As mentioned above - `checkpoint.restore` may
sometimes need to be inside scope if it creates variables.
Returns:
A context manager.
"""
return self._extended._scope(self) # pylint: disable=protected-access
# pylint: enable=line-too-long
@doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended`
@deprecated(None, "use extended.colocate_vars_with() instead.")
def colocate_vars_with(self, colocate_with_variable):
"""DEPRECATED: use extended.colocate_vars_with() instead."""
return self._extended.colocate_vars_with(colocate_with_variable)
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
def make_dataset_iterator(self, dataset):
"""DEPRECATED TF 1.x ONLY."""
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
def make_input_fn_iterator(self,
input_fn,
replication_mode=InputReplicationMode.PER_WORKER):
"""DEPRECATED TF 1.x ONLY."""
if replication_mode != InputReplicationMode.PER_WORKER:
raise ValueError(
"Input replication mode not supported: %r" % replication_mode)
with self.scope():
return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
input_fn, replication_mode=replication_mode)
@doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only
@deprecated(None, "use run() instead")
def experimental_run(self, fn, input_iterator=None):
"""DEPRECATED TF 1.x ONLY."""
with self.scope():
args = (input_iterator.get_next(),) if input_iterator is not None else ()
return self.run(fn, args=args)
def experimental_distribute_dataset(self, dataset, options=None):
# pylint: disable=line-too-long
"""Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`.
The returned `tf.distribute.DistributedDataset` can be iterated over
similar to regular datasets.
NOTE: The user cannot add any more transformations to a
`tf.distribute.DistributedDataset`. You can only create an iterator or
examine the `tf.TypeSpec` of the data generated by it. See API docs of