Skip to content

Commit

Permalink
additional unit tests
Browse files Browse the repository at this point in the history
NOTE: back-compatibility issue: dimarray.align([a, b, c]) instead of align(a, b, c)
  • Loading branch information
perrette committed Nov 28, 2015
1 parent 715da85 commit 84e50bd
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 68 deletions.
73 changes: 48 additions & 25 deletions dimarray/core/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,11 @@ def _get_aligned_axes(arrays, join='outer', axis=None , sort=False, strict=False

axes.append(ax)

# assert len(axes) > 0

return axes

def align(*arrays, **kwargs):
def align(arrays, join='outer', axis=None , sort=False, strict=False):
"""Align axes of a list of DimArray arrays by reindexing
Parameters
Expand Down Expand Up @@ -241,13 +243,13 @@ def align(*arrays, **kwargs):
>>> from dimarray import DimArray, align
>>> a = DimArray([0,1,2],axes=[[0,1,2]])
>>> b = DimArray([1,2,3],axes=[[1,2,3]])
>>> align(a, b)
>>> align([a, b])
[dimarray: 3 non-null elements (1 null)
0 / x0 (4): 0 to 3
array([ 0., 1., 2., nan]), dimarray: 3 non-null elements (1 null)
0 / x0 (4): 0 to 3
array([ nan, 1., 2., 3.])]
>>> align(a, b, join='inner')
>>> align([a, b], join='inner')
[dimarray: 2 non-null elements (0 null)
0 / x0 (2): 1 to 2
array([1, 2]), dimarray: 2 non-null elements (0 null)
Expand All @@ -258,7 +260,7 @@ def align(*arrays, **kwargs):
>>> a = DimArray([0,1], axes=[[0,1]]) # on 'x0' only
>>> b = DimArray([[0,1],[2,3.],[4.,5.]], axes=[[0,1,2],[1,2]]) # one more element along the 1st dimension, 2nd dimension ignored
>>> align(a, b)
>>> align([a, b])
[dimarray: 2 non-null elements (1 null)
0 / x0 (3): 0 to 2
array([ 0., 1., nan]), dimarray: 6 non-null elements (0 null)
Expand All @@ -268,31 +270,31 @@ def align(*arrays, **kwargs):
[ 2., 3.],
[ 4., 5.]])]
"""
join = kwargs.pop('join', get_option('align.join'))
sort = kwargs.pop('sort', False)
axis = kwargs.pop('axis', None)
strict = kwargs.pop('strict', False)
if len(kwargs) > 0:
raise TypeError("align() got unexpected argument(s): "+", ".join(kwargs.keys()))
# join = kwargs.pop('join', get_option('align.join'))
# sort = kwargs.pop('sort', False)
# axis = kwargs.pop('axis', None)
# strict = kwargs.pop('strict', False)
# if len(kwargs) > 0:
# raise TypeError("align() got unexpected argument(s): "+", ".join(kwargs.keys()))
if not (isinstance(arrays, list) or isinstance(arrays, tuple)):
raise ValueError("align: only accepts list or tuple arguments. Got: {}".format(type(arrays)))

# convert any scalar to dimarray
from dimarray import DimArray, Dataset
arrays = list(arrays)
arrays = [a for a in arrays] # convert to list
for i, a in enumerate(arrays):
if not isinstance(a, DimArray):
if not isinstance(a, DimArray) and not isinstance(a, Dataset):
if np.isscalar(a):
arrays[i] = DimArray(a)
elif isinstance(a, Dataset):
pass
else:
raise TypeError("can only align DimArray and Dataset instances, got: {}".format(type(a)))

# find the common axes
axes = _get_aligned_axes(arrays, join=join, sort=sort, strict=strict)
axes = _get_aligned_axes(arrays, axis=axis, join=join, sort=sort, strict=strict)

# update arrays
for i, o in enumerate(arrays):
for ax in axes:
for ax in axes:
for i, o in enumerate(arrays):
if ax.name not in o.dims:
continue
if np.all(o.axes[ax.name] == ax):
Expand Down Expand Up @@ -392,7 +394,7 @@ def stack(arrays, axis=None, keys=None, align=False, **kwargs):
# re-index axes if needed
if align:
kwargs['strict'] = True
arrays = align_(*arrays, **kwargs)
arrays = align_(arrays, **kwargs)

# make it a numpy array
data = [a.values for a in arrays]
Expand Down Expand Up @@ -493,7 +495,21 @@ def concatenate(arrays, axis=0, _no_check=False, align=False, **kwargs):
array([[ 1, 2, 3, 4, 5, 6],
[11, 22, 33, 44, 55, 66]])
"""
assert type(arrays) in (list, tuple), "arrays must be list or tuple, got {}:{}".format(type(arrays), arrays)
# input argument check
if not type(arrays) in (list, tuple):
raise ValueError("arrays must be list or tuple, got {}:{}".format(type(arrays), arrays))
arrays = [a for a in arrays]

from dimarray import DimArray, Dataset

for i, a in enumerate(arrays):
if isinstance(a, Dataset):
msg = "\n==>Note: you may use `concatenate_ds` for Datasets"
raise ValueError("concatenate: expected DimArray. Got {}".format(type(a))+msg)
elif np.isscalar(a):
arrays[i] = DimArray(a)
if not isinstance(a, DimArray):
raise ValueError("concatenate: expected DimArray. Got {}".format(type(a)))

if type(axis) is not int:
axis = arrays[0].dims.index(axis)
Expand All @@ -506,7 +522,7 @@ def concatenate(arrays, axis=0, _no_check=False, align=False, **kwargs):
kwargs['strict'] = True
for ax in arrays[0].axes:
if ax.name != dim:
arrays = align_(arrays, **kwargs)
arrays = align_(arrays, axis=ax.name, **kwargs)

values = np.concatenate([a.values for a in arrays], axis=axis)

Expand All @@ -518,11 +534,18 @@ def concatenate(arrays, axis=0, _no_check=False, align=False, **kwargs):

if not align and not _no_check:
# check that other axes match
for i,a in enumerate(arrays[1:]):
if not _get_subaxes(a) == subaxes:
msg = "First array:\n{}\n".format(subaxes)
msg += "{}th array:\n{}\n".format(i,_get_subaxes(a))
raise ValueError("contatenate: secondary axes do not match. Align first? (`align=True`)")
for ax in subaxes:
for a in arrays:
if not np.all(a.axes[ax.name].values == ax.values):
raise ValueError("contatenate: secondary axes do not match. Align first? (`align=True`)")
# print arrays[0]
# for i,a in enumerate(arrays[1:]):
# if not _get_subaxes(a) == subaxes:
# msg = "First array:\n{}\n".format(subaxes)
# msg += "{}th array:\n{}\n".format(i,_get_subaxes(a))
# raise ValueError("contatenate: secondary axes do not match. Align first? (`align=True`)")
# print a
# print '==> arrays look ok'

newaxes = subaxes[:axis] + [newaxis] + subaxes[axis:]

Expand Down
2 changes: 1 addition & 1 deletion dimarray/core/dimarraycls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,7 +1700,7 @@ def array(data, *args, **kwargs):
broadcast = kwargs.pop('broadcast', True)

if reindex:
data = align(*data)
data = align(data)

if broadcast:
data = broadcast_arrays(*data) # make sure the arrays have the same dimension
Expand Down
2 changes: 1 addition & 1 deletion dimarray/core/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def operation(func, o1, o2, reindex=True, broadcast=True, constructor=None):

# Align axes by re-indexing
if reindex:
o1, o2 = align_axes(o1, o2)
o1, o2 = align_axes((o1, o2))

# Align dimensions by adding new axes and transposing if necessary
if broadcast:
Expand Down
10 changes: 7 additions & 3 deletions dimarray/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, *args, **kwargs):
values[i] = self._constructor(values[i])

# Align objects
values = align_axes(*values)
values = align_axes(values)

# Append object (will automatically update self.axes)
for key, value in zip(keys, values):
Expand Down Expand Up @@ -823,7 +823,7 @@ def stack_ds(datasets, axis, keys=None, align=False, **kwargs):
axis = _check_stack_axis(axis, dims)

if align:
datasets = da.align(*datasets, **kwargs)
datasets = da.align(datasets, strict=True, **kwargs)

# find the list of variables common to all datasets
variables = None
Expand Down Expand Up @@ -893,7 +893,11 @@ def concatenate_ds(datasets, axis=0, align=False, **kwargs):
assert sorted(ds.keys()) == sorted(variables), "variables differ across datasets"

if align:
datasets = da.align(*datasets, **kwargs)
# all dataset axes
axis_nm = datasets[0].axes[axis].name
aligned_dims = [d for d in _get_dims(*datasets) if d != axis_nm]
for d in aligned_dims:
datasets = da.align(datasets, axis=d, strict=True, **kwargs)

# Compute concatenated dataset
dataset = Dataset()
Expand Down
4 changes: 2 additions & 2 deletions docs/_notebooks_rst/reindexing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ It is also possible to proceed to axis alignment on a sequence of arrays
>>> from dimarray import align
>>> x = DimArray([1,2,3],('x0',[1,2,3]))
>>> y = DimArray([3,4],('x0',[2,4]))
>>> xa, ya = align(x, y)
>>> xa, ya = align((x, y))
>>> ya
dimarray: 2 non-null elements (2 null)
0 / x0 (4): 1 to 4
array([ nan, 3., nan, 4.])

See :func:`dimarray.align`
See :func:`dimarray.align`
50 changes: 35 additions & 15 deletions tests/test_align.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import numpy as np
from dimarray import DimArray
from dimarray.testing import assert_equal_dimarrays
from dimarray import align, stack, concatenate
from dimarray import align, stack, concatenate, Dataset, stack_ds, concatenate_ds
import pytest

def _make_datasets(*arrays):
# extend the tests to datasets with one array
return [Dataset({'a':a}) for a in arrays]

def test_align():
" test slices in an increasing axis"
a = DimArray([1, 2, 3, 4], axes=[[1, 2, 3, 4]], dims=['dim0'])
Expand All @@ -14,7 +18,7 @@ def test_align():
a2 = DimArray([np.nan, 1, 2, 3, 4, np.nan], axes=[[0, 1, 2, 3, 4, 6]], dims=['dim0'])
b2 = DimArray([0, np.nan, 2, np.nan, 4, 6], axes=[[0, 1, 2, 3, 4, 6]], dims=['dim0'])
# got
a2_got, b2_got = align(a, b, join="outer")
a2_got, b2_got = align([a, b], join="outer")
# check
assert_equal_dimarrays(a2, a2_got)
assert_equal_dimarrays(b2, b2_got)
Expand All @@ -24,7 +28,7 @@ def test_align():
a3 = DimArray([2, 4], axes=[[2,4]], dims=['dim0'])
b3 = DimArray([2, 4], axes=[[2,4]], dims=['dim0'])
# got
a3_got, b3_got = align(a, b, join="inner")
a3_got, b3_got = align([a, b], join="inner")
# check
assert_equal_dimarrays(a3, a3_got)
assert_equal_dimarrays(b3, b3_got)
Expand All @@ -38,13 +42,13 @@ def test_align_unsorted():
b_mess = DimArray([22, 44, 11], axes=[[2, 4, 1]], dims=['dim0'])

# second array unsorted
a_got, b_got = align(a_sorted, b_mess, join="outer")
a_got, b_got = align([a_sorted, b_mess], join="outer")

assert_equal_dimarrays(a_got, a_sorted)
assert_equal_dimarrays(b_got, b_sorted)

# first array unsorted
b_got, a_got = align(b_mess, a_sorted, join="outer")
b_got, a_got = align([b_mess, a_sorted], join="outer")

assert not np.all(a_got.dim0 == a_sorted.dim0) # not equal because not ordered

Expand All @@ -55,10 +59,10 @@ def test_align_unsorted():
assert_equal_dimarrays(b_got, b_sorted)

# do the same, but pass as command line
a_got, b_got = align(a_sorted, b_mess, join="outer", sort=True)
a_got, b_got = align([a_sorted, b_mess], join="outer", sort=True)

# two arrays unsorted
a_got, b_got = align(a_mess, b_mess, join="outer", sort=True)
a_got, b_got = align([a_mess, b_mess], join="outer", sort=True)

assert_equal_dimarrays(a_got, a_sorted)
assert_equal_dimarrays(b_got, b_sorted)
Expand All @@ -71,19 +75,23 @@ def test_stack():
[11, 22, 33]], axes=[['a', 'b'], [0, 1, 2]], dims=['stackdim', 'x0'])

c_got = stack([a, b], axis='stackdim', keys=['a','b'])
c_got_ds = stack_ds(_make_datasets(a, b), axis='stackdim', keys=['a','b'])

assert_equal_dimarrays(c_got, c)
assert_equal_dimarrays(c_got_ds['a'], c)

def test_stack_align():
a = DimArray([1,2,3], axes=[[0,1,2]], dims=['x0'])
b = DimArray([33,11], axes=[[2,0]], dims=['x0'])

c_got = stack([b, a], axis='stackdim', align=True, sort=True, keys=['a','b'])
c_got_ds = stack_ds(_make_datasets(b, a), axis='stackdim', align=True, sort=True, keys=['a','b'])

c = DimArray([[11., np.nan, 33.],
[ 1., 2., 3.]], axes=[['a', 'b'], [0, 1, 2]], dims=['stackdim', 'x0'])

assert_equal_dimarrays(c_got, c)
assert_equal_dimarrays(c_got_ds['a'], c)

def test_stack_fails():
# Should use concatenate instead, because axis is not new !
Expand All @@ -102,29 +110,41 @@ def test_concatenate_1d():
# c = DimArray([1, 2, 3, 4, 5, 6], axes=[['a','b','c','d','e','f']])

c_got = concatenate((a, b))
c_got_ds = concatenate_ds(_make_datasets(a, b))

assert_equal_dimarrays(c_got, c)
assert_equal_dimarrays(c_got_ds['a'], c)

def test_concatenate_2d():

a = DimArray([[1, 2, 3],
[11, 22, 33]], axes=[[0,1],[0,1,2]])
a = DimArray([[ 1, 2, 3],
[11, 22, 33]], axes=[[0,1],[2,1,0]])

b = DimArray([[4, 5, 6],
[44, 55, 66]], axes=[[0,1],[0,1,2]])
b = DimArray([[44, 55, 66],
[4, 5, 6]], axes=[[1,0],[2,1,0]])

c0_got = concatenate((a, b), axis=0)
c0_got_ds = concatenate_ds(_make_datasets(a, b), axis=0)

c0 = DimArray([[ 1, 2, 3],
[11, 22, 33],
[ 4, 5, 6],
[44, 55, 66]], axes=[[0,1,0,1],[0,1,2]])
[44, 55, 66],
[4, 5, 6]], axes=[[0,1,1,0],[2,1,0]])

assert_equal_dimarrays(c0_got, c0)
assert_equal_dimarrays(c0_got_ds['a'], c0)

c1_got = concatenate((a, b), axis=1)
# axis "x0" is not aligned !
with pytest.raises(ValueError):
c1_got = concatenate((a, b), axis=1)
print c1_got

c1_got = concatenate((a, b), axis=1, align=True, sort=True)
c1_got_ds = concatenate_ds(_make_datasets(a, b), axis=1, align=True, sort=True)

c1 = DimArray([[ 1, 2, 3, 4, 5, 6],
[11, 22, 33, 44, 55, 66]], axes=[[0,1],[0,1,2,0,1,2]])
[11, 22, 33, 44, 55, 66]], axes=[[0,1],[2,1,0,2,1,0]])

assert_equal_dimarrays(c1_got, c1)
assert_equal_dimarrays(c1_got_ds['a'], c1)

0 comments on commit 84e50bd

Please sign in to comment.