/
io_utils.py
190 lines (158 loc) · 5.44 KB
/
io_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=g-import-not-at-top
"""Utilities related to disk I/O."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
import six
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.util.tf_export import keras_export
try:
import h5py
except ImportError:
h5py = None
@keras_export('keras.utils.HDF5Matrix')
class HDF5Matrix(object):
"""Representation of HDF5 dataset to be used instead of a Numpy array.
Example:
```python
x_data = HDF5Matrix('input/file.hdf5', 'data')
model.predict(x_data)
```
Providing `start` and `end` allows use of a slice of the dataset.
Optionally, a normalizer function (or lambda) can be given. This will
be called on every slice of data retrieved.
Arguments:
datapath: string, path to a HDF5 file
dataset: string, name of the HDF5 dataset in the file specified
in datapath
start: int, start of desired slice of the specified dataset
end: int, end of desired slice of the specified dataset
normalizer: function to be called on data when retrieved
Returns:
An array-like HDF5 dataset.
"""
refs = collections.defaultdict(int)
def __init__(self, datapath, dataset, start=0, end=None, normalizer=None):
if h5py is None:
raise ImportError('The use of HDF5Matrix requires '
'HDF5 and h5py installed.')
if datapath not in list(self.refs.keys()):
f = h5py.File(datapath)
self.refs[datapath] = f
else:
f = self.refs[datapath]
self.data = f[dataset]
self.start = start
if end is None:
self.end = self.data.shape[0]
else:
self.end = end
self.normalizer = normalizer
def __len__(self):
return self.end - self.start
def __getitem__(self, key):
if isinstance(key, slice):
start, stop = key.start, key.stop
if start is None:
start = 0
if stop is None:
stop = self.shape[0]
if stop + self.start <= self.end:
idx = slice(start + self.start, stop + self.start)
else:
raise IndexError
elif isinstance(key, (int, np.integer)):
if key + self.start < self.end:
idx = key + self.start
else:
raise IndexError
elif isinstance(key, np.ndarray):
if np.max(key) + self.start < self.end:
idx = (self.start + key).tolist()
else:
raise IndexError
else:
# Assume list/iterable
if max(key) + self.start < self.end:
idx = [x + self.start for x in key]
else:
raise IndexError
if self.normalizer is not None:
return self.normalizer(self.data[idx])
else:
return self.data[idx]
@property
def shape(self):
"""Gets a numpy-style shape tuple giving the dataset dimensions.
Returns:
A numpy-style shape tuple.
"""
return (self.end - self.start,) + self.data.shape[1:]
@property
def dtype(self):
"""Gets the datatype of the dataset.
Returns:
A numpy dtype string.
"""
return self.data.dtype
@property
def ndim(self):
"""Gets the number of dimensions (rank) of the dataset.
Returns:
An integer denoting the number of dimensions (rank) of the dataset.
"""
return self.data.ndim
@property
def size(self):
"""Gets the total dataset size (number of elements).
Returns:
An integer denoting the number of elements in the dataset.
"""
return np.prod(self.shape)
@staticmethod
def _to_type_spec(value):
"""Gets the Tensorflow TypeSpec corresponding to the passed dataset.
Args:
value: A HDF5Matrix object.
Returns:
A tf.TensorSpec.
"""
if not isinstance(value, HDF5Matrix):
raise TypeError('Expected value to be a HDF5Matrix, but saw: {}'.format(
type(value)))
return tensor_spec.TensorSpec(shape=value.shape, dtype=value.dtype)
type_spec.register_type_spec_from_value_converter(HDF5Matrix,
HDF5Matrix._to_type_spec) # pylint: disable=protected-access
def ask_to_proceed_with_overwrite(filepath):
"""Produces a prompt asking about overwriting a file.
Arguments:
filepath: the path to the file to be overwritten.
Returns:
True if we can proceed with overwrite, False otherwise.
"""
overwrite = six.moves.input('[WARNING] %s already exists - overwrite? '
'[y/n]' % (filepath)).strip().lower()
while overwrite not in ('y', 'n'):
overwrite = six.moves.input('Enter "y" (overwrite) or "n" '
'(cancel).').strip().lower()
if overwrite == 'n':
return False
print('[TIP] Next time specify overwrite=True!')
return True