Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.0.0 labels #263

Merged
merged 15 commits into from
Apr 5, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 18 additions & 1 deletion test/test_images.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from numpy import arange, allclose, array, mean, apply_along_axis

from thunder.images.readers import fromlist
from thunder.images.readers import fromlist, fromarray
from thunder.images.images import Images
from thunder.series.series import Series

Expand Down Expand Up @@ -31,6 +31,23 @@ def test_sample(eng):
assert allclose(data.sample(1).shape, (1, 2, 2))
assert allclose(data.filter(lambda x: x.max() > 5).sample(1).toarray(), [[1, 10], [1, 10]])

def test_labels(eng):
x = arange(10).reshape(10, 1, 1)
data = fromlist(x, labels=range(10), engine=eng)

assert allclose(data.filter(lambda x: x[0, 0]%2==0).labels, array([0, 2, 4, 6, 8]))
assert allclose(data[4:6].labels, array([4, 5]))
assert allclose(data[5].labels, array([5]))
assert allclose(data[[0, 3, 8]].labels, array([0, 3, 8]))


def test_labels_setting(eng):
x = arange(10).reshape(10, 1, 1)
data = fromlist(x, engine=eng)

with pytest.raises(ValueError):
data.labels = range(8)


def test_first(eng):
data = fromlist([array([[1, 5], [1, 5]]), array([[1, 10], [1, 10]])], engine=eng)
Expand Down
30 changes: 29 additions & 1 deletion test/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from numpy import allclose, arange, array, asarray, dot, cov, corrcoef

from thunder.series.readers import fromlist, fromarray
from thunder.images.readers import fromlist as img_fromlist

pytestmark = pytest.mark.usefixtures("eng")

Expand Down Expand Up @@ -174,6 +175,33 @@ def test_min(eng):
expected = data.toarray().min(axis=0)
assert allclose(val, expected)

def test_labels(eng):
x = [array([0, 1]), array([2, 3]), array([4, 5]), array([6, 7])]
data = fromlist(x, labels=[0, 1, 2, 3], engine=eng)

assert allclose(data.filter(lambda x: x[0]>2).labels, array([2, 3]))
assert allclose(data[2:].labels, array([2, 3]))
assert allclose(data[1].labels, array([1]))
assert allclose(data[1, :].labels, array([1]))
assert allclose(data[[0, 2]].labels, array([0, 2]))
assert allclose(data.flatten().labels, array([0, 1, 2, 3]))

x = [array([[0, 1],[2, 3]]), array([[4, 5], [6, 7]])]
data = img_fromlist(x, engine=eng).toseries()
data.labels = [[0, 1], [2, 3]]

assert allclose(data.filter(lambda x: x[0]>1).labels, array([2, 3]))
assert allclose(data[0].labels, array([[0, 1]]))
assert allclose(data[:, 0].labels, array([[0], [2]]))
assert allclose(data.flatten().labels, array([0, 1, 2, 3]))

def test_labels_setting(eng):
x = [array([0, 1]), array([2, 3]), array([4, 5]), array([6, 7])]
data = fromlist(x, engine=eng)

with pytest.raises(ValueError):
data.labels = [0, 1, 2]


def test_index_setting(eng):
data = fromlist([array([1, 2, 3]), array([2, 2, 4]), array([4, 2, 1])], engine=eng)
Expand Down Expand Up @@ -410,4 +438,4 @@ def test_mean_by_window(eng):
test3 = data.mean_by_window(indices=[3, 5], window=4).toarray()
assert allclose(test3, [2, 3, 4, 5])
test4 = data.mean_by_window(indices=[3], window=4).toarray()
assert allclose(test4, [1, 2, 3, 4])
assert allclose(test4, [1, 2, 3, 4])
89 changes: 73 additions & 16 deletions thunder/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from numpy import array, asarray, ndarray, prod, ufunc, add, subtract, \
multiply, divide, isscalar, newaxis, unravel_index
multiply, divide, isscalar, newaxis, unravel_index, argsort
from bolt.utils import inshape, tupleize
from bolt.base import BoltArray
from bolt.spark.array import BoltArraySpark
Expand Down Expand Up @@ -189,14 +189,61 @@ class Data(Base):
functions along axes in a backend-specific manner.
"""
_metadata = Base._metadata
_attributes = Base._attributes + ['labels']

def __getitem__(self, item):
# handle values
if isinstance(item, int):
item = slice(item, item+1, None)
if isinstance(item, (tuple, list)):
if isinstance(item, tuple):
item = tuple([slice(i, i+1, None) if isinstance(i, int) else i for i in item])
if isinstance(item, (list, ndarray)):
item = (item,)
new = self._values.__getitem__(item)
return self._constructor(new).__finalize__(self, noprop=('index'))
result = self._constructor(new).__finalize__(self, noprop=('index', 'labels'))

# handle labels
if self.labels is not None:
if isinstance(item, int):
label_item = ([item],)
elif isinstance(item, (list, ndarray, slice)):
label_item = (item, )
elif isinstance(item, tuple):
label_item = item[:len(self.baseaxes)]
newlabels = self.labels
for (i, s) in enumerate(label_item):
if isinstance(s, slice):
newlabels = newlabels[[s if j==i else slice(None) for j in range(len(label_item))]]
else:
newlabels = newlabels.take(tupleize(s), i)
result.labels = newlabels

return result

@property
def baseaxes(self):
raise NotImplementedError

@property
def baseshape(self):
return self.shape[:len(self.baseaxes)]

@property
def labels(self):
return self._labels

@labels.setter
def labels(self, value):
if value is not None:
try:
value = asarray(value)
except:
raise ValueError("Labels must be convertible to an ndarray")
if value.shape != self.baseshape:
raise ValueError("Labels shape {} must be the same as the leading dimensions of the Series {}"\
.format(value.shape, self.baseshape))

self._labels = value

def astype(self, dtype, casting='unsafe'):
"""
Expand Down Expand Up @@ -273,12 +320,6 @@ def min(self):
"""
raise NotImplementedError

def filter(self, func):
"""
Filter elements.
"""
raise NotImplementedError

def map(self, func, **kwargs):
"""
Map a function over elements.
Expand Down Expand Up @@ -309,14 +350,14 @@ def _align(self, axes, key_shape=None):
linearized_shape = [prod(key_shape)] + remaining_shape

# compute the transpose permutation
transpose_order = axes + remaining
transpose_order = list(axes) + remaining

# transpose the array so that the keys being mapped over come first, then linearize keys
reshaped = self.values.transpose(*transpose_order).reshape(*linearized_shape)

return reshaped

def _filter(self, func, axis=(0,)):
def filter(self, func):
"""
Filter array along an axis.

Expand All @@ -333,15 +374,31 @@ def _filter(self, func, axis=(0,)):
axis : tuple or int, optional, default=(0,)
Axis or multiple axes to filter along.
"""

if self.mode == 'local':
axes = sorted(tupleize(axis))
reshaped = self._align(axes)
reshaped = self._align(self.baseaxes)
filtered = asarray(list(filter(func, reshaped)))
return self._constructor(filtered).__finalize__(self)

if self.labels is not None:
mask = asarray(list(map(func, reshaped)))

if self.mode == 'spark':
filtered = self.values.filter(func)
return self._constructor(filtered).__finalize__(self)

sort = False if self.labels is None else True
filtered = self.values.filter(func, axis=self.baseaxes, sort=sort)

if self.labels is not None:
keys, vals = zip(*self.values.map(func, axis=self.baseaxes, value_shape=(1,)).tordd().collect())
perm = sorted(range(len(keys)), key=keys.__getitem__)
mask = asarray(vals)[perm]

if self.labels is not None:
s1 = prod(self.baseshape)
newlabels = self.labels.reshape(s1, 1)[mask].squeeze()
else:
newlabels = None

return self._constructor(filtered, labels=newlabels).__finalize__(self, noprop=('labels',))

def _map(self, func, axis=(0,), value_shape=None, dtype=None, with_keys=False):
"""
Expand Down
13 changes: 6 additions & 7 deletions thunder/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ class Images(Data):
"""
_metadata = Data._metadata

def __init__(self, values, mode='local'):
def __init__(self, values, labels=None, mode='local'):
super(Images, self).__init__(values, mode=mode)
self.labels = labels

@property
def baseaxes(self):
return (0,)

@property
def _constructor(self):
Expand Down Expand Up @@ -166,12 +171,6 @@ def map(self, func, dims=None, with_keys=False):
"""
return self._map(func, axis=0, value_shape=dims, with_keys=with_keys)

def filter(self, func):
"""
Filter images
"""
return self._filter(func, axis=0)

def reduce(self, func):
"""
Reduce over images
Expand Down