/
sharding_util.py
284 lines (245 loc) · 11.1 KB
/
sharding_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data structures and utilities for checkpoint sharding."""
import abc
import dataclasses
import inspect
from typing import Hashable, MutableMapping, Sequence
from tensorflow.python.framework import device as device_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import variables
from tensorflow.python.trackable import base
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.util import tf_export
TensorSlice = MutableMapping[tensor_spec.TensorSpec, tensor_lib.Tensor]
TensorSliceDict = MutableMapping[str, TensorSlice]
@tf_export.tf_export("train.experimental.ShardableTensor")
@dataclasses.dataclass(frozen=True)
class ShardableTensor:
"""Tensor wrapper containing data necessary for sharding.
The tensor representation used as inputs to pre-made and custom
`tf.train.experiemental.ShardingCallback`s, which can be specified using the
`experimental_sharding_callback` option in `tf.train.CheckpointOptions`.
"""
_tensor_save_spec: saveable_object.SaveSpec
tensor: tensor_lib.Tensor
dtype: dtypes.DType
device: device_lib.DeviceSpec
name: str
shape: tensor_shape.TensorShape
slice_spec: variables.Variable.SaveSliceInfo
checkpoint_key: str
trackable: base.Trackable
def __hash__(self) -> int:
return hash((self.name, self.dtype, str(self.device), self.checkpoint_key))
def __repr__(self) -> str:
return (f"\n{self.__class__.__name__}:\n"
f" _tensor_save_spec={self._tensor_save_spec!r}\n"
f" tensor={self.tensor!r}\n"
f" dtype={self.dtype!r}\n"
f" device={self.device!r}\n"
f" name={self.name!r}\n"
f" shape={self.shape!r}\n"
f" slice_spec={self.slice_spec!r}\n"
f" checkpoint_key={self.checkpoint_key!r}\n"
f" trackable={self.trackable!r}")
@tf_export.tf_export("train.experimental.ShardingCallback")
class ShardingCallback(abc.ABC):
"""Checkpoint sharding callback function, along with a text description.
A callback function wrapper that will be executed to determine how tensors
will be split into shards when the saver writes the checkpoint shards to disk.
The callback takes a list of `tf.train.experimental.ShardableTensor`s as input
(as well as any kwargs defined by the `tf.train.experimental.ShardingCallback`
subclass), and organizes the input tensors into different shards. Tensors are
first organized by device task (see `tf.DeviceSpec`), then the callback will
be called for each collection of tensors.
There are a few restrictions to keep in mind when creating a custom callback:
- Tensors must not be removed from the checkpoint.
- Tensors must not be reshaped.
- Tensor dtypes must not change.
- Tensors within a shard must belong to the same task.
Validation checks will be performed after the callback function is executed to
ensure these restrictions aren't violated.
Here's an example of a simple custom callback:
```
# Place all tensors in a single shard.
class AllInOnePolicy(tf.train.experimental.ShardingCallback):
@property
def description(self):
return "Place all tensors in a single shard."
def __call__(self, shardable_tensors):
tensors = {}
for shardable_tensor in shardable_tensors:
tensor = shardable_tensor.tensor_save_spec.tensor
checkpoint_key = shardable_tensor.checkpoint_key
slice_spec = shardable_tensor.slice_spec
tensors.set_default(checkpoint_key, {})[slice_spec] = tensor
return [tensors]
ckpt.save(
"path",
options=tf.train.CheckpointOptions(
experimental_sharding_callback=AllInOnePolicy()))
```
The `description` attribute is used to identify the callback and to aid
debugging during saving and restoration.
To take in kwargs, simply define the constructor and pass them in:
```
class ParameterPolicy(tf.train.experimental.ShardingCallback):
def __init__(self, custom_param):
self.custom_param = custom_param
...
ckpt.save(
"path",
options=tf.train.CheckpointOptions(
experimental_sharding_callback=ParameterPolicy(custom_param=...)))
```
"""
description: str
@property
@abc.abstractmethod
def description(self) -> str:
pass
@abc.abstractmethod
def __call__(
self, shardable_tensors: Sequence[ShardableTensor]
) -> Sequence[TensorSliceDict]:
pass
def __hash__(self) -> int:
hash_val = hash(self.description)
# vars() only includes user-defined attributes.
for attr_name, attr_val in vars(self).items():
if not (inspect.ismethod(attr_val) or inspect.isfunction(attr_val)):
hash_val ^= hash(attr_name)
if isinstance(attr_val, Hashable):
hash_val ^= hash(attr_val)
return hash_val
def validate_shards(
shards: Sequence[TensorSliceDict],
shardable_tensors: Sequence[ShardableTensor],
callback_description: str
) -> None:
"""Validates shards generated by the sharding_callback."""
unseen_tensor_dict = {}
for shardable_tensor in shardable_tensors:
unseen_tensor_dict.setdefault(
shardable_tensor.checkpoint_key, {}
)[shardable_tensor.slice_spec] = shardable_tensor.tensor
seen_tensor_set = set()
for shard_tensors in shards:
task_tensor = None
for checkpoint_key, tensor_slice_dict in shard_tensors.items():
for slice_spec, shard_tensor in tensor_slice_dict.items():
slice_spec = slice_spec.strip()
# Validate uniqueness.
if (checkpoint_key, slice_spec) in seen_tensor_set:
raise RuntimeError(
"After executing the checkpoint sharding callback, multiple "
"tensors with the same checkpoint key and slice spec were "
"found:\n"
f" callback_description: {callback_description}\n"
f" checkpoint_key: {checkpoint_key}\n"
f" slice_spec: {slice_spec}\n")
# Validate no added tensors.
if checkpoint_key not in unseen_tensor_dict:
raise RuntimeError(
"After executing the checkpoint sharding callback, a tensor "
"not originally in the object graph was found in the "
"checkpoint shards:\n"
f" callback_description: {callback_description}\n"
f" checkpoint_key: {checkpoint_key}\n"
f" slice_spec: {slice_spec}\n")
# Validate no shape change.
target_shape = unseen_tensor_dict[checkpoint_key][slice_spec].shape
if shard_tensor.shape != target_shape:
raise RuntimeError(
"After executing the checkpoint sharding callback, a tensor "
"was found with an altered shape:\n"
f" callback_description: {callback_description}\n"
f" checkpoint_key: {checkpoint_key}\n"
f" slice_spec: {slice_spec}\n"
f" original tensor_shape: {target_shape}\n"
f" new tensor_shape: {shard_tensor.shape}\n")
# Validate no dtype change.
target_dtype = unseen_tensor_dict[checkpoint_key][slice_spec].dtype
if shard_tensor.dtype != target_dtype:
raise RuntimeError(
"After executing the checkpoint sharding callback, a tensor "
"was found with an altered dtype:\n"
f" callback_description: {callback_description}\n"
f" checkpoint_key: {checkpoint_key}\n"
f" slice_spec: {slice_spec}\n"
f" original tensor_dtype: {target_dtype}\n"
f" new tensor_dtype: {shard_tensor.dtype}\n")
# Validate no task change.
target_task = device_lib.DeviceSpec.from_string(
unseen_tensor_dict[checkpoint_key][slice_spec].device).task
shard_tensor_task = device_lib.DeviceSpec.from_string(
shard_tensor.device).task
if shard_tensor_task != target_task:
raise RuntimeError(
"After executing the checkpoint sharding callback, a tensor "
"was found with an altered task:\n"
f" callback_description: {callback_description}\n"
f" checkpoint_key: {checkpoint_key}\n"
f" slice_spec: {slice_spec}\n"
f" original tensor_task: {target_task}\n"
f" new tensor_task: {shard_tensor_task}\n")
# Validate tensors in shard have the same task.
if task_tensor is None:
task_tensor = ShardableTensor(
_tensor_save_spec=None,
tensor=None,
dtype=None,
device=shard_tensor.device,
name=None,
shape=None,
slice_spec=slice_spec,
checkpoint_key=checkpoint_key,
trackable=None)
else:
task1 = device_lib.DeviceSpec.from_string(task_tensor.device).task
task2 = device_lib.DeviceSpec.from_string(shard_tensor.device).task
if task1 is not None and task2 is not None and task1 != task2:
raise RuntimeError(
"After executing the checkpoint sharding callback, tensors "
"with different tasks were found in the same shard:\n"
f" callback_description: {callback_description}\n"
" tensor #1:"
f" checkpoint_key: {task_tensor.checkpoint_key}\n"
f" slice_spec: {task_tensor.slice_spec}\n"
f" task: {task1}\n"
" tensor #2:"
f" checkpoint_key: {checkpoint_key}\n"
f" slice_spec: {slice_spec}\n"
f" task: {task2}\n")
del unseen_tensor_dict[checkpoint_key][slice_spec]
if not unseen_tensor_dict[checkpoint_key]:
del unseen_tensor_dict[checkpoint_key]
seen_tensor_set.add((checkpoint_key, slice_spec))
# validate no tensor removal
if unseen_tensor_dict:
tensors_info = ""
for ckpt_key, slice_spec in unseen_tensor_dict.items():
tensors_info += " tensor:\n"
tensors_info += f" checkpoint_key: {ckpt_key}\n"
tensors_info += f" slice_spec: {slice_spec}\n"
raise RuntimeError(
"After executing the checkpoint sharding callback, tensors in the "
"object graph were not found in the checkpoint shards:\n"
f" callback_description: {callback_description}\n"
f"{tensors_info}")