Skip to content

Commit

Permalink
Fix issue with expand_dims on 0d arrays of tuples (#867)
Browse files Browse the repository at this point in the history
* Fix issue with expand_dims on 0d arrays of tuples

This was originally reported on the mailing list:
https://groups.google.com/forum/#!topic/xarray/fz7HHgpgwk0

Unfortunately, the fix requires a backwards compatibility break, because it
changes how 0d object arrays are handled:

In [4]: xr.Variable([], object()).values
Out[4]: array(<object object at 0x10072e2e0>, dtype=object)

Previously, we just returned the original object.

* Add another test

* Fix issue with squeeze() on object arrays of lists

* another test

* another test
  • Loading branch information
shoyer committed Jul 29, 2016
1 parent 6aedf98 commit 5433d9d
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 45 deletions.
15 changes: 12 additions & 3 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,16 @@ Breaking changes
~~~~~~~~~~~~~~~~

- Dropped support for Python 2.6 (:issue:`855`).
- Indexing on multi-index now drop levels, which is consitent with pandas.
- Indexing on multi-index now drop levels, which is consistent with pandas.
It also changes the name of the dimension / coordinate when the multi-index is
reduced to a single index.
- Contour plots no longer add a colorbar per default (:issue:`866`).
reduced to a single index (:issue:`802`).
- Contour plots no longer add a colorbar per default (:issue:`866`). Filled
contour plots are unchanged.
- ``DataArray.values`` and ``.data`` now always returns an NumPy array-like
object, even for 0-dimensional arrays with object dtype (:issue:`867`).
Previously, ``.values`` returned native Python objects in such cases. To
convert the values of scalar arrays to Python objects, use the ``.item()``
method.

Enhancements
~~~~~~~~~~~~
Expand Down Expand Up @@ -102,6 +108,9 @@ Bug fixes
- ``Variable.copy(deep=True)`` no longer converts MultiIndex into a base Index
(:issue:`769`). By `Benoit Bovy <https://github.com/benbovy>`_.

- Fixes for groupby on dimensions with a multi-index (:issue:`867`). By
`Stephan Hoyer <https://github.com/shoyer>`_.

- Fix printing datasets with unicode attributes on Python 2 (:issue:`892`). By
`Stephan Hoyer <https://github.com/shoyer>`_.

Expand Down
39 changes: 27 additions & 12 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _asarray_tuplesafe(values):
Adapted from pandas.core.common._asarray_tuplesafe
"""
if isinstance(values, tuple):
result = utils.tuple_to_0darray(values)
result = utils.to_0d_object_array(values)
else:
result = np.asarray(values)
if result.ndim == 2:
Expand Down Expand Up @@ -396,9 +396,18 @@ def _convert_key(self, key):
key = orthogonal_indexer(key, self.shape)
return key

def _ensure_ndarray(self, value):
# We always want the result of indexing to be a NumPy array. If it's
# not, then it really should be a 0d array. Doing the coercion here
# instead of inside variable.as_compatible_data makes it less error
# prone.
if not isinstance(value, np.ndarray):
value = utils.to_0d_array(value)
return value

def __getitem__(self, key):
key = self._convert_key(key)
return self.array[key]
return self._ensure_ndarray(self.array[key])

def __setitem__(self, key, value):
key = self._convert_key(key)
Expand Down Expand Up @@ -469,16 +478,22 @@ def __getitem__(self, key):

if isinstance(result, pd.Index):
result = PandasIndexAdapter(result, dtype=self.dtype)
elif result is pd.NaT:
# work around the impossibility of casting NaT with asarray
# note: it probably would be better in general to return
# pd.Timestamp rather np.than datetime64 but this is easier
# (for now)
result = np.datetime64('NaT', 'ns')
elif isinstance(result, timedelta):
result = np.timedelta64(getattr(result, 'value', result), 'ns')
elif self.dtype != object:
result = np.asarray(result, dtype=self.dtype)
else:
# result is a scalar
if result is pd.NaT:
# work around the impossibility of casting NaT with asarray
# note: it probably would be better in general to return
# pd.Timestamp rather np.than datetime64 but this is easier
# (for now)
result = np.datetime64('NaT', 'ns')
elif isinstance(result, timedelta):
result = np.timedelta64(getattr(result, 'value', result), 'ns')
elif self.dtype != object:
result = np.asarray(result, dtype=self.dtype)

# as for numpy.ndarray indexing, we always want the result to be
# a NumPy array.
result = utils.to_0d_array(result)

return result

Expand Down
12 changes: 11 additions & 1 deletion xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,23 @@ def is_valid_numpy_dtype(dtype):
return True


def tuple_to_0darray(value):
def to_0d_object_array(value):
"""Given a value, wrap it in a 0-D numpy.ndarray with dtype=object."""
result = np.empty((1,), dtype=object)
result[:] = [value]
result.shape = ()
return result


def to_0d_array(value):
"""Given a value, wrap it in a 0-D numpy.ndarray."""
if np.isscalar(value) or (isinstance(value, np.ndarray)
and value.ndim == 0):
return np.array(value)
else:
return to_0d_object_array(value)


def dict_equiv(first, second, compat=equivalent):
"""Test equivalence of two dict-like objects. If any of the values are
numpy arrays, compare them correctly.
Expand Down
18 changes: 10 additions & 8 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def as_compatible_data(data, fastpath=False):
return _maybe_wrap_data(data)

if isinstance(data, tuple):
data = utils.tuple_to_0darray(data)
data = utils.to_0d_object_array(data)

if isinstance(data, pd.Timestamp):
# TODO: convert, handle datetime objects, too
Expand Down Expand Up @@ -159,19 +159,21 @@ def as_compatible_data(data, fastpath=False):

def _as_array_or_item(data):
"""Return the given values as a numpy array, or as an individual item if
it's a 0-dimensional object array or datetime64.
it's a 0d datetime64 or timedelta64 array.
Importantly, this function does not copy data if it is already an ndarray -
otherwise, it will not be possible to update Variable values in place.
This function mostly exists because 0-dimensional ndarrays with
dtype=datetime64 are broken :(
https://github.com/numpy/numpy/issues/4337
https://github.com/numpy/numpy/issues/7619
TODO: remove this (replace with np.asarray) once these issues are fixed
"""
data = np.asarray(data)
if data.ndim == 0:
if data.dtype.kind == 'O':
# unpack 0d object arrays to be consistent with numpy
data = data.item()
elif data.dtype.kind == 'M':
# convert to a np.datetime64 object, because 0-dimensional ndarrays
# with dtype=datetime64 are broken :(
if data.dtype.kind == 'M':
data = np.datetime64(data, 'ns')
elif data.dtype.kind == 'm':
data = np.timedelta64(data, 'ns')
Expand Down
42 changes: 34 additions & 8 deletions xarray/test/test_groupby.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,46 @@
import numpy as np
import xarray as xr
from xarray.core.groupby import _consolidate_slices

import pytest


def test_consolidate_slices():

assert _consolidate_slices([slice(3), slice(3, 5)]) == [slice(5)]
assert _consolidate_slices([slice(2, 3), slice(3, 6)]) == [slice(2, 6)]
assert (_consolidate_slices([slice(2, 3, 1), slice(3, 6, 1)])
assert _consolidate_slices([slice(3), slice(3, 5)]) == [slice(5)]
assert _consolidate_slices([slice(2, 3), slice(3, 6)]) == [slice(2, 6)]
assert (_consolidate_slices([slice(2, 3, 1), slice(3, 6, 1)])
== [slice(2, 6, 1)])

slices = [slice(2, 3), slice(5, 6)]
assert _consolidate_slices(slices) == slices

with pytest.raises(ValueError):
_consolidate_slices([slice(3), 4])
slices = [slice(2, 3), slice(5, 6)]
assert _consolidate_slices(slices) == slices

with pytest.raises(ValueError):
_consolidate_slices([slice(3), 4])


def test_multi_index_groupby_apply():
# regression test for GH873
ds = xr.Dataset({'foo': (('x', 'y'), np.random.randn(3, 4))},
{'x': ['a', 'b', 'c'], 'y': [1, 2, 3, 4]})
doubled = 2 * ds
group_doubled = (ds.stack(space=['x', 'y'])
.groupby('space')
.apply(lambda x: 2 * x)
.unstack('space'))
assert doubled.equals(group_doubled)


def test_multi_index_groupby_sum():
# regression test for GH873
ds = xr.Dataset({'foo': (('x', 'y', 'z'), np.ones((3, 4, 2)))},
{'x': ['a', 'b', 'c'], 'y': [1, 2, 3, 4]})
expected = ds.sum('z')
actual = (ds.stack(space=['x', 'y'])
.groupby('space')
.sum('z')
.unstack('space'))
assert expected.equals(actual)


# TODO: move other groupby tests from test_dataset and test_dataarray over here
56 changes: 43 additions & 13 deletions xarray/test/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def test_getitem_dict(self):
expected = v[0]
self.assertVariableIdentical(expected, actual)

def assertIndexedLikeNDArray(self, variable, expected_value0,
expected_dtype=None):
def _assertIndexedLikeNDArray(self, variable, expected_value0,
expected_dtype=None):
"""Given a 1-dimensional variable, verify that the variable is indexed
like a numpy.ndarray.
"""
Expand All @@ -66,52 +66,52 @@ def assertIndexedLikeNDArray(self, variable, expected_value0,
# check output type instead of array dtype
self.assertEqual(type(variable.values[0]), type(expected_value0))
self.assertEqual(type(variable[0].values), type(expected_value0))
else:
elif expected_dtype is not False:
self.assertEqual(variable.values[0].dtype, expected_dtype)
self.assertEqual(variable[0].values.dtype, expected_dtype)

def test_index_0d_int(self):
for value, dtype in [(0, np.int_),
(np.int32(0), np.int32)]:
x = self.cls(['x'], [value])
self.assertIndexedLikeNDArray(x, value, dtype)
self._assertIndexedLikeNDArray(x, value, dtype)

def test_index_0d_float(self):
for value, dtype in [(0.5, np.float_),
(np.float32(0.5), np.float32)]:
x = self.cls(['x'], [value])
self.assertIndexedLikeNDArray(x, value, dtype)
self._assertIndexedLikeNDArray(x, value, dtype)

def test_index_0d_string(self):
for value, dtype in [('foo', np.dtype('U3' if PY3 else 'S3')),
(u'foo', np.dtype('U3'))]:
x = self.cls(['x'], [value])
self.assertIndexedLikeNDArray(x, value, dtype)
self._assertIndexedLikeNDArray(x, value, dtype)

def test_index_0d_datetime(self):
d = datetime(2000, 1, 1)
x = self.cls(['x'], [d])
self.assertIndexedLikeNDArray(x, np.datetime64(d))
self._assertIndexedLikeNDArray(x, np.datetime64(d))

x = self.cls(['x'], [np.datetime64(d)])
self.assertIndexedLikeNDArray(x, np.datetime64(d), 'datetime64[ns]')
self._assertIndexedLikeNDArray(x, np.datetime64(d), 'datetime64[ns]')

x = self.cls(['x'], pd.DatetimeIndex([d]))
self.assertIndexedLikeNDArray(x, np.datetime64(d), 'datetime64[ns]')
self._assertIndexedLikeNDArray(x, np.datetime64(d), 'datetime64[ns]')

def test_index_0d_timedelta64(self):
td = timedelta(hours=1)

x = self.cls(['x'], [np.timedelta64(td)])
self.assertIndexedLikeNDArray(x, np.timedelta64(td), 'timedelta64[ns]')
self._assertIndexedLikeNDArray(x, np.timedelta64(td), 'timedelta64[ns]')

x = self.cls(['x'], pd.to_timedelta([td]))
self.assertIndexedLikeNDArray(x, np.timedelta64(td), 'timedelta64[ns]')
self._assertIndexedLikeNDArray(x, np.timedelta64(td), 'timedelta64[ns]')

def test_index_0d_not_a_time(self):
d = np.datetime64('NaT', 'ns')
x = self.cls(['x'], [d])
self.assertIndexedLikeNDArray(x, d, None)
self._assertIndexedLikeNDArray(x, d)

def test_index_0d_object(self):

Expand All @@ -130,7 +130,15 @@ def __repr__(self):

item = HashableItemWrapper((1, 2, 3))
x = self.cls('x', [item])
self.assertIndexedLikeNDArray(x, item)
self._assertIndexedLikeNDArray(x, item, expected_dtype=False)

def test_0d_object_array_with_list(self):
listarray = np.empty((1,), dtype=object)
listarray[0] = [1, 2, 3]
x = self.cls('x', listarray)
assert x.data == listarray
assert x[0].data == listarray.squeeze()
assert x.squeeze().data == listarray.squeeze()

def test_index_and_concat_datetime(self):
# regression test for #125
Expand Down Expand Up @@ -729,6 +737,19 @@ def test_transpose(self):
w3 = Variable(['b', 'c', 'd', 'a'], np.einsum('abcd->bcda', x))
self.assertVariableIdentical(w, w3.transpose('a', 'b', 'c', 'd'))

def test_transpose_0d(self):
for value in [
3.5,
('a', 1),
np.datetime64('2000-01-01'),
np.timedelta64(1, 'h'),
None,
object(),
]:
variable = Variable([], value)
actual = variable.transpose()
assert actual.identical(variable)

def test_squeeze(self):
v = Variable(['x', 'y'], [[1]])
self.assertVariableIdentical(Variable([], 1), v.squeeze())
Expand Down Expand Up @@ -773,6 +794,15 @@ def test_expand_dims(self):
with self.assertRaisesRegexp(ValueError, 'must be a superset'):
v.expand_dims(['z'])

def test_expand_dims_object_dtype(self):
v = Variable([], ('a', 1))
actual = v.expand_dims(('x',), (3,))
exp_values = np.empty((3,), dtype=object)
for i in range(3):
exp_values[i] = ('a', 1)
expected = Variable(['x'], exp_values)
assert actual.identical(expected)

def test_stack(self):
v = Variable(['x', 'y'], [[0, 1], [2, 3]], {'foo': 'bar'})
actual = v.stack(z=('x', 'y'))
Expand Down

0 comments on commit 5433d9d

Please sign in to comment.