-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
file_adapters.py
458 lines (372 loc) · 13.5 KB
/
file_adapters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
# coding=utf-8
# Copyright 2024 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adapters for file formats."""
from __future__ import annotations
import abc
from collections.abc import Iterable, Iterator
import enum
import itertools
import os
import re
from typing import Any, ClassVar, Type, TypeVar
from etils import epath
from tensorflow_datasets.core.utils import file_utils
from tensorflow_datasets.core.utils import type_utils
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_module
from tensorflow_datasets.core.utils.lazy_imports_utils import parquet as pq
from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
ExamplePositions = list[Any]
T = TypeVar('T')
class FileFormat(enum.Enum):
"""Format of the record files.
The values of the enumeration are used as filename endings/suffix.
"""
TFRECORD = 'tfrecord'
RIEGELI = 'riegeli'
ARRAY_RECORD = 'array_record'
PARQUET = 'parquet'
@property
def file_suffix(self) -> str:
return ADAPTER_FOR_FORMAT[self].FILE_SUFFIX
@classmethod
def with_random_access(cls) -> set[FileFormat]:
"""File formats with random access."""
return {
file_format
for file_format, adapter in ADAPTER_FOR_FORMAT.items()
if adapter.SUPPORTS_RANDOM_ACCESS
}
@classmethod
def with_tf_data(cls) -> set[FileFormat]:
"""File formats with tf.data support."""
return {
file_format
for file_format, adapter in ADAPTER_FOR_FORMAT.items()
if adapter.SUPPORTS_TF_DATA
}
@classmethod
def with_suffix_before_shard_spec(cls) -> set[FileFormat]:
"""File formats with suffix before shard spec."""
return {
file_format
for file_format, adapter in ADAPTER_FOR_FORMAT.items()
if adapter.SUFFIX_BEFORE_SHARD_SPEC
}
@classmethod
def with_suffix_after_shard_spec(cls) -> set[FileFormat]:
"""File formats with suffix after shard spec."""
return {
file_format
for file_format, adapter in ADAPTER_FOR_FORMAT.items()
if not adapter.SUFFIX_BEFORE_SHARD_SPEC
}
@classmethod
def from_value(cls, file_format: str | FileFormat) -> FileFormat:
try:
return cls(file_format)
except ValueError as e:
all_values = [f.value for f in cls]
raise ValueError(
f'{file_format} is not a valid FileFormat! '
f'Valid file formats: {all_values}'
) from e
DEFAULT_FILE_FORMAT = FileFormat.TFRECORD
class FileAdapter(abc.ABC):
"""Interface for Adapter objects which read and write examples in a format."""
FILE_SUFFIX: ClassVar[str]
# Whether the file format suffix should go before the shard spec.
# For example, `dataset-train.tfrecord-00000-of-00001` if `True`,
# otherwise `dataset-train-00000-of-00001.tfrecord`.
SUFFIX_BEFORE_SHARD_SPEC: ClassVar[bool] = True
SUPPORTS_RANDOM_ACCESS: ClassVar[bool]
SUPPORTS_TF_DATA: ClassVar[bool]
BUFFER_SIZE = 8 << 20 # 8 MiB per file.
@classmethod
@abc.abstractmethod
def make_tf_data(
cls,
filename: epath.PathLike,
buffer_size: int | None = None,
) -> tf.data.Dataset:
"""Returns TensorFlow Dataset comprising given record file."""
raise NotImplementedError()
@classmethod
@abc.abstractmethod
def write_examples(
cls,
path: epath.PathLike,
iterator: Iterable[type_utils.KeySerializedExample],
) -> ExamplePositions | None:
"""Write examples from given iterator in given path.
Args:
path: Path where to write the examples.
iterator: Iterable of examples.
Returns:
List of record positions for each record in the given iterator. In case of
TFRecords, does not return anything.
"""
raise NotImplementedError()
class TfRecordFileAdapter(FileAdapter):
"""File adapter for TFRecord file format."""
FILE_SUFFIX = 'tfrecord'
SUPPORTS_RANDOM_ACCESS = False
SUPPORTS_TF_DATA = True
@classmethod
def make_tf_data(
cls,
filename: epath.PathLike,
buffer_size: int | None = None,
) -> tf.data.Dataset:
"""Returns TensorFlow Dataset comprising given record file."""
buffer_size = buffer_size or cls.BUFFER_SIZE
return tf.data.TFRecordDataset(filename, buffer_size=buffer_size)
@classmethod
def write_examples(
cls,
path: epath.PathLike,
iterator: Iterable[type_utils.KeySerializedExample],
) -> ExamplePositions | None:
"""Write examples from given iterator in given path.
Args:
path: Path where to write the examples.
iterator: Iterable of examples.
Returns:
None
"""
with tf.io.TFRecordWriter(os.fspath(path)) as writer:
for _, serialized_example in iterator:
writer.write(serialized_example)
writer.flush()
class RiegeliFileAdapter(FileAdapter):
"""File adapter for Riegeli file format."""
FILE_SUFFIX = 'riegeli'
SUPPORTS_RANDOM_ACCESS = False
SUPPORTS_TF_DATA = True
@classmethod
def make_tf_data(
cls,
filename: epath.PathLike,
buffer_size: int | None = None,
) -> tf.data.Dataset:
buffer_size = buffer_size or cls.BUFFER_SIZE
from riegeli.tensorflow.ops import riegeli_dataset_ops as riegeli_tf # pylint: disable=g-import-not-at-top # pytype: disable=import-error
return riegeli_tf.RiegeliDataset(filename, buffer_size=buffer_size)
@classmethod
def write_examples(
cls,
path: epath.PathLike,
iterator: Iterable[type_utils.KeySerializedExample],
) -> ExamplePositions | None:
"""Write examples from given iterator in given path.
Args:
path: Path where to write the examples.
iterator: Iterable of examples.
Returns:
List of record positions for each record in the given iterator.
"""
positions = []
import riegeli # pylint: disable=g-import-not-at-top
with tf.io.gfile.GFile(os.fspath(path), 'wb') as f:
with riegeli.RecordWriter(f, options='transpose') as writer:
for _, record in iterator:
writer.write_record(record)
positions.append(writer.last_pos)
return positions
class ArrayRecordFileAdapter(FileAdapter):
"""File adapter for ArrayRecord file format."""
FILE_SUFFIX = 'array_record'
SUPPORTS_RANDOM_ACCESS = True
SUPPORTS_TF_DATA = False
@classmethod
def make_tf_data(
cls,
filename: epath.PathLike,
buffer_size: int | None = None,
) -> tf.data.Dataset:
"""Returns TensorFlow Dataset comprising given array record file."""
raise NotImplementedError(
'`.as_dataset()` not implemented for ArrayRecord files. Please, use'
' `.as_data_source()`.'
)
@classmethod
def write_examples(
cls,
path: epath.PathLike,
iterator: Iterable[type_utils.KeySerializedExample],
) -> ExamplePositions | None:
"""Write examples from given iterator in given path.
Args:
path: Path where to write the examples.
iterator: Iterable of examples.
Returns:
None
"""
writer = array_record_module.ArrayRecordWriter(
os.fspath(path), 'group_size:1'
)
for _, serialized_example in iterator:
writer.write(serialized_example)
writer.close()
class ParquetFileAdapter(FileAdapter):
"""File adapter for the [Parquet](https://parquet.apache.org) file format.
This FileAdapter requires `pyarrow` as a dependency and builds upon
`pyarrow.parquet`.
At the moment, the Parquet adapter doesn't leverage Parquet's columnar
features and behaves like any other adapter. Instead of saving the features in
the columns, we use one single `data` column where we store the serialized
tf.Example proto.
TODO(b/317277518): Let Parquet handle the serialization/deserialization.
"""
FILE_SUFFIX = 'parquet'
SUPPORTS_RANDOM_ACCESS = True
SUPPORTS_TF_DATA = True
_PARQUET_FIELD = 'data'
_BATCH_SIZE = 100
@classmethod
def _schema(cls) -> pa.Schema:
"""Returns the Parquet schema as a one-column `data` binary field."""
return pa.schema([pa.field(cls._PARQUET_FIELD, pa.binary())])
@classmethod
def make_tf_data(
cls,
filename: epath.PathLike,
buffer_size: int | None = None,
) -> tf.data.Dataset:
"""Reads a Parquet file as a tf.data.Dataset.
Args:
filename: Path to the Parquet file.
buffer_size: Unused buffer size.
Returns:
A tf.data.Dataset with the serialized examples.
"""
del buffer_size # unused
def get_data(py_filename: bytes) -> Iterator[tf.Tensor]:
table = pq.read_table(py_filename.decode(), schema=cls._schema())
for batch in table.to_batches():
for example in batch.to_pylist():
yield tf.constant(example[cls._PARQUET_FIELD])
return tf.data.Dataset.from_generator(
get_data,
args=(filename,),
output_signature=tf.TensorSpec(shape=(), dtype=tf.string),
)
@classmethod
def write_examples(
cls,
path: epath.PathLike,
iterator: Iterable[type_utils.KeySerializedExample],
) -> None:
"""Writes the serialized tf.Example proto in a binary field named `data`.
Args:
path: Path to the Parquet file.
iterator: Iterable of serialized examples.
"""
with pq.ParquetWriter(path, schema=cls._schema()) as writer:
for examples in _batched(iterator, cls._BATCH_SIZE):
examples = [{cls._PARQUET_FIELD: example} for _, example in examples]
batch = pa.RecordBatch.from_pylist(examples)
writer.write_batch(batch)
return None
def _to_bytes(key: type_utils.Key) -> bytes:
"""Convert the key to bytes."""
if isinstance(key, int):
return key.to_bytes(128, byteorder='big') # Use 128 as this match md5
elif isinstance(key, bytes):
return key
elif isinstance(key, str):
return key.encode('utf-8')
else:
raise TypeError(f'Invalid key type: {type(key)}')
# Create a mapping from FileFormat -> FileAdapter.
ADAPTER_FOR_FORMAT: dict[FileFormat, Type[FileAdapter]] = {
FileFormat.ARRAY_RECORD: ArrayRecordFileAdapter,
FileFormat.PARQUET: ParquetFileAdapter,
FileFormat.RIEGELI: RiegeliFileAdapter,
FileFormat.TFRECORD: TfRecordFileAdapter,
}
_FILE_SUFFIX_TO_FORMAT = {
adapter.FILE_SUFFIX: file_format
for file_format, adapter in ADAPTER_FOR_FORMAT.items()
}
def file_format_from_suffix(file_suffix: str) -> FileFormat:
"""Returns the file format associated with the file extension (`tfrecord`)."""
if file_suffix not in _FILE_SUFFIX_TO_FORMAT:
raise ValueError(
'Unrecognized file extension: Should be one of '
f'{list(_FILE_SUFFIX_TO_FORMAT.values())}'
)
return _FILE_SUFFIX_TO_FORMAT[file_suffix]
def is_example_file(filename: str) -> bool:
"""Whether the given filename is a record file."""
return any(
f'.{adapter.FILE_SUFFIX}' in filename
for adapter in ADAPTER_FOR_FORMAT.values()
)
def _batched(iterator: Iterator[T] | Iterable[T], n: int) -> Iterator[list[T]]:
"""Batches the result of an iterator into lists of length n.
This function is built-in the standard library from 3.12 (source:
https://docs.python.org/3/library/itertools.html#itertools.batched). However,
TFDS supports older versions of Python.
Args:
iterator: The iterator to batch.
n: The maximal length of each batch.
Yields:
The next list of n elements.
"""
i = 0
while True:
batch = list(itertools.islice(iterator, i, i + n))
if not batch:
return
yield batch
i += n
def convert_path_to_file_format(
path: epath.PathLike, file_format: FileFormat
) -> epath.Path:
"""Returns the path to a specific shard for a different file format.
TFDS can store the file format in the filename as a suffix or as an infix. For
example:
- `dataset-train.<FILE_FORMAT>-00000-of-00001`, a so-called infix format
because the file format comes before the shard spec.
- `dataset-train-00000-of-00001.<FILE_FORMAT>`, a so-called suffix format
because the file format comes after the shard spec.
Args:
path: The path of a specific to convert. Can be the path for different file
formats.
file_format: The file format to which the shard path should be converted.
"""
path = epath.Path(path)
file_name: str = path.name
if file_format.file_suffix in file_name:
# Already has the right file format in the file name.
return path
infix_formats = FileFormat.with_suffix_before_shard_spec()
suffix_formats = FileFormat.with_suffix_after_shard_spec()
# Remove any existing file format from the file name.
infix_format_concat = '|'.join(f.file_suffix for f in infix_formats)
file_name = re.sub(rf'(\.({infix_format_concat}))', '', file_name)
suffix_formats_concat = '|'.join(f.file_suffix for f in suffix_formats)
file_name = re.sub(rf'(\.({suffix_formats_concat}))$', '', file_name)
# Add back the proper file format.
if file_format in suffix_formats:
file_name = f'{file_name}.{file_format.file_suffix}'
else:
file_name = re.sub(
r'-(\d+)-of-(\d+)',
rf'.{file_format.file_suffix}-\1-of-\2',
file_name,
)
return path.parent / file_name