-
Notifications
You must be signed in to change notification settings - Fork 12
/
maps.py
162 lines (128 loc) · 4.46 KB
/
maps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python
'''Map functions'''
import numpy as np
import six
from .exceptions import DataError, PescadorError
from . import util
__all__ = ['buffer_stream', 'tuples', 'keras_tuples']
def __stack_data(data):
output = dict()
for key in data[0].keys():
output[key] = np.array([x[key] for x in data])
return output
def buffer_stream(stream, buffer_size, partial=False,
generator=util.Deprecated()):
'''Buffer "data" from an stream into one data object.
Parameters
----------
stream : stream
The stream to buffer
buffer_size : int > 0
The number of examples to retain per batch.
partial : bool, default=False
If True, yield a final partial batch on under-run.
generator : stream
.. warning:: This parameter name was deprecated in pescador 1.1
Use the `stream` parameter instead.
The `generator` parameter will be removed in pescador 2.0.
Yields
------
batch
A batch of size at most `buffer_size`
Raises
------
DataError if the stream contains items that are not data-like.
'''
stream = util.rename_kw('generator', generator,
'stream', stream,
'1.1', '2.0')
data = []
n = 0
for x in stream:
data.append(x)
n += 1
if n < buffer_size:
continue
try:
yield __stack_data(data)
except (TypeError, AttributeError):
raise DataError("Malformed data stream: {}".format(data))
finally:
data = []
n = 0
if data and partial:
yield __stack_data(data)
def tuples(stream, *keys):
"""Reformat data as tuples.
Parameters
----------
stream : iterable
Stream of data objects.
*keys : strings
Keys to use for ordering data.
Yields
------
items : tuple of np.ndarrays
Data object reformated as a tuple.
Raises
------
DataError if the stream contains items that are not data-like.
KeyError if a data object does not contain the requested key.
"""
if not keys:
raise PescadorError('Unable to generate tuples from '
'an empty item set')
for data in stream:
try:
yield tuple(data[key] for key in keys)
except TypeError:
raise DataError("Malformed data stream: {}".format(data))
def keras_tuples(stream, inputs=None, outputs=None):
"""Reformat data objects as keras-compatible tuples.
For more detail: https://keras.io/models/model/#fit
Parameters
----------
stream : iterable
Stream of data objects.
inputs : string or iterable of strings, None
Keys to use for ordered input data.
If not specified, returns `None` in its place.
outputs : string or iterable of strings, default=None
Keys to use for ordered output data.
If not specified, returns `None` in its place.
Yields
------
x : np.ndarray, list of np.ndarray, or None
If `inputs` is a string, `x` is a single np.ndarray.
If `inputs` is an iterable of strings, `x` is a list of np.ndarrays.
If `inputs` is a null type, `x` is None.
y : np.ndarray, list of np.ndarray, or None
If `outputs` is a string, `y` is a single np.ndarray.
If `outputs` is an iterable of strings, `y` is a list of np.ndarrays.
If `outputs` is a null type, `y` is None.
Raises
------
DataError if the stream contains items that are not data-like.
"""
flatten_inputs, flatten_outputs = False, False
if inputs and isinstance(inputs, six.string_types):
inputs = [inputs]
flatten_inputs = True
if outputs and isinstance(outputs, six.string_types):
outputs = [outputs]
flatten_outputs = True
inputs, outputs = (inputs or []), (outputs or [])
if not inputs + outputs:
raise PescadorError('At least one key must be given for '
'`inputs` or `outputs`')
for data in stream:
try:
x = list(data[key] for key in inputs) or None
if len(inputs) == 1 and flatten_inputs:
x = x[0]
y = list(data[key] for key in outputs) or None
if len(outputs) == 1 and flatten_outputs:
y = y[0]
yield (x, y)
except TypeError:
raise DataError("Malformed data stream: {}".format(data))