Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid coercing to numpy in as_shared_dtypes #8714

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
15 changes: 15 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,21 @@ Bug fixes
- Ensure :py:meth:`DataArray.unstack` works when wrapping array API-compliant
classes. (:issue:`8666`, :pull:`8668`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Avoid coercing to numpy arrays inside :py:func:`~xarray.core.duck_array_ops.as_shared_dtype`. (:pull:`8714`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Preserve chunks when writing time-like variables to zarr by enabling lazy CF
encoding of time-like variables (:issue:`7132`, :issue:`8230`, :issue:`8432`,
:pull:`8575`). By `Spencer Clark <https://github.com/spencerkclark>`_ and
`Mattia Almansi <https://github.com/malmans2>`_.
- Preserve chunks when writing time-like variables to zarr by enabling their
lazy encoding (:issue:`7132`, :issue:`8230`, :issue:`8432`, :pull:`8253`,
:pull:`8575`; see also discussion in :pull:`8253`). By `Spencer Clark
<https://github.com/spencerkclark>`_ and `Mattia Almansi
<https://github.com/malmans2>`_.
- Raise an informative error if dtype encoding of time-like variables would
lead to integer overflow or unsafe conversion from floating point to integer
values (:issue:`8542`, :pull:`8575`). By `Spencer Clark
<https://github.com/spencerkclark>`_.
- Fix negative slicing of Zarr arrays without dask installed. (:issue:`8252`)
By `Deepak Cherian <https://github.com/dcherian>`_.
- Preserve chunks when writing time-like variables to zarr by enabling lazy CF encoding of time-like
Expand Down
26 changes: 7 additions & 19 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@

from xarray.core import dask_array_ops, dtypes, nputils
from xarray.core.options import OPTIONS
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
from xarray.core.utils import is_duck_array, module_available
from xarray.namedarray import pycompat
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import array_type, is_chunked_array
from xarray.namedarray.parallelcompat import get_chunked_array_type, is_chunked_array
from xarray.namedarray.pycompat import is_duck_dask_array, to_lazy_duck_array

# remove once numpy 2.0 is the oldest supported version
if module_available("numpy", minversion="2.0.0.dev0"):
Expand Down Expand Up @@ -220,22 +220,10 @@ def asarray(data, xp=np):


def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
array_type_cupy = array_type("cupy")
if array_type_cupy and any(
isinstance(x, array_type_cupy) for x in scalars_or_arrays
):
import cupy as cp

arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
else:
arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously this asarray call would coerce to numpy unnecessarily, when all we really wanted was an array type that we could examine the .dtype attribute of.

# Pass arrays directly instead of dtypes to result_type so scalars
# get handled properly.
# Note that result_type() safely gets the dtype from dask arrays without
# evaluating them.
out_type = dtypes.result_type(*arrays)
return [astype(x, out_type, copy=False) for x in arrays]
"""Cast arrays to a shared dtype using xarray's type promotion rules."""
duckarrays = [to_lazy_duck_array(obj, xp=xp) for obj in scalars_or_arrays]
out_type = dtypes.result_type(*duckarrays)
return [astype(x, out_type, copy=False) for x in duckarrays]


def broadcast_to(array, shape):
Expand Down
34 changes: 30 additions & 4 deletions xarray/namedarray/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from packaging.version import Version

from xarray.core.utils import is_scalar
from xarray.core.utils import is_scalar, module_available
from xarray.namedarray.utils import is_duck_array, is_duck_dask_array

integer_types = (int, np.integer)
Expand Down Expand Up @@ -88,6 +88,14 @@ def mod_version(mod: ModType) -> Version:
return _get_cached_duck_array_module(mod).version


def is_dask_collection(x) -> bool:
if module_available("dask"):
from dask.base import is_dask_collection

return is_dask_collection(x)
return False


def is_chunked_array(x: duckarray[Any, Any]) -> bool:
return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks"))

Expand Down Expand Up @@ -121,18 +129,36 @@ def to_numpy(
return data


def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, _DType]:
from xarray.core.indexing import ExplicitlyIndexed
def to_duck_array(
data: Any, xp=np, **kwargs: dict[str, Any]
) -> duckarray[_ShapeType, _DType]:
from xarray.namedarray.parallelcompat import get_chunked_array_type

if is_chunked_array(data):
chunkmanager = get_chunked_array_type(data)
loaded_data, *_ = chunkmanager.compute(data, **kwargs) # type: ignore[var-annotated]
return loaded_data

return to_lazy_duck_array(data)


def to_lazy_duck_array(
data: Any, xp=np, **kwargs: dict[str, Any]
) -> duckarray[_ShapeType, _DType]:
"""Doesn't compute chunked data."""
from xarray.core.indexing import ExplicitlyIndexed

if isinstance(data, ExplicitlyIndexed):
return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return]
elif is_duck_array(data):
return data
else:
return np.asarray(data) # type: ignore[return-value]
from xarray.core.duck_array_ops import asarray

array_type_cupy = array_type("cupy")
if array_type_cupy and any(isinstance(data, array_type_cupy)):
import cupy as cp

return asarray(data, xp=cp)
else:
return asarray(data, xp=xp)
101 changes: 101 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import platform
import string
import warnings
from collections.abc import Iterable
from contextlib import contextmanager, nullcontext
from typing import Any, Callable
from unittest import mock # noqa: F401

import numpy as np
Expand Down Expand Up @@ -208,6 +210,105 @@ def __getitem__(self, key):
raise UnexpectedDataAccess("Tried accessing data.")


HANDLED_ARRAY_FUNCTIONS: dict[str, Callable] = {}


def implements(numpy_function):
"""Register an __array_function__ implementation for ManifestArray objects."""

def decorator(func):
HANDLED_ARRAY_FUNCTIONS[numpy_function] = func
return func

return decorator


@implements(np.concatenate)
def concatenate(
arrays: Iterable[ConcatenatableArray], /, *, axis=0
) -> ConcatenatableArray:
if any(not isinstance(arr, ConcatenatableArray) for arr in arrays):
raise TypeError

result = np.concatenate([arr.array for arr in arrays], axis=axis)
return ConcatenatableArray(result)


@implements(np.stack)
def stack(arrays: Iterable[ConcatenatableArray], /, *, axis=0) -> ConcatenatableArray:
if any(not isinstance(arr, ConcatenatableArray) for arr in arrays):
raise TypeError

result = np.stack([arr.array for arr in arrays], axis=axis)
return ConcatenatableArray(result)


@implements(np.result_type)
def result_type(*arrays_and_dtypes) -> np.dtype:
"""Called by xarray to ensure all arguments to concat have the same dtype."""
first_dtype, *other_dtypes = (np.dtype(obj) for obj in arrays_and_dtypes)
for other_dtype in other_dtypes:
if other_dtype != first_dtype:
raise ValueError("dtypes not all consistent")
return first_dtype


@implements(np.broadcast_to)
def broadcast_to(
x: ConcatenatableArray, /, shape: tuple[int, ...]
) -> ConcatenatableArray:
"""
Broadcasts an array to a specified shape, by either manipulating chunk keys or copying chunk manifest entries.
"""
if not isinstance(x, ConcatenatableArray):
raise TypeError

result = np.broadcast_to(x.array, shape=shape)
return ConcatenatableArray(result)


class ConcatenatableArray(utils.NDArrayMixin):
"""Disallows loading or coercing to an index but does support concatenation / stacking."""

# TODO only reason this is different from InaccessibleArray is to avoid it being a subclass of ExplicitlyIndexed

HANDLED_ARRAY_FUNCTIONS = [concatenate, stack, result_type]

def __init__(self, array):
self.array = array

def get_duck_array(self):
raise UnexpectedDataAccess("Tried accessing data")

def __array__(self, dtype: np.typing.DTypeLike = None):
raise UnexpectedDataAccess("Tried accessing data")

def __getitem__(self, key):
raise UnexpectedDataAccess("Tried accessing data.")

def __array_function__(self, func, types, args, kwargs) -> Any:
if func not in HANDLED_ARRAY_FUNCTIONS:
return NotImplemented

# Note: this allows subclasses that don't override
# __array_function__ to handle ManifestArray objects
if not all(issubclass(t, ConcatenatableArray) for t in types):
return NotImplemented

return HANDLED_ARRAY_FUNCTIONS[func](*args, **kwargs)

def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Any:
"""We have to define this in order to convince xarray that this class is a duckarray, even though we will never support ufuncs."""
return NotImplemented

def astype(self, dtype: np.dtype, /, *, copy: bool = True) -> ConcatenatableArray:
"""Needed because xarray will call this even when it's a no-op"""
if dtype != self.dtype:
raise NotImplementedError()
else:
return self


class FirstElementAccessibleArray(InaccessibleArray):
def __getitem__(self, key):
tuple_idxr = key.tuple
Expand Down
12 changes: 12 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Generic

import numpy as np
import numpy.testing as npt
import pandas as pd
import pytest
import pytz
Expand All @@ -31,6 +32,7 @@
from xarray.core.variable import as_compatible_data, as_variable
from xarray.namedarray.pycompat import array_type
from xarray.tests import (
ConcatenatableArray,
assert_allclose,
assert_array_equal,
assert_equal,
Expand Down Expand Up @@ -551,6 +553,16 @@
assert_identical(expected, actual)
assert actual.dtype == object

def test_concat_without_access(self):
a = self.cls("x", ConcatenatableArray(np.array([0, 1])))

Check failure on line 557 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9 bare-minimum

TestIndexVariable.test_concat_without_access xarray.tests.UnexpectedDataAccess: Tried accessing data

Check failure on line 557 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 all-but-dask

TestIndexVariable.test_concat_without_access xarray.tests.UnexpectedDataAccess: Tried accessing data

Check failure on line 557 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9 min-all-deps

TestIndexVariable.test_concat_without_access xarray.tests.UnexpectedDataAccess: Tried accessing data

Check failure on line 557 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12

TestIndexVariable.test_concat_without_access xarray.tests.UnexpectedDataAccess: Tried accessing data

Check failure on line 557 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9

TestIndexVariable.test_concat_without_access xarray.tests.UnexpectedDataAccess: Tried accessing data

Check failure on line 557 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.11

TestIndexVariable.test_concat_without_access xarray.tests.UnexpectedDataAccess: Tried accessing data

Check failure on line 557 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 flaky

TestIndexVariable.test_concat_without_access xarray.tests.UnexpectedDataAccess: Tried accessing data

Check failure on line 557 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.11

TestIndexVariable.test_concat_without_access xarray.tests.UnexpectedDataAccess: Tried accessing data

Check failure on line 557 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.9

TestIndexVariable.test_concat_without_access xarray.tests.UnexpectedDataAccess: Tried accessing data

Check failure on line 557 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.12

TestIndexVariable.test_concat_without_access xarray.tests.UnexpectedDataAccess: Tried accessing data
b = self.cls("x", ConcatenatableArray(np.array([2, 3])))
actual = Variable.concat([a, b], dim="x")
expected_arr = np.array([0, 1, 2, 3])
expected = Variable("x", ConcatenatableArray(expected_arr))
assert isinstance(actual.data, ConcatenatableArray)

Check failure on line 562 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9 min-all-deps

TestVariableWithDask.test_concat_without_access assert False + where False = isinstance(dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>, ConcatenatableArray) + where dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray> = <xarray.Variable (x: 4)> Size: 32B\ndask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>.data

Check failure on line 562 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.12

TestVariableWithDask.test_concat_without_access assert False + where False = isinstance(dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>, ConcatenatableArray) + where dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray> = <xarray.Variable (x: 4)> Size: 32B\ndask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>.data

Check failure on line 562 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.9

TestVariableWithDask.test_concat_without_access assert False + where False = isinstance(dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>, ConcatenatableArray) + where dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray> = <xarray.Variable (x: 4)> Size: 32B\ndask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>.data

Check failure on line 562 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.11

TestVariableWithDask.test_concat_without_access assert False + where False = isinstance(dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>, ConcatenatableArray) + where dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray> = <xarray.Variable (x: 4)> Size: 32B\ndask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>.data

Check failure on line 562 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10 flaky

TestVariableWithDask.test_concat_without_access assert False + where False = isinstance(dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>, ConcatenatableArray) + where dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray> = <xarray.Variable (x: 4)> Size: 32B\ndask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>.data

Check failure on line 562 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.11

TestVariableWithDask.test_concat_without_access assert False + where False = isinstance(dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>, ConcatenatableArray) + where dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray> = <xarray.Variable (x: 4)> Size: 32B\ndask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>.data

Check failure on line 562 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.9

TestVariableWithDask.test_concat_without_access assert False + where False = isinstance(dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>, ConcatenatableArray) + where dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray> = <xarray.Variable (x: 4)> Size: 32B\ndask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>.data

Check failure on line 562 in xarray/tests/test_variable.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.12

TestVariableWithDask.test_concat_without_access assert False + where False = isinstance(dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>, ConcatenatableArray) + where dask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray> = <xarray.Variable (x: 4)> Size: 32B\ndask.array<concatenate, shape=(4,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>.data
assert expected.dims == ("x",)
npt.assert_equal(actual.data.array, expected_arr)

@pytest.mark.parametrize("deep", [True, False])
@pytest.mark.parametrize("astype", [float, int, str])
def test_copy(self, deep: bool, astype: type[object]) -> None:
Expand Down