-
Notifications
You must be signed in to change notification settings - Fork 281
/
serial_ops.py
201 lines (159 loc) · 7.02 KB
/
serial_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Easily save tf.data.Datasets as tfrecord files, and restore tfrecords as Datasets.
The goal of this module is to create a SIMPLE api to tfrecords that can be used without
learning all of the underlying mechanics.
Users only need to deal with 2 functions:
save_dataset(dataset)
dataset = load_dataset(tfrecord, header)
It really is that easy!
To make this work, we create a .header file for each tfrecord which encodes metadata
needed to reconstruct the original dataset.
Note that PyYAML (yaml) package must be installed to make use of this module.
Saving must be done in eager mode, but loading is compatible with both eager and
graph execution modes.
GOTCHAS:
- This module is only compatible with "dictionary-style" datasets {key: val, key2:val2,..., keyN: valN}.
- The restored dataset will have the TFRecord dtypes {float32, int64, string} instead of the original
tensor dtypes. This is always the case with TFRecord datasets, whether you use this module or not.
The original dtypes are stored in the headers if you want to restore them after loading."""
import functools
import os
import tempfile
import numpy as np
import tensorflow as tf
# The three encoding functions.
def _bytes_feature(value):
"""value: list"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def _float_feature(value):
"""value: list"""
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
"""value: list"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
# TODO use base_type() to ensure consistent conversion.
def np_value_to_feature(value):
"""Maps dataset values to tf Features.
Only numpy types are supported since Datasets only contain tensors.
Each datatype should only have one way of being serialized."""
if isinstance(value, np.ndarray):
# feature = _bytes_feature(value.tostring())
if np.issubdtype(value.dtype, np.integer):
feature = _int64_feature(value.flatten())
elif np.issubdtype(value.dtype, np.float):
feature = _float_feature(value.flatten())
elif np.issubdtype(value.dtype, np.bool):
feature = _int64_feature(value.flatten())
else:
raise TypeError(f"value dtype: {value.dtype} is not recognized.")
elif isinstance(value, bytes):
feature = _bytes_feature([value])
elif np.issubdtype(type(value), np.integer):
feature = _int64_feature([value])
elif np.issubdtype(type(value), np.float):
feature = _float_feature([value])
else:
raise TypeError(
f"value type: {type(value)} is not recognized. value must be a valid Numpy object."
)
return feature
def base_type(dtype):
"""Returns the TFRecords allowed type corresponding to dtype."""
int_types = [
tf.int8,
tf.int16,
tf.int32,
tf.int64,
tf.uint8,
tf.uint16,
tf.uint32,
tf.uint64,
tf.qint8,
tf.qint16,
tf.qint32,
tf.bool,
]
float_types = [tf.float16, tf.float32, tf.float64]
byte_types = [tf.string, bytes]
if dtype in int_types:
new_dtype = tf.int64
elif dtype in float_types:
new_dtype = tf.float32
elif dtype in byte_types:
new_dtype = tf.string
else:
raise ValueError(f"dtype {dtype} is not a recognized/supported type!")
return new_dtype
def build_header(dataset):
"""Build header dictionary of metadata for the tensors in the dataset. This will be used when loading
the tfrecords file to reconstruct the original tensors from the raw data. Shape is stored as an array
and dtype is stored as an enumerated value (defined by tensorflow)."""
header = {}
for key in dataset.element_spec.keys():
header[key] = {
"shape": list(dataset.element_spec[key].shape),
"dtype": dataset.element_spec[key].dtype.as_datatype_enum,
}
return header
def build_feature_desc(header):
"""Build feature_desc dictionary for the tensors in the dataset. This will be used to reconstruct Examples
from the tfrecords file.
Assumes FixedLenFeatures.
If you got VarLenFeatures I feel bad for you son,
I got 115 problems but a VarLenFeature ain't one."""
feature_desc = {}
for key, params in header.items():
feature_desc[key] = tf.io.FixedLenFeature(
shape=params["shape"], dtype=base_type(int(params["dtype"]))
)
return feature_desc
def dataset_to_examples(ds):
"""Converts a dataset to a dataset of tf.train.Example strings. Each Example is a single observation.
WARNING: Only compatible with "dictionary-style" datasets {key: val, key2:val2,..., keyN, valN}.
WARNING: Must run in eager mode!"""
# TODO handle tuples and flat datasets as well.
for x in ds:
# Each individual tensor is converted to a known serializable type.
features = {key: np_value_to_feature(value.numpy()) for key, value in x.items()}
# All features are then packaged into a single Example object.
example = tf.train.Example(features=tf.train.Features(feature=features))
yield example.SerializeToString()
def save_dataset(dataset, tfrecord_path, header_path):
"""Saves a flat dataset as a tfrecord file, and builds a header file for reloading as dataset.
Must run in eager mode because it depends on dataset iteration and element_spec."""
import yaml
if not tf.executing_eagerly():
raise ValueError("save_dataset() must run in eager mode!")
# Header
header = build_header(dataset)
header_file = open(header_path, "w")
yaml.dump(header, stream=header_file)
# Dataset
ds_examples = tf.data.Dataset.from_generator(
lambda: dataset_to_examples(dataset), output_types=tf.string
)
writer = tf.data.experimental.TFRecordWriter(tfrecord_path)
writer.write(ds_examples)
# TODO-DECIDE is this yaml loader safe?
def load_dataset(tfrecord_path, header_path):
"""Uses header file to predict the shape and dtypes of tensors for tf.data."""
import yaml
header_file = open(header_path)
header = yaml.load(header_file, Loader=yaml.FullLoader)
feature_desc = build_feature_desc(header)
parse_func = functools.partial(tf.io.parse_single_example, features=feature_desc)
dataset = tf.data.TFRecordDataset(tfrecord_path).map(parse_func)
return dataset