From c6f4e3afb362b18e79004b58272a7c4028ce68e5 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Tue, 6 Feb 2024 04:33:29 -0500 Subject: [PATCH 1/9] extract dtypes from underlying duck arrays without coercing to numpy --- xarray/core/duck_array_ops.py | 44 ++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index b30ba4c3a78..6a6c39f6bcc 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -214,26 +214,32 @@ def astype(data, dtype, **kwargs): def asarray(data, xp=np): + print(data) + print(type(data)) return data if is_duck_array(data) else xp.asarray(data) -def as_shared_dtype(scalars_or_arrays, xp=np): - """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - array_type_cupy = array_type("cupy") - if array_type_cupy and any( - isinstance(x, array_type_cupy) for x in scalars_or_arrays - ): - import cupy as cp - - arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] +def as_duck_array(data, xp=np): + if is_duck_array(data): + return data + elif hasattr(data, "get_duck_array"): + # must be a lazy indexing class wrapping a duck array + return data.get_duck_array() else: - arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] - # Pass arrays directly instead of dtypes to result_type so scalars - # get handled properly. - # Note that result_type() safely gets the dtype from dask arrays without - # evaluating them. - out_type = dtypes.result_type(*arrays) - return [astype(x, out_type, copy=False) for x in arrays] + array_type_cupy = array_type("cupy") + if array_type_cupy and any(isinstance(data, array_type_cupy)): + import cupy as cp + + return asarray(data, xp=cp) + else: + return asarray(data, xp=xp) + + +def as_shared_dtype(scalars_or_arrays, xp=np): + """Cast arrays to a shared dtype using xarray's type promotion rules.""" + duckarrays = [as_duck_array(obj, xp=xp) for obj in scalars_or_arrays] + out_type = dtypes.result_type(*duckarrays) + return [astype(x, out_type, copy=False) for x in duckarrays] def broadcast_to(array, shape): @@ -327,7 +333,7 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition) - return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) + return xp.where(condition, *as_shared_dtype([x, y])) def where_method(data, cond, other=dtypes.NA): @@ -350,14 +356,14 @@ def concatenate(arrays, axis=0): arrays[0], np.ndarray ): xp = get_array_namespace(arrays[0]) - return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) + return xp.concat(as_shared_dtype(arrays), axis=axis) return _concatenate(as_shared_dtype(arrays), axis=axis) def stack(arrays, axis=0): """stack() with better dtype promotion rules.""" xp = get_array_namespace(arrays[0]) - return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis) + return xp.stack(as_shared_dtype(arrays), axis=axis) def reshape(array, shape): From 1467c4c38ccf0a7ea7e82261b9f83ae936a0c117 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Tue, 6 Feb 2024 04:36:06 -0500 Subject: [PATCH 2/9] remove print statements --- xarray/core/duck_array_ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6a6c39f6bcc..91161bd3510 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -214,8 +214,6 @@ def astype(data, dtype, **kwargs): def asarray(data, xp=np): - print(data) - print(type(data)) return data if is_duck_array(data) else xp.asarray(data) From d9931ef78879352e35b8d8c7afa716678f1adae4 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Tue, 6 Feb 2024 04:39:21 -0500 Subject: [PATCH 3/9] use array namespace again --- xarray/core/duck_array_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 91161bd3510..035255aa619 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -331,7 +331,7 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition) - return xp.where(condition, *as_shared_dtype([x, y])) + return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) def where_method(data, cond, other=dtypes.NA): @@ -354,14 +354,14 @@ def concatenate(arrays, axis=0): arrays[0], np.ndarray ): xp = get_array_namespace(arrays[0]) - return xp.concat(as_shared_dtype(arrays), axis=axis) + return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) return _concatenate(as_shared_dtype(arrays), axis=axis) def stack(arrays, axis=0): """stack() with better dtype promotion rules.""" xp = get_array_namespace(arrays[0]) - return xp.stack(as_shared_dtype(arrays), axis=axis) + return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis) def reshape(array, shape): From c067f7d69c5316657146d097bf0696aa6a6aea77 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Tue, 6 Feb 2024 11:49:18 -0500 Subject: [PATCH 4/9] use pycompat.to_duck_array instead --- xarray/core/duck_array_ops.py | 20 ++------------------ xarray/core/pycompat.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 035255aa619..c7899105507 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -35,7 +35,7 @@ from xarray.core import dask_array_ops, dtypes, nputils, pycompat from xarray.core.options import OPTIONS from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array -from xarray.core.pycompat import array_type, is_duck_dask_array +from xarray.core.pycompat import is_duck_dask_array, to_duck_array from xarray.core.utils import is_duck_array, module_available # remove once numpy 2.0 is the oldest supported version @@ -217,25 +217,9 @@ def asarray(data, xp=np): return data if is_duck_array(data) else xp.asarray(data) -def as_duck_array(data, xp=np): - if is_duck_array(data): - return data - elif hasattr(data, "get_duck_array"): - # must be a lazy indexing class wrapping a duck array - return data.get_duck_array() - else: - array_type_cupy = array_type("cupy") - if array_type_cupy and any(isinstance(data, array_type_cupy)): - import cupy as cp - - return asarray(data, xp=cp) - else: - return asarray(data, xp=xp) - - def as_shared_dtype(scalars_or_arrays, xp=np): """Cast arrays to a shared dtype using xarray's type promotion rules.""" - duckarrays = [as_duck_array(obj, xp=xp) for obj in scalars_or_arrays] + duckarrays = [to_duck_array(obj, xp=xp) for obj in scalars_or_arrays] out_type = dtypes.result_type(*duckarrays) return [astype(x, out_type, copy=False) for x in duckarrays] diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 32ef408f7cc..cde565837b0 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -7,6 +7,7 @@ import numpy as np from packaging.version import Version +from xarray.core.types import T_DuckArray from xarray.core.utils import is_duck_array, is_scalar, module_available integer_types = (int, np.integer) @@ -129,7 +130,7 @@ def to_numpy(data) -> np.ndarray: return data -def to_duck_array(data): +def to_duck_array(data, xp=np) -> T_DuckArray: from xarray.core.indexing import ExplicitlyIndexed if isinstance(data, ExplicitlyIndexed): @@ -137,4 +138,12 @@ def to_duck_array(data): elif is_duck_array(data): return data else: - return np.asarray(data) + from xarray.core.duck_array_ops import asarray + + array_type_cupy = array_type("cupy") + if array_type_cupy and any(isinstance(data, array_type_cupy)): + import cupy as cp + + return asarray(data, xp=cp) + else: + return asarray(data, xp=xp) From 5092aaaf874d2fdaa4447de5b6d59943f7b9cc17 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Tue, 6 Feb 2024 11:50:05 -0500 Subject: [PATCH 5/9] a sprinkle of type hints --- xarray/core/pycompat.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index cde565837b0..d188fac6b42 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -87,7 +87,7 @@ def mod_version(mod: ModType) -> Version: return _get_cached_duck_array_module(mod).version -def is_dask_collection(x): +def is_dask_collection(x) -> bool: if module_available("dask"): from dask.base import is_dask_collection @@ -95,7 +95,7 @@ def is_dask_collection(x): return False -def is_duck_dask_array(x): +def is_duck_dask_array(x) -> bool: return is_duck_array(x) and is_dask_collection(x) @@ -103,7 +103,7 @@ def is_chunked_array(x) -> bool: return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks")) -def is_0d_dask_array(x): +def is_0d_dask_array(x) -> bool: return is_duck_dask_array(x) and is_scalar(x) From a884ba8a77b01cb816608960a4ed3a0c1269b80d Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Tue, 6 Feb 2024 11:54:22 -0500 Subject: [PATCH 6/9] whatsnew --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1b0f2f18efb..c9044911092 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -53,6 +53,8 @@ Bug fixes By `Tom Nicholas `_. - Ensure :py:meth:`DataArray.unstack` works when wrapping array API-compliant classes. (:issue:`8666`, :pull:`8668`) By `Tom Nicholas `_. +- Avoid coercing to numpy arrays inside :py:func:`~xarray.core.duck_array_ops.as_shared_dtype`. (:pull:`8714`) + By `Tom Nicholas `_. - Preserve chunks when writing time-like variables to zarr by enabling lazy CF encoding of time-like variables (:issue:`7132`, :issue:`8230`, :issue:`8432`, :pull:`8575`). By `Spencer Clark `_ and From 45808d8afef8527b8e172ba7422bee47248d62c8 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Sun, 24 Mar 2024 20:50:44 -0400 Subject: [PATCH 7/9] don't compute in as_shared_dtype --- xarray/core/duck_array_ops.py | 4 ++-- xarray/namedarray/pycompat.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index a5a3b70ee57..692e74efa23 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -38,7 +38,7 @@ from xarray.core.utils import is_duck_array, module_available from xarray.namedarray import pycompat from xarray.namedarray.parallelcompat import get_chunked_array_type, is_chunked_array -from xarray.namedarray.pycompat import is_duck_dask_array, to_duck_array +from xarray.namedarray.pycompat import is_duck_dask_array, to_lazy_duck_array # remove once numpy 2.0 is the oldest supported version if module_available("numpy", minversion="2.0.0.dev0"): @@ -221,7 +221,7 @@ def asarray(data, xp=np): def as_shared_dtype(scalars_or_arrays, xp=np): """Cast arrays to a shared dtype using xarray's type promotion rules.""" - duckarrays = [to_duck_array(obj, xp=xp) for obj in scalars_or_arrays] + duckarrays = [to_lazy_duck_array(obj, xp=xp) for obj in scalars_or_arrays] out_type = dtypes.result_type(*duckarrays) return [astype(x, out_type, copy=False) for x in duckarrays] diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py index 75c34bf0573..418e80bd180 100644 --- a/xarray/namedarray/pycompat.py +++ b/xarray/namedarray/pycompat.py @@ -132,7 +132,6 @@ def to_numpy( def to_duck_array( data: Any, xp=np, **kwargs: dict[str, Any] ) -> duckarray[_ShapeType, _DType]: - from xarray.core.indexing import ExplicitlyIndexed from xarray.namedarray.parallelcompat import get_chunked_array_type if is_chunked_array(data): @@ -140,6 +139,15 @@ def to_duck_array( loaded_data, *_ = chunkmanager.compute(data, **kwargs) # type: ignore[var-annotated] return loaded_data + return to_lazy_duck_array(data) + + +def to_lazy_duck_array( + data: Any, xp=np, **kwargs: dict[str, Any] +) -> duckarray[_ShapeType, _DType]: + """Doesn't compute chunked data.""" + from xarray.core.indexing import ExplicitlyIndexed + if isinstance(data, ExplicitlyIndexed): return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return] elif is_duck_array(data): From 6833d66080ebe2fa559d10bc8d7762119ee6d015 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Sun, 24 Mar 2024 21:24:10 -0400 Subject: [PATCH 8/9] ConcatenatableArray class --- xarray/tests/__init__.py | 101 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 5007db9eeb2..81beba80cf1 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -4,7 +4,9 @@ import platform import string import warnings +from collections.abc import Iterable from contextlib import contextmanager, nullcontext +from typing import Any, Callable from unittest import mock # noqa: F401 import numpy as np @@ -208,6 +210,105 @@ def __getitem__(self, key): raise UnexpectedDataAccess("Tried accessing data.") +HANDLED_ARRAY_FUNCTIONS: dict[str, Callable] = {} + + +def implements(numpy_function): + """Register an __array_function__ implementation for ManifestArray objects.""" + + def decorator(func): + HANDLED_ARRAY_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +@implements(np.concatenate) +def concatenate( + arrays: Iterable[ConcatenatableArray], /, *, axis=0 +) -> ConcatenatableArray: + if any(not isinstance(arr, ConcatenatableArray) for arr in arrays): + raise TypeError + + result = np.concatenate([arr.array for arr in arrays], axis=axis) + return ConcatenatableArray(result) + + +@implements(np.stack) +def stack(arrays: Iterable[ConcatenatableArray], /, *, axis=0) -> ConcatenatableArray: + if any(not isinstance(arr, ConcatenatableArray) for arr in arrays): + raise TypeError + + result = np.stack([arr.array for arr in arrays], axis=axis) + return ConcatenatableArray(result) + + +@implements(np.result_type) +def result_type(*arrays_and_dtypes) -> np.dtype: + """Called by xarray to ensure all arguments to concat have the same dtype.""" + first_dtype, *other_dtypes = (np.dtype(obj) for obj in arrays_and_dtypes) + for other_dtype in other_dtypes: + if other_dtype != first_dtype: + raise ValueError("dtypes not all consistent") + return first_dtype + + +@implements(np.broadcast_to) +def broadcast_to( + x: ConcatenatableArray, /, shape: tuple[int, ...] +) -> ConcatenatableArray: + """ + Broadcasts an array to a specified shape, by either manipulating chunk keys or copying chunk manifest entries. + """ + if not isinstance(x, ConcatenatableArray): + raise TypeError + + result = np.broadcast_to(x.array, shape=shape) + return ConcatenatableArray(result) + + +class ConcatenatableArray(utils.NDArrayMixin): + """Disallows loading or coercing to an index but does support concatenation / stacking.""" + + # TODO only reason this is different from InaccessibleArray is to avoid it being a subclass of ExplicitlyIndexed + + HANDLED_ARRAY_FUNCTIONS = [concatenate, stack, result_type] + + def __init__(self, array): + self.array = array + + def get_duck_array(self): + raise UnexpectedDataAccess("Tried accessing data") + + def __array__(self, dtype: np.typing.DTypeLike = None): + raise UnexpectedDataAccess("Tried accessing data") + + def __getitem__(self, key): + raise UnexpectedDataAccess("Tried accessing data.") + + def __array_function__(self, func, types, args, kwargs) -> Any: + if func not in HANDLED_ARRAY_FUNCTIONS: + return NotImplemented + + # Note: this allows subclasses that don't override + # __array_function__ to handle ManifestArray objects + if not all(issubclass(t, ConcatenatableArray) for t in types): + return NotImplemented + + return HANDLED_ARRAY_FUNCTIONS[func](*args, **kwargs) + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Any: + """We have to define this in order to convince xarray that this class is a duckarray, even though we will never support ufuncs.""" + return NotImplemented + + def astype(self, dtype: np.dtype, /, *, copy: bool = True) -> ConcatenatableArray: + """Needed because xarray will call this even when it's a no-op""" + if dtype != self.dtype: + raise NotImplementedError() + else: + return self + + class FirstElementAccessibleArray(InaccessibleArray): def __getitem__(self, key): tuple_idxr = key.tuple From bcd02bf5d1822c3ac032edba790896449fa75568 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Thu, 28 Mar 2024 14:17:26 -0400 Subject: [PATCH 9/9] test that no coercion occurs using ConcatenatableArray --- xarray/tests/test_variable.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index d9289aa6674..42c014364dd 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -8,6 +8,7 @@ from typing import Generic import numpy as np +import numpy.testing as npt import pandas as pd import pytest import pytz @@ -31,6 +32,7 @@ from xarray.core.variable import as_compatible_data, as_variable from xarray.namedarray.pycompat import array_type from xarray.tests import ( + ConcatenatableArray, assert_allclose, assert_array_equal, assert_equal, @@ -551,6 +553,16 @@ def test_concat_mixed_dtypes(self): assert_identical(expected, actual) assert actual.dtype == object + def test_concat_without_access(self): + a = self.cls("x", ConcatenatableArray(np.array([0, 1]))) + b = self.cls("x", ConcatenatableArray(np.array([2, 3]))) + actual = Variable.concat([a, b], dim="x") + expected_arr = np.array([0, 1, 2, 3]) + expected = Variable("x", ConcatenatableArray(expected_arr)) + assert isinstance(actual.data, ConcatenatableArray) + assert expected.dims == ("x",) + npt.assert_equal(actual.data.array, expected_arr) + @pytest.mark.parametrize("deep", [True, False]) @pytest.mark.parametrize("astype", [float, int, str]) def test_copy(self, deep: bool, astype: type[object]) -> None: