Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error unstacking array API compliant class #8666

Closed
5 tasks done
TomNicholas opened this issue Jan 25, 2024 · 0 comments · Fixed by #8668
Closed
5 tasks done

Error unstacking array API compliant class #8666

TomNicholas opened this issue Jan 25, 2024 · 0 comments · Fixed by #8668
Labels
array API standard Support for the Python array API standard bug

Comments

@TomNicholas
Copy link
Contributor

What happened?

Unstacking fails for array types that strictly follow the array API standard.

What did you expect to happen?

This obviously works fine with a normal numpy array.

Minimal Complete Verifiable Example

import numpy.array_api as nxp

arr = nxp.asarray([[1, 2, 3], [4, 5, 6]], dtype=np.dtype('float32'))

da = xr.DataArray(
    arr,
    coords=[("x", ["a", "b"]), ("y", [0, 1, 2])],
)
da
stacked = da.stack(z=("x", "y"))
stacked.indexes["z"]
stacked.unstack()

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[65], line 8
      6 stacked = da.stack(z=("x", "y"))
      7 stacked.indexes["z"]
----> 8 roundtripped = stacked.unstack()
      9 arr.identical(roundtripped)

File ~/Documents/Work/Code/xarray/xarray/util/deprecation_helpers.py:115, in _deprecate_positional_args.<locals>._decorator.<locals>.inner(*args, **kwargs)
    111     kwargs.update({name: arg for name, arg in zip_args})
    113     return func(*args[:-n_extra_args], **kwargs)
--> 115 return func(*args, **kwargs)

File ~/Documents/Work/Code/xarray/xarray/core/dataarray.py:2913, in DataArray.unstack(self, dim, fill_value, sparse)
   2851 @_deprecate_positional_args("v2023.10.0")
   2852 def unstack(
   2853     self,
   (...)
   2857     sparse: bool = False,
   2858 ) -> Self:
   2859     """
   2860     Unstack existing dimensions corresponding to MultiIndexes into
   2861     multiple new dimensions.
   (...)
   2911     DataArray.stack
   2912     """
-> 2913     ds = self._to_temp_dataset().unstack(dim, fill_value=fill_value, sparse=sparse)
   2914     return self._from_temp_dataset(ds)

File ~/Documents/Work/Code/xarray/xarray/util/deprecation_helpers.py:115, in _deprecate_positional_args.<locals>._decorator.<locals>.inner(*args, **kwargs)
    111     kwargs.update({name: arg for name, arg in zip_args})
    113     return func(*args[:-n_extra_args], **kwargs)
--> 115 return func(*args, **kwargs)

File ~/Documents/Work/Code/xarray/xarray/core/dataset.py:5581, in Dataset.unstack(self, dim, fill_value, sparse)
   5579 for d in dims:
   5580     if needs_full_reindex:
-> 5581         result = result._unstack_full_reindex(
   5582             d, stacked_indexes[d], fill_value, sparse
   5583         )
   5584     else:
   5585         result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse)

File ~/Documents/Work/Code/xarray/xarray/core/dataset.py:5474, in Dataset._unstack_full_reindex(self, dim, index_and_vars, fill_value, sparse)
   5472 if name not in index_vars:
   5473     if dim in var.dims:
-> 5474         variables[name] = var.unstack({dim: new_dim_sizes})
   5475     else:
   5476         variables[name] = var

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:1684, in Variable.unstack(self, dimensions, **dimensions_kwargs)
   1682 result = self
   1683 for old_dim, dims in dimensions.items():
-> 1684     result = result._unstack_once_full(dims, old_dim)
   1685 return result

File ~/Documents/Work/Code/xarray/xarray/core/variable.py:1574, in Variable._unstack_once_full(self, dim, old_dim)
   1571 reordered = self.transpose(*dim_order)
   1573 new_shape = reordered.shape[: len(other_dims)] + new_dim_sizes
-> 1574 new_data = reordered.data.reshape(new_shape)
   1575 new_dims = reordered.dims[: len(other_dims)] + new_dim_names
   1577 return type(self)(
   1578     new_dims, new_data, self._attrs, self._encoding, fastpath=True
   1579 )

AttributeError: 'Array' object has no attribute 'reshape'

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

No response

Anything else we need to know?

It fails on the arr.reshape call, because the array API standard has reshape be a function, not a method.

We do in fact have an array API-compatible version of reshape defined in duck_array_ops.py, it just apparently isn't yet used everywhere we call reshape.

def reshape(array, shape):

Environment

main branch of xarray, numpy 1.26.0

@TomNicholas TomNicholas added bug needs triage Issue that has not been reviewed by xarray team member array API standard Support for the Python array API standard and removed needs triage Issue that has not been reviewed by xarray team member labels Jan 25, 2024
@TomNicholas TomNicholas added this to To do in Duck Array Wrapping via automation Jan 25, 2024
Duck Array Wrapping automation moved this from To do to Done Jan 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array API standard Support for the Python array API standard bug
Projects
Development

Successfully merging a pull request may close this issue.

1 participant