diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6bf495713fe..4d07bffedd5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,10 @@ Breaking changes New Features ~~~~~~~~~~~~ +- Added the ``sparse`` option to :py:meth:`~xarray.DataArray.unstack`, + :py:meth:`~xarray.Dataset.unstack`, :py:meth:`~xarray.DataArray.reindex`, + :py:meth:`~xarray.Dataset.reindex` (:issue:`3518`). + By `Keisuke Fujii `_. - Added the ``fill_value`` option to :py:meth:`~xarray.DataArray.unstack` and :py:meth:`~xarray.Dataset.unstack` (:issue:`3518`). diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 41ff5a3b32d..749de6c13e2 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -466,6 +466,7 @@ def reindex_variables( tolerance: Any = None, copy: bool = True, fill_value: Optional[Any] = dtypes.NA, + sparse: bool = False, ) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]: """Conform a dictionary of aligned variables onto a new set of variables, filling in missing values with NaN. @@ -503,6 +504,8 @@ def reindex_variables( the input. In either case, new xarray objects are always returned. fill_value : scalar, optional Value to use for newly missing values + sparse: bool, optional + Use an sparse-array Returns ------- @@ -571,6 +574,8 @@ def reindex_variables( for name, var in variables.items(): if name not in indexers: + if sparse: + var = var._as_sparse(fill_value=fill_value) key = tuple( slice(None) if d in unchanged_dims else int_indexers.get(d, slice(None)) for d in var.dims diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 23342fc5e0d..1ed4b5566d7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1729,6 +1729,7 @@ def unstack( self, dim: Union[Hashable, Sequence[Hashable], None] = None, fill_value: Any = dtypes.NA, + sparse: bool = False, ) -> "DataArray": """ Unstack existing dimensions corresponding to MultiIndexes into @@ -1742,6 +1743,7 @@ def unstack( Dimension(s) over which to unstack. By default unstacks all MultiIndexes. fill_value: value to be filled. By default, np.nan + sparse: use sparse-array if True Returns ------- @@ -1773,7 +1775,7 @@ def unstack( -------- DataArray.stack """ - ds = self._to_temp_dataset().unstack(dim, fill_value) + ds = self._to_temp_dataset().unstack(dim, fill_value, sparse) return self._from_temp_dataset(ds) def to_unstacked_dataset(self, dim, level=0): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 371e0d6bf26..71288757cb7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2286,6 +2286,7 @@ def reindex( the input. In either case, a new xarray object is always returned. fill_value : scalar, optional Value to use for newly missing values + sparse: use sparse-array. By default, False **indexers_kwarg : {dim: indexer, ...}, optional Keyword arguments in the same form as ``indexers``. One of indexers or indexers_kwargs must be provided. @@ -2428,6 +2429,29 @@ def reindex( the original and desired indexes. If you do want to fill in the `NaN` values present in the original dataset, use the :py:meth:`~Dataset.fillna()` method. + """ + return self._reindex( + indexers, + method, + tolerance, + copy, + fill_value, + sparse=False, + **indexers_kwargs, + ) + + def _reindex( + self, + indexers: Mapping[Hashable, Any] = None, + method: str = None, + tolerance: Number = None, + copy: bool = True, + fill_value: Any = dtypes.NA, + sparse: bool = False, + **indexers_kwargs: Any, + ) -> "Dataset": + """ + same to _reindex but support sparse option """ indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex") @@ -2444,6 +2468,7 @@ def reindex( tolerance, copy=copy, fill_value=fill_value, + sparse=sparse, ) coord_names = set(self._coord_names) coord_names.update(indexers) @@ -3333,7 +3358,7 @@ def ensure_stackable(val): return data_array - def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset": + def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset": index = self.get_index(dim) index = index.remove_unused_levels() full_idx = pd.MultiIndex.from_product(index.levels, names=index.names) @@ -3342,7 +3367,9 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset": if index.equals(full_idx): obj = self else: - obj = self.reindex({dim: full_idx}, copy=False, fill_value=fill_value) + obj = self._reindex( + {dim: full_idx}, copy=False, fill_value=fill_value, sparse=sparse + ) new_dim_names = index.names new_dim_sizes = [lev.size for lev in index.levels] @@ -3372,6 +3399,7 @@ def unstack( self, dim: Union[Hashable, Iterable[Hashable]] = None, fill_value: Any = dtypes.NA, + sparse: bool = False, ) -> "Dataset": """ Unstack existing dimensions corresponding to MultiIndexes into @@ -3385,6 +3413,7 @@ def unstack( Dimension(s) over which to unstack. By default unstacks all MultiIndexes. fill_value: value to be filled. By default, np.nan + sparse: use sparse-array if True Returns ------- @@ -3422,7 +3451,7 @@ def unstack( result = self.copy(deep=False) for dim in dims: - result = result._unstack_once(dim, fill_value) + result = result._unstack_once(dim, fill_value, sparse) return result def update(self, other: "CoercibleMapping", inplace: bool = None) -> "Dataset": diff --git a/xarray/core/variable.py b/xarray/core/variable.py index e630dc4b457..55e8f64d56c 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -993,6 +993,36 @@ def chunk(self, chunks=None, name=None, lock=False): return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True) + def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA): + """ + use sparse-array as backend. + """ + import sparse + + # TODO what to do if dask-backended? + if fill_value is dtypes.NA: + dtype, fill_value = dtypes.maybe_promote(self.dtype) + else: + dtype = dtypes.result_type(self.dtype, fill_value) + + if sparse_format is _default: + sparse_format = "coo" + try: + as_sparse = getattr(sparse, "as_{}".format(sparse_format.lower())) + except AttributeError: + raise ValueError("{} is not a valid sparse format".format(sparse_format)) + + data = as_sparse(self.data.astype(dtype), fill_value=fill_value) + return self._replace(data=data) + + def _to_dense(self): + """ + Change backend from sparse to np.array + """ + if hasattr(self._data, "todense"): + return self._replace(data=self._data.todense()) + return self.copy(deep=False) + def isel( self: VariableType, indexers: Mapping[Hashable, Any] = None, @@ -2021,6 +2051,14 @@ def chunk(self, chunks=None, name=None, lock=False): # Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk() return self.copy(deep=False) + def _as_sparse(self, sparse_format=_default, fill_value=_default): + # Dummy + return self.copy(deep=False) + + def _to_dense(self): + # Dummy + return self.copy(deep=False) + def _finalize_indexing_result(self, dims, data): if getattr(data, "ndim", 0) != 1: # returns Variable rather than IndexVariable if multi-dimensional diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index be40ce7c6e8..b09203f91a2 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2811,6 +2811,25 @@ def test_unstack_fill_value(self): expected = ds["var"].unstack("index").fillna(-1).astype(np.int) assert actual.equals(expected) + @requires_sparse + def test_unstack_sparse(self): + ds = xr.Dataset( + {"var": (("x",), np.arange(6))}, + coords={"x": [0, 1, 2] * 2, "y": (("x",), ["a"] * 3 + ["b"] * 3)}, + ) + # make ds incomplete + ds = ds.isel(x=[0, 2, 3, 4]).set_index(index=["x", "y"]) + # test fill_value + actual = ds.unstack("index", sparse=True) + expected = ds.unstack("index") + assert actual["var"].variable._to_dense().equals(expected["var"].variable) + assert actual["var"].data.density < 1.0 + + actual = ds["var"].unstack("index", sparse=True) + expected = ds["var"].unstack("index") + assert actual.variable._to_dense().equals(expected.variable) + assert actual.data.density < 1.0 + def test_stack_unstack_fast(self): ds = Dataset( { diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index d92a68729b5..ee8d54e567e 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -33,6 +33,7 @@ assert_identical, raises_regex, requires_dask, + requires_sparse, source_ndarray, ) @@ -1862,6 +1863,17 @@ def test_getitem_with_mask_nd_indexer(self): ) +@requires_sparse +class TestVariableWithSparse: + # TODO inherit VariableSubclassobjects to cover more tests + + def test_as_sparse(self): + data = np.arange(12).reshape(3, 4) + var = Variable(("x", "y"), data)._as_sparse(fill_value=-1) + actual = var._to_dense() + assert_identical(var, actual) + + class TestIndexVariable(VariableSubclassobjects): cls = staticmethod(IndexVariable)