/
_tfdata.py
494 lines (425 loc) · 14.7 KB
/
_tfdata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
"""Internal code for saving and loading :py:class:`tf.data.Dataset` pipelines."""
import os
from typing import Callable, Optional
import tensorflow as tf
import scalarstop.pickle
from scalarstop._constants import _ELEMENT_SPEC_FILENAME, _TFDATA_DIRECTORY_NAME
from scalarstop._cpu import num_usable_virtual_cpu_cores
from scalarstop.exceptions import (
DataBlobShardingNotSupported,
DataBlobShardingValueError,
ElementSpecNotFound,
TensorFlowDatasetNotFound,
UnsupportedDataBlobSaveLoadVersion,
)
from scalarstop.warnings import warn_deprecated
_ENUMERATE_IDX_ELEMENT_SPEC = tf.TensorSpec(shape=(), dtype=tf.int64, name=None)
def _undo_enumerate(_idx, row):
"""
A :py:meth:`tf.data.Dataset.map` function for
reversing the :py:meth:`tf.data.Dataset.enumerate()`
transformation.
This function takes a :py:class:`tf.data.Dataset`
of shape ``(idx, row)`` and returns the ``row``.
"""
return row
def make_num_shards_on_save(num_shards: int) -> Callable:
"""
Generates a sharding function for :py:func`tf.data.experimental.save`.
Args:
num_shards: The number of distinct files to save to the filesystem.
Returns:
Returns a function that accepts an *enumerated*
:py:class:`tf.data.Dataset` and returns the enumerated
index modulo `num_shards`.
"""
def shard(idx, _row):
return idx % num_shards
return shard
def _load_v1(
*,
tfdata_path: str,
element_spec,
shard_offset: Optional[int],
shard_quantity: int,
total_num_shards: int,
) -> tf.data.Dataset:
"""
tfdata loader v1.
This version does not support sharding. To shard your
:py:class:`DataBlob`, use version 3 or newer.
"""
if shard_offset is not None or total_num_shards != 1:
raise DataBlobShardingNotSupported(
version=1,
offset=shard_offset,
quantity=shard_quantity,
total_num_shards=total_num_shards,
)
return tf.data.experimental.load(
path=tfdata_path,
element_spec=element_spec,
)
def _select_these_shards_v2(
offset: Optional[int], quantity: int, total_num_shards: int
):
"""
Generates a sharding function for :py:func:`tf.data.experimental.load`.
This function does not set the correct ``cycle_length`` when
interleaving datasets. Please use :py:func:`select_these_shards_v3`
or a higher version.
Args:
offset: The first shard index to select from the filesystem. If
this value is ``None``, then we will load all shards.
quantity: The number of consecutive shards to load
starting from (and including) the shard located
at ``offset``. This argument has no effect
if ``offset`` is ``None``.
total_num_shards: The total number of shards in the
saved :py:class:`tf.data.Dataset`.
Returns:
Returns a function that returns an individual
shard as a :py:class:`tf.data.Dataset`.
"""
if total_num_shards < 1:
raise DataBlobShardingValueError(
f"`{total_num_shards=}` cannot be less than 1."
)
if offset is not None:
if offset >= total_num_shards:
raise DataBlobShardingValueError(
f"{offset=} is a shard index that cannot be >= {total_num_shards=}."
)
if quantity > total_num_shards:
raise DataBlobShardingValueError(
f"{quantity=} cannot be greater than > {total_num_shards=}."
)
offset_quantity_sum = offset + quantity
if offset_quantity_sum > total_num_shards:
raise DataBlobShardingValueError(
f"The sum of {offset=} and {quantity=} ({offset_quantity_sum}) "
f"cannot be greater than {total_num_shards=}."
)
def select(datasets):
if offset is not None:
retval = datasets.skip(offset).take(quantity)
else:
retval = datasets
return retval.interleave(lambda x: x)
return select
def _load_v2(
*,
tfdata_path: str,
element_spec,
shard_offset: Optional[int],
shard_quantity: int,
total_num_shards: int,
) -> tf.data.Dataset:
"""
tfdata loader v2.
This version is DEPRECATED because it does not return
elements in order when reading from multiple shards at once.
Onlu use this version if you have existing code and
trained models that depend on this function's broken
behavior.
"""
warn_deprecated(
"You are loading a DataBlob with ScalarStop Load/Save version 2. "
"This version is DEPRECATED because it does not load elements "
"from the saved DataBlob when attempting to load multiple "
"shards at once. Please recreate your saved DataBlobs to"
"migrate to Load/Save version 3."
)
# The v2 `save()` function calls `enumerate()` on the dataset
# before saving. This changes the `element_spec`, and we have
# to account for it here.
element_spec_after_enumerate = (_ENUMERATE_IDX_ELEMENT_SPEC, element_spec)
return tf.data.experimental.load(
path=tfdata_path,
element_spec=element_spec_after_enumerate,
reader_func=_select_these_shards_v2(
offset=shard_offset,
quantity=shard_quantity,
total_num_shards=total_num_shards,
),
).map(_undo_enumerate)
def _select_these_shards_v3(
offset: Optional[int], quantity: int, total_num_shards: int
):
"""
Generates a sharding function for :py:func:`tf.data.experimental.load`.
Args:
offset: The first shard index to select from the filesystem. If
this value is ``None``, then we will load all shards.
quantity: The number of consecutive shards to load
starting from (and including) the shard located
at ``offset``. This argument has no effect
if ``offset`` is ``None``.
total_num_shards: The total number of shards in the
saved :py:class:`tf.data.Dataset`.
Returns:
Returns a function that returns a :py:class:`tf.data.Dataset`
that selects individual shards.
"""
if total_num_shards < 1:
raise DataBlobShardingValueError(
f"`{total_num_shards=}` cannot be less than 1."
)
if offset is not None:
if offset >= total_num_shards:
raise DataBlobShardingValueError(
f"{offset=} is a shard index that cannot be >= {total_num_shards=}."
)
if quantity > total_num_shards:
raise DataBlobShardingValueError(
f"{quantity=} cannot be greater than > {total_num_shards=}."
)
offset_quantity_sum = offset + quantity
if offset_quantity_sum > total_num_shards:
raise DataBlobShardingValueError(
f"The sum of {offset=} and {quantity=} ({offset_quantity_sum}) "
f"cannot be greater than {total_num_shards=}."
)
cycle_length = quantity
else:
cycle_length = total_num_shards
# In our tests, we found that setting num_parallel_calls to
# tf.data.experimental.AUTOTUNE makes loading datasets twice as slow.
# Instead, we'll set the number of parallel calls to the number
# of hyperthreaded CPU cores available to the current process on
# this machine--unless the number of shards that we are loading
# is less than the number of CPU cores.
num_cpus = num_usable_virtual_cpu_cores()
if num_cpus:
num_parallel_calls = min(cycle_length, num_cpus)
else:
# If we were unable to probe the number of virtual
# CPU cores on the current machine, then we'll let
# TensorFlow deal with the problem.
num_parallel_calls = tf.data.experimental.AUTOTUNE
def select(datasets):
"""The actual TensorFlow function for selecting datasets."""
if offset is not None:
retval = datasets.skip(offset).take(quantity)
else:
retval = datasets
return retval.interleave(
lambda x: x,
cycle_length=cycle_length,
num_parallel_calls=num_parallel_calls,
deterministic=True,
)
return select
def _load_v3(
*,
tfdata_path: str,
element_spec,
shard_offset: Optional[int],
shard_quantity: int,
total_num_shards: int,
) -> tf.data.Dataset:
"""
tfdata loader v3.
This fixes an issue with tfdata loader v2 where
rows from shards were being returned in the incorrect
order because we didn't fix the ``cycle_length``
when interleaving datasets.
"""
# The v2 `save()` function calls `enumerate()` on the dataset
# before saving. This changes the `element_spec`, and we have
# to account for it here.
element_spec_after_enumerate = (_ENUMERATE_IDX_ELEMENT_SPEC, element_spec)
return tf.data.experimental.load(
path=tfdata_path,
element_spec=element_spec_after_enumerate,
reader_func=_select_these_shards_v3(
offset=shard_offset,
quantity=shard_quantity,
total_num_shards=total_num_shards,
),
).map(
_undo_enumerate,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
deterministic=True,
)
def _load(
*,
tfdata_path: str,
element_spec,
shard_offset: Optional[int],
shard_quantity: int,
save_load_version: int,
total_num_shards: int,
):
"""load a :py:class:`tf.data.Dataset`."""
if save_load_version == 1:
return _load_v1(
tfdata_path=tfdata_path,
element_spec=element_spec,
shard_offset=shard_offset,
shard_quantity=shard_quantity,
total_num_shards=total_num_shards,
)
if save_load_version == 2:
return _load_v2(
tfdata_path=tfdata_path,
element_spec=element_spec,
shard_offset=shard_offset,
shard_quantity=shard_quantity,
total_num_shards=total_num_shards,
)
if save_load_version == 3:
return _load_v3(
tfdata_path=tfdata_path,
element_spec=element_spec,
shard_offset=shard_offset,
shard_quantity=shard_quantity,
total_num_shards=total_num_shards,
)
raise UnsupportedDataBlobSaveLoadVersion(
version=save_load_version,
)
def tfdata_load(
*,
path: str,
save_load_version,
total_num_shards: int = 1,
element_spec=None,
shard_offset: Optional[int] = None,
shard_quantity: int = 1,
) -> tf.data.Dataset:
"""
Load a :py:class:`tf.data.Dataset` from a filesystem path.
This is a little different from
:py:func:`tf.data.experimental.load` because we save the
`element_spec` in a pickled file above the
:py:class:`tf.data.Dataset` 's directory.
If you want to read a dataset that doesn't have the
``element_spec`` saved on disk, then just specify
the ``element_spec`` keyword argument with your own value.
"""
# Load the element spec.
if element_spec is None:
element_spec_path = os.path.join(path, _ELEMENT_SPEC_FILENAME)
try:
with open(element_spec_path, "rb") as fp:
element_spec = scalarstop.pickle.load(file=fp)
except FileNotFoundError as exc:
raise ElementSpecNotFound(path) from exc
# Load the tf.data Dataset.
tfdata_path = os.path.join(path, _TFDATA_DIRECTORY_NAME)
try:
loaded_tfdata = _load(
tfdata_path=tfdata_path,
element_spec=element_spec,
shard_offset=shard_offset,
shard_quantity=shard_quantity,
save_load_version=save_load_version,
total_num_shards=total_num_shards,
)
except tf.errors.NotFoundError as exc:
raise TensorFlowDatasetNotFound(tfdata_path) from exc
# Tell TensorFlow that we want it to automatically shard
# by data, and not by filename. This is because we are
# not using multiple shard files in a way that is useful
# to TensorFlow if the user wants to shard again later,
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.DATA
)
loaded_tfdata = loaded_tfdata.with_options(options)
return loaded_tfdata
def _save_v1(
*,
dataset: tf.data.Dataset,
tfdata_path: str,
num_shards: int,
):
"""tfdata saver v1."""
if num_shards != 1:
raise DataBlobShardingValueError(
"The ScalarStop DataBlob Persistence Protocol v1 only supports "
f"num_shards=1. You passed {num_shards=}. Try saving with a "
"higher protocol version."
)
return tf.data.experimental.save(
dataset=dataset,
path=tfdata_path,
compression=None,
)
def _save_v2(
*,
dataset: tf.data.Dataset,
tfdata_path: str,
num_shards: int,
):
"""tfdata saver v2."""
return tf.data.experimental.save(
dataset=dataset.enumerate(),
path=tfdata_path,
shard_func=make_num_shards_on_save(num_shards=num_shards),
compression=None,
)
def _save_v3(
*,
dataset: tf.data.Dataset,
tfdata_path: str,
num_shards: int,
):
"""
tfdata saver v3.
This implementation is identical to tfdata saver v2
because the backwards-compatible changes are on the
loading side.
"""
return _save_v2(dataset=dataset, tfdata_path=tfdata_path, num_shards=num_shards)
def _save(
dataset: tf.data.Dataset,
tfdata_path: str,
num_shards: int,
save_load_version: int,
):
"""tfdata saver."""
if save_load_version == 1:
return _save_v1(
dataset=dataset,
tfdata_path=tfdata_path,
num_shards=num_shards,
)
if save_load_version == 2:
return _save_v2(
dataset=dataset,
tfdata_path=tfdata_path,
num_shards=num_shards,
)
if save_load_version == 3:
return _save_v3(
dataset=dataset,
tfdata_path=tfdata_path,
num_shards=num_shards,
)
raise UnsupportedDataBlobSaveLoadVersion(version=save_load_version)
def tfdata_save(
*,
dataset: tf.data.Dataset,
path: str,
num_shards: int,
save_load_version: int,
):
"""Save a tf.data dataset."""
os.mkdir(path)
# Save the tf.data Dataset.
tfdata_path = os.path.join(path, _TFDATA_DIRECTORY_NAME)
_save(
dataset=dataset,
tfdata_path=tfdata_path,
num_shards=num_shards,
save_load_version=save_load_version,
)
# Save the element spec.
element_spec_path = os.path.join(path, _ELEMENT_SPEC_FILENAME)
with open(element_spec_path, "wb") as fh: # type: ignore
scalarstop.pickle.dump(
obj=dataset.element_spec,
file=fh,
)