/
tpu_strategy_util.py
238 lines (198 loc) · 9.82 KB
/
tpu_strategy_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TPU specific APIs to be used in conjunction with TPU Strategy."""
import gc
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import topology
from tensorflow.python.tpu import tpu
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
_INITIALIZED_TPU_SYSTEMS = {}
_LOCAL_MASTERS = ("", "local")
@tf_export("tpu.experimental.initialize_tpu_system")
def initialize_tpu_system(cluster_resolver=None):
"""Initialize the TPU devices.
Args:
cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster.
Returns:
The tf.tpu.Topology object for the topology of the TPU cluster. If called
inside tf.function, it returns the serialized topology object instead.
Raises:
RuntimeError: If running inside a tf.function.
NotFoundError: If no TPU devices found in eager mode.
"""
# Deallocate all TPU buffers by clearing out eager context caches and
# triggering garbage collection to avoid keeping invalid tpu buffer around
# after reinitialized tpu system.
logging.info("Deallocate tpu buffers before initializing tpu system.")
context.context()._clear_caches() # pylint: disable=protected-access
context.context().clear_kernel_cache()
gc.collect()
job = None
if cluster_resolver is None:
# If no cluster resolver is specified, and running eagerly, execute the init
# ops in the current device scope.
if context.executing_eagerly():
curr_device = device.DeviceSpec.from_string(context.context().device_name)
if curr_device.job is not None:
job = "{}/replica:0/task:0".format(curr_device.job)
cluster_resolver = TPUClusterResolver("")
assert isinstance(cluster_resolver, TPUClusterResolver)
tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access
if tpu_name in _INITIALIZED_TPU_SYSTEMS:
logging.warning(
"TPU system %s has already been initialized. "
"Reinitializing the TPU can cause previously created "
"variables on TPU to be lost.", tpu_name)
logging.info("Initializing the TPU system: %s", tpu_name)
# This function looks as it is for the following non-intuitive reasons.
# tpu.initialize_system creates a dummy op whose sole purpose is to trigger
# DistributedTPURewritePass. This pass actually adds real ops that
# initialize the TPU system. Thus, we can't simply run tpu.initialize_system
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
if tpu_name not in _LOCAL_MASTERS:
# Explicitly place the tpu.initialize_system in the first worker to
# avoid the output node match multiple devices error.
job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
if context.executing_eagerly():
@function.defun
def _tpu_init_fn():
# In TF1, we usually close chips when compilation fails to clear the data
# in infeed. In TF2, we don't need to do this because infeed is no longer
# used, so user can recover from TPU compilation failures more smoothly.
# Same for the cancellation of a TPU excution.
return tpu.initialize_system(
job=job,
compilation_failure_closes_chips=False,
tpu_cancellation_closes_chips=False)
# The TPU_SYSTEM device must match the device used in tpu.initialize_system
# exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
# devices available.
try:
with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
output = _tpu_init_fn()
context.async_wait()
except errors.InvalidArgumentError as e:
raise errors.NotFoundError(
None, None,
"TPUs not found in the cluster. Failed in initialization: "
+ str(e))
# Clear out the eager context caches since the memory is invalid now.
context.context()._initialize_logical_devices() # pylint: disable=protected-access
serialized_topology = output.numpy()
elif not ops.executing_eagerly_outside_functions():
master = cluster_resolver.master()
cluster_spec = cluster_resolver.cluster_spec()
session_config = config_pb2.ConfigProto(allow_soft_placement=True)
if cluster_spec:
session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
with ops.Graph().as_default():
with session_lib.Session(config=session_config, target=master) as sess:
serialized_topology = sess.run(tpu.initialize_system())
else:
with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
serialized_topology = tpu.initialize_system(
job=job, compilation_failure_closes_chips=False)
# If initialize_tpu_system is called inside tf.function, we only return
# the serialized topology object as the tf.tpu.Topology object has to be
# constructed in eager mode.
return serialized_topology
logging.info("Finished initializing TPU system.")
tpu_topology = topology.Topology(serialized=serialized_topology)
cluster_resolver.set_tpu_topology(serialized_topology)
_INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology
return tpu_topology
def get_initialized_tpu_systems():
"""Returns all currently initialized tpu systems.
Returns:
A dictionary, with tpu name as the key and the tpu topology as the value.
"""
return _INITIALIZED_TPU_SYSTEMS.copy()
@tf_export("tpu.experimental.shutdown_tpu_system")
def shutdown_tpu_system(cluster_resolver=None):
"""Shuts down the TPU devices.
This will clear all caches, even those that are maintained through sequential
calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
cache.
Args:
cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster.
Raises:
RuntimeError: If no TPU devices found for eager execution or if run in a
tf.function.
"""
job = None
if cluster_resolver is None:
# If no cluster resolver is specified, and running eagerly, execute the init
# ops in the current device scope.
if context.executing_eagerly():
curr_device = device.DeviceSpec.from_string(context.context().device_name)
if curr_device.job is not None:
job = "{}/replica:0/task:0".format(curr_device.job)
cluster_resolver = TPUClusterResolver("")
assert isinstance(cluster_resolver, TPUClusterResolver)
tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access
if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
logging.warning("You are shutting down a TPU system %s that has not been "
"initialized." % tpu_name)
logging.info("Shutting down the TPU system: %s", tpu_name)
if context.executing_eagerly():
# This function looks as it is for the following non-intuitive reasons.
# tpu.shutdown_system creates a dummy op whose sole purpose is to trigger
# DistributedTPURewritePass. This pass actually adds real ops that
# shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
if tpu_name not in _LOCAL_MASTERS:
# Explicitly place the tpu.shutdown_system in the first worker to
# avoid the output node match multiple devices error.
job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
@function.defun
def _tpu_shutdown_fn():
tpu.shutdown_system(job=job)
# The TPU_SYSTEM device must match the device used in tpu.shutdown_system
# exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
# devices available.
with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
_tpu_shutdown_fn()
# Clear out the eager context caches since the memory is invalid now.
logging.info("Clearing out eager caches")
context.context()._clear_caches() # pylint: disable=protected-access
context.context().clear_kernel_cache()
elif not ops.executing_eagerly_outside_functions():
master = cluster_resolver.master()
cluster_spec = cluster_resolver.cluster_spec()
session_config = config_pb2.ConfigProto(allow_soft_placement=True)
if cluster_spec:
session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
with ops.Graph().as_default():
with session_lib.Session(config=session_config, target=master) as sess:
sess.run(tpu.shutdown_system())
else:
raise RuntimeError(
"initialize_tpu_system is not supported within "
"tf.functions. You should call initialize_tpu_system outside of your tf.function. "
)
logging.info("Finished shutting down TPU system.")
if tpu_name in _INITIALIZED_TPU_SYSTEMS:
del _INITIALIZED_TPU_SYSTEMS[tpu_name]