-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtpu.py
1688 lines (1463 loc) · 70.4 KB
/
tpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Library of TPU helper functions."""
import collections
import enum
from typing import Any, Callable, Iterable, List, Optional, Text, Tuple, Union
from absl import logging
import numpy as np
from tensorflow.compiler.tf2xla.python import xla as tf2xla
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as embedding_pb2
from tensorflow.python import tf2
from tensorflow.python.compiler.xla import xla
from tensorflow.python.framework import auto_control_deps
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import cond
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.tpu import device_assignment as device_assignment_lib
from tensorflow.python.tpu import tensor_tracer
from tensorflow.python.tpu import tpu_feed
from tensorflow.python.tpu import tpu_function
from tensorflow.python.tpu import tpu_name_util
from tensorflow.python.tpu import tpu_replication
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.types import core as core_types
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util import traceback_utils
from tensorflow.python.util import variable_utils
from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable("TPUReplicatedInput")
# Ops which can be safely pruned from XLA compile if they have no consumers.
# These ops should also have no inputs.
_UNCONNECTED_OPS_TO_PRUNE = set(["Placeholder", "VarHandleOp"])
_POST_DEVICE_REWRITE_ATTR = "_post_device_rewrite"
_TPU_COMPILATION_STATUS_ATTR = "_tpu_compilation_status"
_PIVOT_FOR_CLUSTER = "_pivot_for_cluster"
core = tpu_name_util.core
def _tpu_system_device_name(job: Optional[Text]) -> Text:
"""Returns the device name for the TPU_SYSTEM device of `job`."""
if job is None:
return "/device:TPU_SYSTEM:0"
else:
return "/job:%s/device:TPU_SYSTEM:0" % job
@tf_export(v1=["tpu.initialize_system"])
def initialize_system(
embedding_config: Optional[embedding_pb2.TPUEmbeddingConfiguration] = None,
job: Optional[Text] = None,
compilation_failure_closes_chips: bool = True,
tpu_cancellation_closes_chips: Optional[bool] = None,
) -> core_types.Tensor:
"""Initializes a distributed TPU system for use with TensorFlow.
Args:
embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
describing the desired configuration of the hardware embedding lookup
tables. If embedding_config is None, no hardware embeddings can be used.
job: The job (the XXX in TensorFlow device specification /job:XXX) that
contains the TPU devices that will be initialized. If job=None it is
assumed there is only one job in the TensorFlow flock, and an error will
be returned if this assumption does not hold.
compilation_failure_closes_chips: Set the configuration whether
we want to close TPU chips when there is a compilation failure.
tpu_cancellation_closes_chips: Set the configuration whether
we want to close TPU chips when a TPU execution is cancelled. If the value
is None, the behavior will be determined by the command line flag
`tpu_cancellation_closes_chips` for the TPU worker. WARNING: this argument
only applies to TFRT TPU runtime.
Returns:
A serialized `TopologyProto` that describes the TPU system. Note:
the topology must be evaluated using `Session.run` before it can be used.
"""
config_string = ("" if embedding_config is None else
embedding_config.SerializeToString())
# The enum is defined in core/tpu/kernels/tpu_execute_op_options.h.
tpu_cancellation_closes_chips_enum = 0
if tpu_cancellation_closes_chips is not None:
if tpu_cancellation_closes_chips:
tpu_cancellation_closes_chips_enum = 1
else:
tpu_cancellation_closes_chips_enum = 2
with ops.device(_tpu_system_device_name(job)):
topology = tpu_ops.configure_distributed_tpu(
compilation_failure_closes_chips=compilation_failure_closes_chips,
tpu_cancellation_closes_chips=tpu_cancellation_closes_chips_enum,
)
if embedding_config is None:
return topology
# This set of control dependencies is needed as this function is expected to
# return an op which will return the topology when executed, but we need to
# call the embedding initialization op between initializing the TPU and
# returning the topology.
with ops.control_dependencies([topology]):
embedding_init = tpu_ops.configure_tpu_embedding(config=config_string)
with ops.control_dependencies([embedding_init]):
return array_ops.identity(topology, name="tpu_init_identity")
def initialize_system_for_tpu_embedding(
embedding_config: embedding_pb2.TPUEmbeddingConfiguration,
job: Optional[Text] = None,
) -> ops.Operation:
"""Initializes a distributed TPU Embedding system for use with TensorFlow.
The following two are equivalent:
1. initialize_system() with embedding_config.
2. initialize_system() without embedding_config, then
initialize_system_for_tpu_embedding().
initialize_system() should not be called with embedding_config if
initialize_system_for_tpu_embedding() is meant to be called later.
Args:
embedding_config: a `TPUEmbeddingConfiguration` proto describing the desired
configuration of the hardware embedding lookup tables.
job: The job (the XXX in TensorFlow device specification /job:XXX) that
contains the TPU devices that will be initialized. If job=None it is
assumed there is only one job in the TensorFlow flock, and an error will
be returned if this assumption does not hold.
Returns:
A no-op.
"""
config_string = embedding_config.SerializeToString()
with ops.device(_tpu_system_device_name(job)):
return tpu_ops.configure_tpu_embedding(config=config_string)
@tf_export(v1=["tpu.shutdown_system"])
def shutdown_system(job: Optional[Text] = None) -> ops.Operation:
"""Shuts down a running a distributed TPU system.
Args:
job: The job (the XXX in TensorFlow device specification /job:XXX) that
contains the TPU devices that will be shutdown. If job=None it is
assumed there is only one job in the TensorFlow flock, and an error will
be returned if this assumption does not hold.
"""
with ops.device(_tpu_system_device_name(job)):
shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu()
return shutdown_distributed_tpu
@auto_control_deps.register_acd_resource_resolver
def tpu_replicated_input_resolver(
op: ops.Operation,
resource_reads: object_identity.ObjectIdentitySet,
resource_writes: object_identity.ObjectIdentitySet) -> bool:
"""Replaces TPUReplicatedInput outputs with its inputs in resource_inputs."""
# Ignore TPUReplicatedInput for ACD purposes since we will be directly adding
# control deps on the replicated inputs.
if op.type == "TPUReplicatedInput":
if resource_reads or resource_writes:
resource_reads.clear()
resource_writes.clear()
return True
else:
return False
# Replace tensors in `resource_inputs` which are outputs of TPUReplicatedInput
# with the actual replicated inputs. This allows ACD to correct add control
# deps when there are multiple calls to `run` in a
# `tf.function`.
def replace_with_unreplicated_resources(resource_inputs):
"""Replaces handles in `resource_inputs` with their unreplicated inputs."""
to_remove = []
to_add = []
for resource in resource_inputs:
if resource.op.type == "TPUReplicatedInput":
to_remove.append(resource)
to_add.extend(resource.op.inputs)
for t in to_remove:
resource_inputs.discard(t)
resource_inputs.update(to_add)
return to_add or to_remove
return bool(replace_with_unreplicated_resources(resource_reads) or
replace_with_unreplicated_resources(resource_writes))
@tf_export(v1=["tpu.PaddingSpec"])
class PaddingSpec(enum.IntEnum):
"""Represents the type of padding policies for tpu.replicate."""
# By default the policy is set to AUTO, the dynamic input shape dimension will
# be pad to maximum of all the replicas.
AUTO = 0
# Bucketize the dynamic input shape dimension into a power of 2.
POWER_OF_TWO = 1
@tf_export("tpu.XLAOptions")
class XLAOptions(
collections.namedtuple("XLAOptions", [
"use_spmd_for_xla_partitioning",
"enable_xla_dynamic_padder",
])):
"""XLA compilation options.
Attributes:
use_spmd_for_xla_partitioning: Boolean. Whether to use XLA's SPMD
partitioner instead of MPMD partitioner when compiler partitioning is
requested.
enable_xla_dynamic_padder: Boolean. Whether to enable XLA dynamic padder
infrastructure to handle dynamic shapes inputs inside XLA. True by
default. Disabling this may cause correctness issues with dynamic shapes
inputs, as XLA will just assume the inputs are with padded shapes. However
users can optionally set it to False to improve device time if masking is
already handled in the user side.
"""
def __new__(cls,
use_spmd_for_xla_partitioning=True,
enable_xla_dynamic_padder=True):
return super(XLAOptions, cls).__new__(cls, use_spmd_for_xla_partitioning,
enable_xla_dynamic_padder)
@tf_export(v1=["tpu.replicate"])
@traceback_utils.filter_traceback
def replicate(
computation: Callable[..., Any],
inputs: Optional[List[List[core_types.Tensor]]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
name: Optional[Text] = None,
maximum_shapes: Optional[Any] = None,
padding_spec: Optional[PaddingSpec] = None,
xla_options: Optional[XLAOptions] = None) -> List[Any]:
"""Builds a graph operator that runs a replicated TPU computation.
Example for the basic usage that `inputs` has static shape:
```python
def computation(x):
x = x + 1
return tf.math.reduce_mean(x)
x = tf.convert_to_tensor([1., 2., 3.])
y = tf.convert_to_tensor([4., 5., 6.])
tf.compat.v1.tpu.replicate(computation, inputs=[[x], [y]])
```
If the `inputs` has dynamic shapes and you would like to automatically
bucketize the inputs to avoid XLA recompilation. See the advanced example
below:
```python
def computation(x):
x = x + 1
return tf.math.reduce_mean(x)
# Assume input tensors in two replicas `x` and `y` both have dynamic shape
# ([None, 2]).
tf.compat.v1.tpu.replicate(
computation,
inputs=[x, y],
maximum_shapes=[tf.TensorShape([None, None])],
padding_spec=tf.compat.v1.tpu.PaddingSpec.POWER_OF_TWO)
```
Args:
computation: A Python function that builds the computation to replicate.
inputs: A list of lists of input tensors or `None` (equivalent to
`[[]]`), indexed by `[replica_num][input_num]`. All replicas must
have the same number of inputs. Each input can be a nested structure
containing values that are convertible to tensors. Note that passing an
N-dimension list of compatible values will result in a N-dimension list of
scalar tensors rather than a single Rank-N tensors. If you need different
behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to computation.
device_assignment: If not `None`, a `DeviceAssignment` describing the
mapping between logical cores in the computation with physical cores in
the TPU topology. Uses a default device assignment if `None`. The
`DeviceAssignment` may be omitted if each replica of the computation uses
only one core, and there is either only one replica, or the number of
replicas is equal to the number of cores in the TPU system.
name: (Deprecated) Does nothing.
maximum_shapes: A nested structure of tf.TensorShape representing the shape
to which the respective component of each input element in each replica
should be padded. Any unknown dimensions (e.g.
tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
object) will be padded to the maximum size of that dimension over all
replicas. The structure of `maximum_shapes` needs to be the same as
`inputs[0]`.
padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
padding policy when the `inputs` to `tpu.replicate` is dynamic.
One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns:
A list of outputs, indexed by `[replica_num]` each output can be a nested
structure same as what computation() returns with a few exceptions.
Exceptions include:
1) None output: a NoOp would be returned which control-depends on
computation.
2) Single value output: A tuple containing the value would be returned.
3) Operation-only outputs: a NoOp would be returned which
control-depends on computation.
TODO(b/121383831): Investigate into removing these special cases.
Raises:
ValueError: If all replicas do not have equal numbers of input tensors.
ValueError: If the number of inputs per replica does not match
the number of formal parameters to `computation`.
ValueError: If the static `inputs` dimensions don't match with the values
given in `maximum_shapes`.
ValueError: If the structure of inputs per replica does not match
the structure of `maximum_shapes`.
"""
return split_compile_and_replicate(
computation,
inputs,
infeed_queue,
device_assignment,
name,
maximum_shapes=maximum_shapes,
padding_spec=padding_spec,
xla_options=xla_options)[1]
def _ceil_to_pow_of_n(x, n):
"""Ceil input `x` to power of `n`."""
x = math_ops.cast(x, dtypes.float32)
lognx = math_ops.log(x) / math_ops.log(n * 1.0)
lognx = math_ops.ceil(lognx)
result = math_ops.pow(n * 1.0, lognx)
result = math_ops.cast(result, dtypes.int32)
return result
def _pad_all_input(
inputs: Iterable[core_types.Tensor],
padded_shapes: List[Optional[tensor_shape.TensorShape]],
padding_spec: PaddingSpec
) -> Tuple[List[List[Any]], List[dynamic_padding.PaddingMap]]:
"""Pad all input tensors given padded_shapes.
The real shape tensors will be concatenated with the padded original inputs.
Args:
inputs: The original inputs.
padded_shapes: A list of padded shapes for each input. If an entry is None,
no padding is performed.
padding_spec: An enum specified by `tpu.PaddingSpec`. This describes the
padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side.
Returns:
The padded inputs and a PaddingMap list which maps the padded input
dimension to the real shape argument index.
"""
# maximum_static_shapes[idx][i] indicates the maximum static size of ith
# dimension of the idx input among all the replicas.
maximum_static_shapes = []
# need_padding[idx][i] indicates whether the ith dimension of the idx input
# needs padding.
need_padding = []
input_shape_tensors = []
for core_idx, inputs_per_core in enumerate(inputs):
for idx, input_tensor in enumerate(inputs_per_core):
input_shape = input_tensor.get_shape().as_list()
if core_idx == 0:
input_shape_tensors.append([])
maximum_static_shapes.append(input_shape)
need_padding.append(np.full_like(input_shape, False, dtype=bool))
else:
for i, s in enumerate(input_shape):
if s is None or s != maximum_static_shapes[idx][i]:
need_padding[idx][i] = True
maximum_static_shapes[idx] = max(input_shape,
maximum_static_shapes[idx])
# Append _POST_DEVICE_REWRITE_ATTR attributes to the real shape ops.
real_input_shape = array_ops.shape(input_tensor)
real_input_shape.op._set_attr( # pylint: disable=protected-access
_POST_DEVICE_REWRITE_ATTR,
attr_value_pb2.AttrValue(b=True))
input_shape_tensors[idx].append(real_input_shape)
maximum_shapes = []
for shapes_per_input in input_shape_tensors:
maximum_shapes.append(
math_ops.reduce_max(array_ops_stack.stack(shapes_per_input), axis=0))
padded_inputs = []
real_shapes = []
padding_maps = []
for core_idx, inputs_per_core in enumerate(inputs):
padded_inputs.append([])
real_shapes.append([])
real_shape_idx = len(inputs_per_core) - 1
for idx, input_tensor in enumerate(inputs_per_core):
input_shape_tensor = input_shape_tensors[idx][core_idx]
input_shape = input_tensor.get_shape().as_list()
padded_shape = padded_shapes[idx]
# If we have no padded_shape, then skip padding.
if any(need_padding[idx]) and padded_shape is not None:
for i, s in enumerate(input_shape):
if need_padding[idx][i]:
if core_idx == 0:
real_shape_idx += 1
padding_map = dynamic_padding.PaddingMap()
padding_map.arg_index = idx
padding_map.shape_index = i
padding_map.padding_arg_index = real_shape_idx
padding_maps.append(padding_map)
real_shapes[core_idx].append(
math_ops.cast(input_shape_tensor[i], dtypes.int32))
paddings = []
for i, s in enumerate(padded_shape.dims):
if need_padding[idx][i]:
# The minimum padded dimension size is 2 as XLA doesn't support size
# 1 dynamic size.
minimum_dynamic_dim_size = 2
if s.value is not None:
# Pad to the given maximum value.
max_dim_size = max(s.value, minimum_dynamic_dim_size)
else:
# If maximum value is not given, then pad to the maximum dimension
# among all the cores.
max_dim_size = math_ops.maximum(maximum_shapes[idx][i],
minimum_dynamic_dim_size)
if padding_spec == PaddingSpec.POWER_OF_TWO:
max_dim_size = _ceil_to_pow_of_n(max_dim_size, 2)
# Pad to the given maximum value.
padding = [0, max_dim_size - input_shape_tensor[i]]
else:
padding = [0, 0]
paddings.append(padding)
if input_tensor.get_shape().is_fully_defined():
# TODO(rxsang): This is a hack to make sure padded_input has dynamic
# shapes, so any tf.size/tf.shape op performed on it won't be constant
# folded. Do we have better ways to do it?
padded_input = cond.cond(
array_ops.constant(True),
lambda: array_ops.pad(input_tensor, paddings), # pylint: disable=cell-var-from-loop
lambda: input_tensor)
else:
padded_input = array_ops.pad(input_tensor, paddings)
# Append _POST_DEVICE_REWRITE_ATTR attributes to all padded inputs.
padded_input.op._set_attr( # pylint: disable=protected-access
_POST_DEVICE_REWRITE_ATTR,
attr_value_pb2.AttrValue(b=True))
padded_inputs[core_idx].append(padded_input)
else:
padded_inputs[core_idx].append(input_tensor)
num_replicas = len(padded_inputs)
for i in range(num_replicas):
padded_inputs[i].extend(real_shapes[i])
return padded_inputs, padding_maps
def _flatten_and_filter_composite(maybe_composite, non_composite_output,
composite_output=None):
"""For an input, replaced the input by a tuple if the input is composite.
If `maybe_composite` is not composite, return the parameter
`non_composite_output` otherwise return a tuple which consists of the value of
the parameter `composite_output` the same number of times as there are
components of the composite tensor.
This is useful for computing a mask when flattening nested data with
`expand_composites=True`. For example
```python
nest.flatten(data, expand_composites=True)
```
and
```python
nest.flatten(nest.map(
data, lambda x: _flatten_and_filter_composite(x, False, True)))
```
will have the same length and second will be True if the tensor in the first
is derived from a expanding a composite tensor.
Args:
maybe_composite: A value to test for being a composite tensor.
non_composite_output: The value to return when `maybe_composite` is not a
composite.
composite_output: the value to fill the output tuple with if
`maybe_composite` is a composite.
Returns:
`non_composite_output` or a tuple with multiple copies of
`composite_output`.
"""
if isinstance(maybe_composite, composite_tensor.CompositeTensor):
num_components = len(nest.flatten(maybe_composite, expand_composites=True))
return (composite_output,) * num_components
return non_composite_output
def split_compile_and_replicate(
computation: Callable[..., Any],
inputs: Optional[List[List[core_types.Tensor]]] = None,
infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
device_assignment: Optional[device_assignment_lib.DeviceAssignment] = None,
name: Optional[Text] = None,
use_tpu: bool = True,
maximum_shapes: Optional[Any] = None,
padding_spec: Optional[PaddingSpec] = None,
xla_options: Optional[XLAOptions] = None,
) -> List[List[core_types.Tensor]]:
"""Builds graph operators that runs compilation and replicated computation.
This is a lower level interface than replicate that returns a separate compile
and execute output tensor. In the generated graph the compile op feeds into
the execute op and no additional compilation is incurred when running the
compile op before the execute op. The compile op returns additional
information about the compilation but does not return the compiled program.
Args:
computation: A Python function that builds the computation to replicate.
inputs: A list of lists of input tensors or `None` (equivalent to
`[[]]`), indexed by `[replica_num][input_num]`. All replicas must
have the same number of inputs. Each input can be a nested structure
containing values that are convertible to tensors. Note that passing an
N-dimension list of compatible values will result in a N-dimension list of
scalar tensors rather than a single Rank-N tensors. If you need different
behavior, convert part of inputs to tensors with `tf.convert_to_tensor`.
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to computation.
device_assignment: If not `None`, a `DeviceAssignment` describing the
mapping between logical cores in the computation with physical cores in
the TPU topology. Uses a default device assignment if `None`. The
`DeviceAssignment` may be omitted if each replica of the computation uses
only one core, and there is either only one replica, or the number of
replicas is equal to the number of cores in the TPU system.
name: (Deprecated) Does nothing.
use_tpu: When false, the input `computation` is executed on the XLA CPU/GPU
backends. Currently, only supports a default placement (computation is
placed on GPU if one is available, and on CPU if not).
maximum_shapes: A nested structure of tf.TensorShape representing the shape
to which the respective component of each input element in each replica
should be padded. Any unknown dimensions (e.g.
tf.compat.v1.Dimension(None) in a tf.TensorShape or -1 in a tensor-like
object) will be padded to the maximum size of that dimension over all
replicas. The structure of `maximum_shapes` needs to be the same as
`inputs[0]`.
padding_spec: An enum specified by `tf.tpu.PaddingSpec`. This describes the
padding policy when the `inputs` to `tf.tpu.replicate` is dynamic.
One usage is to enable automatic bucketizing on the inputs by setting the
value to `tpu.PaddingSpec.POWER_OF_TWO`, which can help to reduce the
recompilation in the XLA side.
xla_options: An instance of `tpu.XLAOptions` which indicates the options
passed to XLA compiler. Use `None` for default options.
Returns:
A list of lists with the first list corresponding to the compile op and the
second a list of output tensors, indexed by `[replica_num][output_num]`.
Raises:
ValueError: If all replicas do not have equal numbers of input tensors.
ValueError: If the number of inputs per replica does not match
the number of formal parameters to `computation`.
ValueError: If the static `inputs` dimensions don't match with the values
given in `maximum_shapes`.
ValueError: If the structure of inputs per replica does not match
the structure of `maximum_shapes`.
"""
del name
inputs = [[]] if inputs is None else inputs
xla_options = xla_options or XLAOptions()
metadata_kwargs = {}
if device_assignment is not None:
# Turn the Numpy array into a flattened list so we can pass it as an
# operator attribute.
metadata_kwargs = {
"topology":
device_assignment.topology.serialized(),
"device_assignment":
device_assignment.core_assignment.flatten().tolist()
}
metadata_kwargs["num_cores_per_replica"] = (
device_assignment.num_cores_per_replica)
# This entry is used for enabling automatic outside compilation.
metadata_kwargs["allow_soft_placement"] = config.get_soft_device_placement()
if config.get_soft_device_placement():
logging.info("Automatic outside compilation is enabled. "
"Ops without XLA kernels will be automatically "
"placed on CPU.")
if not isinstance(inputs, list):
raise TypeError("tpu.replicate() inputs must be a list of lists/tuples, "
f"received {type(inputs)}")
if any(not isinstance(inp, (list, tuple)) for inp in inputs):
raise TypeError(
"tpu.replicate() inputs must be a list of lists/tuples, "
f"received types: {[type(inp) for inp in inputs]}")
num_replicas = len(inputs)
# No replicas? Nothing to do.
if num_replicas == 0:
return []
# Checks all replicas have the same structure.
for i in range(1, num_replicas):
nest.assert_same_structure(inputs[0], inputs[i])
# Explicitly read variables.
inputs = variable_utils.convert_variables_to_tensors(inputs)
# Flatten inputs. This structure may contain None values, which will be
# handled later.
flat_inputs_with_nones = [
nest.flatten(per_replica_input, expand_composites=True)
for per_replica_input in inputs
]
# Mask parallel to one replica's inputs with True for tensors coming from
# composites.
is_composite = nest.flatten(nest.map_structure(
lambda x: _flatten_and_filter_composite(x, False, True), inputs[0]))
# Converts inputs to Tensors, replacing Nones with a placeholder 0 since
# tpu_ops.tpu_replicated_input() can't handle non-Tensor values.
flat_inputs = []
for inp in flat_inputs_with_nones:
flat_inputs.append([
constant_op.constant(0) if x is None else ops.convert_to_tensor(x)
for x in inp
])
# Verifies that all replicas have matching numbers and types of inputs
flat_input_types = [x.dtype for x in flat_inputs[0]]
input_arity = len(inputs[0])
flat_input_arity = len(flat_input_types)
for i in range(num_replicas):
if len(inputs[i]) != input_arity:
raise ValueError("Replicas must have the same number of inputs. "
"Replica 0 had {} inputs, replica {} had {} "
"inputs.".format(input_arity, i, len(inputs[i])))
types = [x.dtype for x in flat_inputs[i]]
if types != flat_input_types:
raise ValueError("Replicas must have matching input types. Replica 0 had "
"input types {}, replica {} had input types {}".format(
flat_input_types, i, types))
arg_error = xla.check_function_argument_count(
computation, input_arity, infeed_queue)
if arg_error is not None:
if infeed_queue is None:
raise TypeError(
"Supplied computation cannot be called with the specified inputs. "
f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]}, "
f"but the computation needs {arg_error}")
else:
raise TypeError(
"Supplied computation cannot be called with the specified inputs. "
f"You specified {input_arity} inputs: {[i.name for i in inputs[0]]} ",
f"and {infeed_queue.number_of_tuple_elements} additional inputs "
f"from infeed, but the computation needs {arg_error}")
dynamic_shape_inputs = False
if maximum_shapes:
if infeed_queue:
raise ValueError(
"Dynamic input shapes are not supported with infeed queues")
# Make sure maximum_shapes has the same structure as inputs.
nest.assert_same_structure(inputs[0], maximum_shapes, check_types=False)
# Flatten padded shapes:
# For composite tensor components, we don't want to pad them. For each
# entry of maximum_shapes that corresponds to a composite tensor, replace it
# by a tuple of Nones of the same length as the number of components of the
# composite tensor. When we flatten a second time, this makes
# flat_maximum_shapes have the same length as flat_inputs[i]. We can then
# avoid padding these tensors. The assumption is that they will be used by
# outside compilation or that the components are statically shaped and will
# be used by tpu compatible ops.
flat_maximum_shapes = nest.flatten(
[_flatten_and_filter_composite(x, y)
for x, y in zip(nest.flatten(inputs[0]),
nest.flatten(maximum_shapes))])
flat_maximum_shapes = [
tensor_shape.TensorShape(s) if s is not None else None
for s in flat_maximum_shapes
]
nest.assert_same_structure(flat_inputs[0], flat_maximum_shapes,
check_types=False)
unpadded_inputs = flat_inputs
flat_inputs, padding_maps = _pad_all_input(unpadded_inputs,
flat_maximum_shapes,
padding_spec)
if padding_maps:
dynamic_shape_inputs = True
logging.info("TPU has inputs with dynamic shapes: %s", inputs[0])
metadata_kwargs["step_marker_location"] = getattr(
computation, "step_marker_location", "STEP_MARK_AT_ENTRY")
metadata_kwargs["use_spmd_for_xla_partitioning"] = \
xla_options.use_spmd_for_xla_partitioning
graph = ops.get_default_graph()
# Fan-in: Builds a TPUReplicatedInput node for each input.
flat_replicated_inputs = []
for i in range(0, len(flat_inputs[0])):
replicas = [flat_inputs[replica][i] for replica in range(num_replicas)]
flat_replicated_inputs.append(
tpu_ops.tpu_replicated_input(
replicas, name="input{}".format(i)))
if isinstance(graph, func_graph.FuncGraph):
# When we are in Tensorflow 2.0 function, 'graph' will be a FuncGraph
# object. If both outside graph and this function have a TPU cluster,
# they will have the same cluster name and it will cause problems (because
# we lower functional ops in Tensorflow 2.0). Append function name to
# 'cluster_name' to avoid cluster name collision.
cluster_name = graph.unique_name("cluster_" + graph.name)
else:
cluster_name = graph.unique_name("cluster")
pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
pivot._set_attr(_PIVOT_FOR_CLUSTER, # pylint: disable=protected-access
attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name)))
context = tpu_replication.TPUReplicateContext(
name=cluster_name, num_replicas=num_replicas, pivot=pivot)
try:
context.Enter()
metadata = tpu_ops.tpu_replicate_metadata(
num_replicas=num_replicas, use_tpu=use_tpu, **metadata_kwargs)
with tpu_function.tpu_shard_context(
num_replicas), ops.control_dependencies([metadata]):
if dynamic_shape_inputs and xla_options.enable_xla_dynamic_padder:
for padding_map in padding_maps:
input_shape = flat_replicated_inputs[padding_map.arg_index].shape
flat_replicated_inputs[
padding_map.arg_index] = tf2xla.set_dynamic_dimension_size(
flat_replicated_inputs[padding_map.arg_index],
padding_map.shape_index,
flat_replicated_inputs[padding_map.padding_arg_index])
flat_replicated_inputs[padding_map.arg_index].set_shape(input_shape)
# Add identity ops so even unused inputs are "consumed" by the
# computation. This is to avoid orphaned TPUReplicatedInput nodes.
# TODO(phawkins): consider instead pruning unused TPUReplicatedInput
# and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
flat_replicated_inputs = [
array_ops.identity(x, name="replicated_input_{}".format(i))
for i, x in enumerate(flat_replicated_inputs)
]
for i, composite in zip(flat_replicated_inputs, is_composite):
# pylint: disable=protected-access
# Add an attribute to the identity node so that they could be removed in
# encapsulate TPU computation pass if unused. However we don't remove
# inputs when dynamic padding is enabled.
# TODO(rxsang): Use other ways except argument index in padding_map so
# outside compilation can work with dynamic padding correctly.
if not dynamic_shape_inputs or composite:
i.op._set_attr("_tpu_input_identity",
attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access
# Clobber replicated placeholders with Nones.
computation_inputs = [
None if inp is None else replicated for replicated, inp in zip(
flat_replicated_inputs, flat_inputs_with_nones[0])
]
# Unflatten the computation inputs to match original input structure.
computation_inputs = nest.pack_sequence_as(
structure=inputs[0],
flat_sequence=computation_inputs[:flat_input_arity],
expand_composites=True)
# If there is an infeed queue, adds the dequeued values to the
# computation's inputs.
if infeed_queue is not None:
infeed_queue.set_number_of_shards(num_replicas)
for t in infeed_queue.generate_dequeue_op():
computation_inputs.append(t)
# Only resource variables work inside a TPU computation, so turn on
# resource variables for the computation.
# TODO(phawkins): consider removing this code. It will
# be less confusing to clients if they knowingly choose to use resource
# variables.
# Partitioned variables is not supported (b/112311320).
vscope = variable_scope.get_variable_scope()
saved_use_resource = vscope.use_resource
saved_custom_getter = vscope.custom_getter
def custom_getter(getter, name, *args, **kwargs):
"""Variables on TPU have a few restrictions."""
partitioner = kwargs.get("partitioner", None)
if partitioner is not None:
kwargs["partitioner"] = None
logging.warning(
"Partitioned variables are not supported on TPU. Got "
"`partitioner` that is %s for variable %s. "
"Setting `partitioner` to `None`.", partitioner, name)
if saved_custom_getter is None:
return getter(name, *args, **kwargs)
else:
return saved_custom_getter(getter, name, *args, **kwargs)
vscope.set_use_resource(True)
vscope.set_custom_getter(custom_getter)
outputs = computation(*computation_inputs)
vscope.set_use_resource(saved_use_resource)
vscope.set_custom_getter(saved_custom_getter)
outputs = variable_utils.convert_variables_to_tensors(outputs)
need_spmd_partitioning = (
xla_options.use_spmd_for_xla_partitioning and
device_assignment is not None and
device_assignment.num_cores_per_replica > 1)
outputs_is_flat = xla.is_flat(outputs)
if outputs_is_flat:
output_tensors, control_deps, pack_template = _postprocess_flat_outputs(
outputs, need_spmd_partitioning)
else:
output_tensors, control_deps, pack_template = (
_postprocess_non_flat_outputs(outputs, need_spmd_partitioning))
if tensor_tracer.TensorTracer.is_enabled():
if tf2.enabled():
logging.warn("TF API ver >= 2.0 detected. "
"Tensor Tracer v1 is not enabled.")
else:
tt = tensor_tracer.TensorTracer()
output_tensors = tt.trace_tpu(ops.get_default_graph(),
output_tensors, control_deps,
num_replicas)
context.ExitResult(output_tensors)
finally:
context.report_unsupported_operations()
context.Exit()
host_compute_core = context.HostComputeCore()
if host_compute_core:
attr_value = attr_value_pb2.AttrValue()
attr_value.list.s.extend(compat.as_bytes(x) for x in host_compute_core)
metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access
with ops.control_dependencies([metadata]):
if use_tpu:
compile_status = tpu_ops.tpu_compilation_result()
op = compile_status.op
attr_value = attr_value_pb2.AttrValue(s=compat.as_bytes(cluster_name))
op._set_attr(_TPU_COMPILATION_STATUS_ATTR, attr_value) # pylint: disable=protected-access
else:
compile_status = control_flow_ops.no_op(name="compilation_status")
if not output_tensors:
# Returns a list of NoOps dependent on the replication Op, indexed by
# [replica_num].
return [
compile_status,
[
control_flow_ops.group(control_deps, name="shard_%d" % i)
for i in range(num_replicas)
]
]
# Fan-out: Builds a TPUReplicatedOutput node for each output.
replicated_outputs = [[] for i in range(num_replicas)]
for i, t in enumerate(output_tensors):
# None values returned by the computation can't be sent to
# tpu_ops.tpu_replicated_output(), we handle them specially here. We can
# avoid the placeholder 0 routine required on the inputs since outputs are
# replicated per-tensor, not per-replica, so we can skip replication.
if t is None:
for replica in range(num_replicas):
replicated_outputs[replica].append(None)
continue
# Fan-out: Builds a TPUReplicatedOutput node for each output.
ys = tpu_ops.tpu_replicated_output(
t, num_replicas, name="output{}".format(i))
# Wraps the outputs in identity operators so the names of any possible
# `fetch` nodes are preserved by the replication rewrite.
with ops.control_dependencies(control_deps):
for replica in range(num_replicas):
replicated_outputs[replica].append(
array_ops.identity(
ys[replica], name="output_%d_shard_%d" % (i, replica)))
replicated_outputs = [
nest.pack_sequence_as(pack_template, replica_outs, expand_composites=True)
for replica_outs in replicated_outputs
]
return [compile_status, replicated_outputs]
def _postprocess_flat_outputs(
outputs: Any,
need_spmd_partitioning: bool
) -> Tuple[List[Optional[core_types.Tensor]], List[ops.Operation], List[Any]]:
"""Validates non-flat outputs, add backs device assignments and other attrs.
Args:
outputs: Output from `computation` inside `tpu.rewrite`.
need_spmd_partitioning: Whether XLA SPMD partitioning is needed.
Returns:
- Tensors extracted from outputs.
- Operations extracted from outputs.
- A pack template for use with nest.pack_sequence_as to pack the tensors.
"""
# Following code segment is to preserve legacy behavior. Previously we only
# supported flat outputs and thus for consistency it was nice to convert even
# single element into a tuple. But now that we support arbitrary output
# structure, this is no longer necessary.
# TODO(b/121383831): Migrate all legacy use cases and delete this special
# case.
# If the computation returns `None`, make it an empty tuple.
if outputs is None:
outputs = tuple()
# For legacy / backwards compatibility reasons we return a list for "flat"
# output values (even if the user's flat return value was a different type or
# even just a scalar value) so use nest.flatten to compute a flat list pack
# template.
pack_template = nest.flatten(outputs, expand_composites=False)
# Even though outputs is already "flat", we flatten any composites so their
# component tensors can be tagged and replicated. The pack_template will be
# used by the caller to repack the composite tensors.
outputs = nest.flatten(outputs, expand_composites=True)
# Append `no_op` here so that fetching any return value of this function
# will trigger TPUExecute node.
outputs += (control_flow_ops.no_op(),)
maybe_convert = lambda x: None if x is None else ops.convert_to_tensor(x)
try:
if need_spmd_partitioning: