/
tf_record.py
318 lines (262 loc) · 11.4 KB
/
tf_record.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""For reading and writing TFRecords files."""
from tensorflow.python.lib.io import _pywrap_record_io
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@tf_export(
v1=["io.TFRecordCompressionType", "python_io.TFRecordCompressionType"])
@deprecation.deprecated_endpoints("io.TFRecordCompressionType",
"python_io.TFRecordCompressionType")
class TFRecordCompressionType(object):
"""The type of compression for the record."""
NONE = 0
ZLIB = 1
GZIP = 2
@tf_export(
"io.TFRecordOptions",
v1=["io.TFRecordOptions", "python_io.TFRecordOptions"])
@deprecation.deprecated_endpoints("python_io.TFRecordOptions")
class TFRecordOptions(object):
"""Options used for manipulating TFRecord files."""
compression_type_map = {
TFRecordCompressionType.ZLIB: "ZLIB",
TFRecordCompressionType.GZIP: "GZIP",
TFRecordCompressionType.NONE: ""
}
def __init__(self,
compression_type=None,
flush_mode=None,
input_buffer_size=None,
output_buffer_size=None,
window_bits=None,
compression_level=None,
compression_method=None,
mem_level=None,
compression_strategy=None):
# pylint: disable=line-too-long
"""Creates a `TFRecordOptions` instance.
Options only effect TFRecordWriter when compression_type is not `None`.
Documentation, details, and defaults can be found in
[`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h)
and in the [zlib manual](http://www.zlib.net/manual.html).
Leaving an option as `None` allows C++ to set a reasonable default.
Args:
compression_type: `"GZIP"`, `"ZLIB"`, or `""` (no compression).
flush_mode: flush mode or `None`, Default: Z_NO_FLUSH.
input_buffer_size: int or `None`.
output_buffer_size: int or `None`.
window_bits: int or `None`.
compression_level: 0 to 9, or `None`.
compression_method: compression method or `None`.
mem_level: 1 to 9, or `None`.
compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY.
Returns:
A `TFRecordOptions` object.
Raises:
ValueError: If compression_type is invalid.
"""
# pylint: enable=line-too-long
# Check compression_type is valid, but for backwards compatibility don't
# immediately convert to a string.
self.get_compression_type_string(compression_type)
self.compression_type = compression_type
self.flush_mode = flush_mode
self.input_buffer_size = input_buffer_size
self.output_buffer_size = output_buffer_size
self.window_bits = window_bits
self.compression_level = compression_level
self.compression_method = compression_method
self.mem_level = mem_level
self.compression_strategy = compression_strategy
@classmethod
def get_compression_type_string(cls, options):
"""Convert various option types to a unified string.
Args:
options: `TFRecordOption`, `TFRecordCompressionType`, or string.
Returns:
Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`).
Raises:
ValueError: If compression_type is invalid.
"""
if not options:
return ""
elif isinstance(options, TFRecordOptions):
return cls.get_compression_type_string(options.compression_type)
elif isinstance(options, TFRecordCompressionType):
return cls.compression_type_map[options]
elif options in TFRecordOptions.compression_type_map:
return cls.compression_type_map[options]
elif options in TFRecordOptions.compression_type_map.values():
return options
else:
raise ValueError('Not a valid compression_type: "{}"'.format(options))
def _as_record_writer_options(self):
"""Convert to RecordWriterOptions for use with PyRecordWriter."""
options = _pywrap_record_io.RecordWriterOptions(
compat.as_bytes(
self.get_compression_type_string(self.compression_type)))
if self.flush_mode is not None:
options.zlib_options.flush_mode = self.flush_mode
if self.input_buffer_size is not None:
options.zlib_options.input_buffer_size = self.input_buffer_size
if self.output_buffer_size is not None:
options.zlib_options.output_buffer_size = self.output_buffer_size
if self.window_bits is not None:
options.zlib_options.window_bits = self.window_bits
if self.compression_level is not None:
options.zlib_options.compression_level = self.compression_level
if self.compression_method is not None:
options.zlib_options.compression_method = self.compression_method
if self.mem_level is not None:
options.zlib_options.mem_level = self.mem_level
if self.compression_strategy is not None:
options.zlib_options.compression_strategy = self.compression_strategy
return options
@tf_export(v1=["io.tf_record_iterator", "python_io.tf_record_iterator"])
@deprecation.deprecated(
date=None,
instructions=("Use eager execution and: \n"
"`tf.data.TFRecordDataset(path)`"))
def tf_record_iterator(path, options=None):
"""An iterator that read the records from a TFRecords file.
Args:
path: The path to the TFRecords file.
options: (optional) A TFRecordOptions object.
Returns:
An iterator of serialized TFRecords.
Raises:
IOError: If `path` cannot be opened for reading.
"""
compression_type = TFRecordOptions.get_compression_type_string(options)
return _pywrap_record_io.RecordIterator(path, compression_type)
def tf_record_random_reader(path):
"""Creates a reader that allows random-access reads from a TFRecords file.
The created reader object has the following method:
- `read(offset)`, which returns a tuple of `(record, ending_offset)`, where
`record` is the TFRecord read at the offset, and
`ending_offset` is the ending offset of the read record.
The method throws a `tf.errors.DataLossError` if data is corrupted at
the given offset. The method throws `IndexError` if the offset is out of
range for the TFRecords file.
Usage example:
```py
reader = tf_record_random_reader(file_path)
record_1, offset_1 = reader.read(0) # 0 is the initial offset.
# offset_1 is the ending offset of the 1st record and the starting offset of
# the next.
record_2, offset_2 = reader.read(offset_1)
# offset_2 is the ending offset of the 2nd record and the starting offset of
# the next.
# We can jump back and read the first record again if so desired.
reader.read(0)
```
Args:
path: The path to the TFRecords file.
Returns:
An object that supports random-access reading of the serialized TFRecords.
Raises:
IOError: If `path` cannot be opened for reading.
"""
return _pywrap_record_io.RandomRecordReader(path)
@tf_export(
"io.TFRecordWriter", v1=["io.TFRecordWriter", "python_io.TFRecordWriter"])
@deprecation.deprecated_endpoints("python_io.TFRecordWriter")
class TFRecordWriter(_pywrap_record_io.RecordWriter):
"""A class to write records to a TFRecords file.
[TFRecords tutorial](https://www.tensorflow.org/tutorials/load_data/tfrecord)
TFRecords is a binary format which is optimized for high throughput data
retrieval, generally in conjunction with `tf.data`. `TFRecordWriter` is used
to write serialized examples to a file for later consumption. The key steps
are:
Ahead of time:
- [Convert data into a serialized format](
https://www.tensorflow.org/tutorials/load_data/tfrecord#tfexample)
- [Write the serialized data to one or more files](
https://www.tensorflow.org/tutorials/load_data/tfrecord#tfrecord_files_in_python)
During training or evaluation:
- [Read serialized examples into memory](
https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file)
- [Parse (deserialize) examples](
https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file)
A minimal example is given below:
>>> import tempfile
>>> example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords")
>>> np.random.seed(0)
>>> # Write the records to a file.
... with tf.io.TFRecordWriter(example_path) as file_writer:
... for _ in range(4):
... x, y = np.random.random(), np.random.random()
...
... record_bytes = tf.train.Example(features=tf.train.Features(feature={
... "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
... "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
... })).SerializeToString()
... file_writer.write(record_bytes)
>>> # Read the data back out.
>>> def decode_fn(record_bytes):
... return tf.io.parse_single_example(
... # Data
... record_bytes,
...
... # Schema
... {"x": tf.io.FixedLenFeature([], dtype=tf.float32),
... "y": tf.io.FixedLenFeature([], dtype=tf.float32)}
... )
>>> for batch in tf.data.TFRecordDataset([example_path]).map(decode_fn):
... print("x = {x:.4f}, y = {y:.4f}".format(**batch))
x = 0.5488, y = 0.7152
x = 0.6028, y = 0.5449
x = 0.4237, y = 0.6459
x = 0.4376, y = 0.8918
This class implements `__enter__` and `__exit__`, and can be used
in `with` blocks like a normal file. (See the usage example above.)
"""
# TODO(josh11b): Support appending?
def __init__(self, path, options=None):
"""Opens file `path` and creates a `TFRecordWriter` writing to it.
Args:
path: The path to the TFRecords file.
options: (optional) String specifying compression type,
`TFRecordCompressionType`, or `TFRecordOptions` object.
Raises:
IOError: If `path` cannot be opened for writing.
ValueError: If valid compression_type can't be determined from `options`.
"""
if not isinstance(options, TFRecordOptions):
options = TFRecordOptions(compression_type=options)
# pylint: disable=protected-access
super(TFRecordWriter, self).__init__(
compat.as_bytes(path), options._as_record_writer_options())
# pylint: enable=protected-access
# TODO(slebedev): The following wrapper methods are there to compensate
# for lack of signatures in pybind11-generated classes. Switch to
# __text_signature__ when TensorFlow drops Python 2.X support.
# See https://github.com/pybind/pybind11/issues/945
# pylint: disable=useless-super-delegation
def write(self, record):
"""Write a string record to the file.
Args:
record: str
"""
super(TFRecordWriter, self).write(record)
def flush(self):
"""Flush the file."""
super(TFRecordWriter, self).flush()
def close(self):
"""Close the file."""
super(TFRecordWriter, self).close()
# pylint: enable=useless-super-delegation