Skip to content

Commit

Permalink
Merge pull request #263 from jwittenbach/1.0.0-labels
Browse files Browse the repository at this point in the history
1.0.0 labels
  • Loading branch information
jwittenbach committed Apr 5, 2016
2 parents 08e8ece + d468976 commit ef6d3a6
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 66 deletions.
19 changes: 18 additions & 1 deletion test/test_images.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from numpy import arange, allclose, array, mean, apply_along_axis

from thunder.images.readers import fromlist
from thunder.images.readers import fromlist, fromarray
from thunder.images.images import Images
from thunder.series.series import Series

Expand Down Expand Up @@ -31,6 +31,23 @@ def test_sample(eng):
assert allclose(data.sample(1).shape, (1, 2, 2))
assert allclose(data.filter(lambda x: x.max() > 5).sample(1).toarray(), [[1, 10], [1, 10]])

def test_labels(eng):
x = arange(10).reshape(10, 1, 1)
data = fromlist(x, labels=range(10), engine=eng)

assert allclose(data.filter(lambda x: x[0, 0]%2==0).labels, array([0, 2, 4, 6, 8]))
assert allclose(data[4:6].labels, array([4, 5]))
assert allclose(data[5].labels, array([5]))
assert allclose(data[[0, 3, 8]].labels, array([0, 3, 8]))


def test_labels_setting(eng):
x = arange(10).reshape(10, 1, 1)
data = fromlist(x, engine=eng)

with pytest.raises(ValueError):
data.labels = range(8)


def test_first(eng):
data = fromlist([array([[1, 5], [1, 5]]), array([[1, 10], [1, 10]])], engine=eng)
Expand Down
30 changes: 29 additions & 1 deletion test/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from numpy import allclose, arange, array, asarray, dot, cov, corrcoef

from thunder.series.readers import fromlist, fromarray
from thunder.images.readers import fromlist as img_fromlist

pytestmark = pytest.mark.usefixtures("eng")

Expand Down Expand Up @@ -174,6 +175,33 @@ def test_min(eng):
expected = data.toarray().min(axis=0)
assert allclose(val, expected)

def test_labels(eng):
x = [array([0, 1]), array([2, 3]), array([4, 5]), array([6, 7])]
data = fromlist(x, labels=[0, 1, 2, 3], engine=eng)

assert allclose(data.filter(lambda x: x[0]>2).labels, array([2, 3]))
assert allclose(data[2:].labels, array([2, 3]))
assert allclose(data[1].labels, array([1]))
assert allclose(data[1, :].labels, array([1]))
assert allclose(data[[0, 2]].labels, array([0, 2]))
assert allclose(data.flatten().labels, array([0, 1, 2, 3]))

x = [array([[0, 1],[2, 3]]), array([[4, 5], [6, 7]])]
data = img_fromlist(x, engine=eng).toseries()
data.labels = [[0, 1], [2, 3]]

assert allclose(data.filter(lambda x: x[0]>1).labels, array([2, 3]))
assert allclose(data[0].labels, array([[0, 1]]))
assert allclose(data[:, 0].labels, array([[0], [2]]))
assert allclose(data.flatten().labels, array([0, 1, 2, 3]))

def test_labels_setting(eng):
x = [array([0, 1]), array([2, 3]), array([4, 5]), array([6, 7])]
data = fromlist(x, engine=eng)

with pytest.raises(ValueError):
data.labels = [0, 1, 2]


def test_index_setting(eng):
data = fromlist([array([1, 2, 3]), array([2, 2, 4]), array([4, 2, 1])], engine=eng)
Expand Down Expand Up @@ -410,4 +438,4 @@ def test_mean_by_window(eng):
test3 = data.mean_by_window(indices=[3, 5], window=4).toarray()
assert allclose(test3, [2, 3, 4, 5])
test4 = data.mean_by_window(indices=[3], window=4).toarray()
assert allclose(test4, [1, 2, 3, 4])
assert allclose(test4, [1, 2, 3, 4])
89 changes: 73 additions & 16 deletions thunder/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from numpy import array, asarray, ndarray, prod, ufunc, add, subtract, \
multiply, divide, isscalar, newaxis, unravel_index
multiply, divide, isscalar, newaxis, unravel_index, argsort
from bolt.utils import inshape, tupleize
from bolt.base import BoltArray
from bolt.spark.array import BoltArraySpark
Expand Down Expand Up @@ -189,14 +189,61 @@ class Data(Base):
functions along axes in a backend-specific manner.
"""
_metadata = Base._metadata
_attributes = Base._attributes + ['labels']

def __getitem__(self, item):
# handle values
if isinstance(item, int):
item = slice(item, item+1, None)
if isinstance(item, (tuple, list)):
if isinstance(item, tuple):
item = tuple([slice(i, i+1, None) if isinstance(i, int) else i for i in item])
if isinstance(item, (list, ndarray)):
item = (item,)
new = self._values.__getitem__(item)
return self._constructor(new).__finalize__(self, noprop=('index'))
result = self._constructor(new).__finalize__(self, noprop=('index', 'labels'))

# handle labels
if self.labels is not None:
if isinstance(item, int):
label_item = ([item],)
elif isinstance(item, (list, ndarray, slice)):
label_item = (item, )
elif isinstance(item, tuple):
label_item = item[:len(self.baseaxes)]
newlabels = self.labels
for (i, s) in enumerate(label_item):
if isinstance(s, slice):
newlabels = newlabels[[s if j==i else slice(None) for j in range(len(label_item))]]
else:
newlabels = newlabels.take(tupleize(s), i)
result.labels = newlabels

return result

@property
def baseaxes(self):
raise NotImplementedError

@property
def baseshape(self):
return self.shape[:len(self.baseaxes)]

@property
def labels(self):
return self._labels

@labels.setter
def labels(self, value):
if value is not None:
try:
value = asarray(value)
except:
raise ValueError("Labels must be convertible to an ndarray")
if value.shape != self.baseshape:
raise ValueError("Labels shape {} must be the same as the leading dimensions of the Series {}"\
.format(value.shape, self.baseshape))

self._labels = value

def astype(self, dtype, casting='unsafe'):
"""
Expand Down Expand Up @@ -273,12 +320,6 @@ def min(self):
"""
raise NotImplementedError

def filter(self, func):
"""
Filter elements.
"""
raise NotImplementedError

def map(self, func, **kwargs):
"""
Map a function over elements.
Expand Down Expand Up @@ -309,14 +350,14 @@ def _align(self, axes, key_shape=None):
linearized_shape = [prod(key_shape)] + remaining_shape

# compute the transpose permutation
transpose_order = axes + remaining
transpose_order = list(axes) + remaining

# transpose the array so that the keys being mapped over come first, then linearize keys
reshaped = self.values.transpose(*transpose_order).reshape(*linearized_shape)

return reshaped

def _filter(self, func, axis=(0,)):
def filter(self, func):
"""
Filter array along an axis.
Expand All @@ -333,15 +374,31 @@ def _filter(self, func, axis=(0,)):
axis : tuple or int, optional, default=(0,)
Axis or multiple axes to filter along.
"""

if self.mode == 'local':
axes = sorted(tupleize(axis))
reshaped = self._align(axes)
reshaped = self._align(self.baseaxes)
filtered = asarray(list(filter(func, reshaped)))
return self._constructor(filtered).__finalize__(self)

if self.labels is not None:
mask = asarray(list(map(func, reshaped)))

if self.mode == 'spark':
filtered = self.values.filter(func)
return self._constructor(filtered).__finalize__(self)

sort = False if self.labels is None else True
filtered = self.values.filter(func, axis=self.baseaxes, sort=sort)

if self.labels is not None:
keys, vals = zip(*self.values.map(func, axis=self.baseaxes, value_shape=(1,)).tordd().collect())
perm = sorted(range(len(keys)), key=keys.__getitem__)
mask = asarray(vals)[perm]

if self.labels is not None:
s1 = prod(self.baseshape)
newlabels = self.labels.reshape(s1, 1)[mask].squeeze()
else:
newlabels = None

return self._constructor(filtered, labels=newlabels).__finalize__(self, noprop=('labels',))

def _map(self, func, axis=(0,), value_shape=None, dtype=None, with_keys=False):
"""
Expand Down
13 changes: 6 additions & 7 deletions thunder/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ class Images(Data):
"""
_metadata = Data._metadata

def __init__(self, values, mode='local'):
def __init__(self, values, labels=None, mode='local'):
super(Images, self).__init__(values, mode=mode)
self.labels = labels

@property
def baseaxes(self):
return (0,)

@property
def _constructor(self):
Expand Down Expand Up @@ -166,12 +171,6 @@ def map(self, func, dims=None, with_keys=False):
"""
return self._map(func, axis=0, value_shape=dims, with_keys=with_keys)

def filter(self, func):
"""
Filter images
"""
return self._filter(func, axis=0)

def reduce(self, func):
"""
Reduce over images
Expand Down

0 comments on commit ef6d3a6

Please sign in to comment.