Skip to content

Commit

Permalink
Incorporated bmcfee + cjacoby feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric Humphrey committed Apr 17, 2017
1 parent eb91dba commit b65d43f
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/frameworks/keras_example.py
Expand Up @@ -139,7 +139,7 @@ def sampler(X, y):
'''
X = np.atleast_2d(X)
# y's are binary vectors, and should be of shape (10,) after this.
y = np.atleast_2d(y)
y = np.atleast_1d(y)

n = X.shape[0]

Expand Down
2 changes: 1 addition & 1 deletion examples/mux_files_example.py
Expand Up @@ -61,7 +61,7 @@ def npz_generator(npz_path):
"""Generate data from an npz file."""
npz_data = np.load(npz_path)
X = npz_data['X']
# y's are binary vectors, and should be of shape (10,) after this.
# Y is a binary maxtrix with shape=(n, k), each y will have shape=(k,)
y = npz_data['Y']

n = X.shape[0]
Expand Down
2 changes: 1 addition & 1 deletion pescador/buffered.py
Expand Up @@ -74,7 +74,7 @@ def iterate(self, max_iter=None, partial=True):
not the number of individual samples.
partial : bool, default=True
Return buffers smaller than the requested size.
If True, will return a final batch smaller than the requested size.
"""
with core.StreamActivator(self):
for n, batch in enumerate(buffer_stream(self.stream_,
Expand Down
7 changes: 5 additions & 2 deletions pescador/maps.py
Expand Up @@ -134,10 +134,13 @@ def keras_tuples(stream, inputs=None, outputs=None):
------
DataError if the stream contains items that are not data-like.
"""
flatten_inputs, flatten_outputs = False, False
if inputs and isinstance(inputs, six.string_types):
inputs = [inputs]
flatten_inputs = True
if outputs and isinstance(outputs, six.string_types):
outputs = [outputs]
flatten_outputs = True

inputs, outputs = (inputs or []), (outputs or [])
if not inputs + outputs:
Expand All @@ -147,11 +150,11 @@ def keras_tuples(stream, inputs=None, outputs=None):
for data in stream:
try:
x = list(data[key] for key in inputs) or None
if len(inputs) == 1:
if len(inputs) == 1 and flatten_inputs:
x = x[0]

y = list(data[key] for key in outputs) or None
if len(outputs) == 1:
if len(outputs) == 1 and flatten_outputs:
y = y[0]

yield (x, y)
Expand Down
2 changes: 1 addition & 1 deletion pescador/mux.py
Expand Up @@ -48,7 +48,7 @@ def __init__(self, streamers, k,
k : int > 0
The number of streams to keep active at any time.
lam : float > 0 or None
rate : float > 0 or None
Rate parameter for the Poisson distribution governing sample counts
for individual streams.
If ``None``, sample infinitely from each stream.
Expand Down
5 changes: 3 additions & 2 deletions tests/test_buffered.py
@@ -1,4 +1,5 @@
#!/usr/bin/env python
# TODO: Remove these tests with the `buffered.py` submodule at 2.0 release.

import pytest
import numpy as np
Expand All @@ -11,7 +12,7 @@
@pytest.mark.parametrize('dimension', [1, 2, 3])
@pytest.mark.parametrize('batch_size', [1, 2, 5, 17])
@pytest.mark.parametrize('buf_size', [1, 2, 5, 17, 100])
def test_buffer_streamer(dimension, batch_size, buf_size):
def test_BufferedStreamer(dimension, batch_size, buf_size):

key = 'X'

Expand Down Expand Up @@ -44,7 +45,7 @@ def __unpack_stream(stream):
@pytest.mark.parametrize('dimension', [1, 2, 3])
@pytest.mark.parametrize('batch_size', [1, 2, 5, 17])
@pytest.mark.parametrize('buf_size', [1, 2, 5, 17, 100])
def test_buffer_streamer_tuple(dimension, batch_size, buf_size, items):
def test_BufferedStreamer_tuples(dimension, batch_size, buf_size, items):

gen_stream = pescador.Streamer(T.md_generator, dimension, 50,
size=batch_size, items=items)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_maps.py
Expand Up @@ -84,6 +84,11 @@ def test_keras_tuples(sample_data):
assert n == x ** 0.5
assert y is None

stream = pescador.maps.keras_tuples(sample_data, inputs=["bang"])
for n, (x, y) in enumerate(stream):
assert n == x[0] ** 0.5
assert y is None

stream = pescador.maps.keras_tuples(sample_data, inputs=["foo", "bang"],
outputs=["bar", "whiz"])
for n, (x, y) in enumerate(stream):
Expand Down

0 comments on commit b65d43f

Please sign in to comment.