Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid coercing to numpy in as_shared_dtypes #8714

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
44 changes: 25 additions & 19 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,26 +214,32 @@ def astype(data, dtype, **kwargs):


def asarray(data, xp=np):
print(data)
print(type(data))
return data if is_duck_array(data) else xp.asarray(data)


def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
array_type_cupy = array_type("cupy")
if array_type_cupy and any(
isinstance(x, array_type_cupy) for x in scalars_or_arrays
):
import cupy as cp

arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
def as_duck_array(data, xp=np):
TomNicholas marked this conversation as resolved.
Show resolved Hide resolved
if is_duck_array(data):
return data
elif hasattr(data, "get_duck_array"):
# must be a lazy indexing class wrapping a duck array
return data.get_duck_array()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this idea always work? What if it steps down through a lazy decoder class that changes the dtype...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those should be going through

class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin):
"""Lazily computed array holding values of elemwise-function.
Do not construct this object directly: call lazy_elemwise_func instead.
Values are computed upon indexing or coercion to a NumPy array.
"""
def __init__(self, array, func: Callable, dtype: np.typing.DTypeLike):
assert not is_chunked_array(array)
self.array = indexing.as_indexable(array)
self.func = func
self._dtype = dtype

so you should be fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm getting confused as to how this all works now... Don't I want to be computing as_shared_dtype using the dtype of the outermost wrapped class? Whereas this will step through all the way to the innermost duckarray, which may have a different dtype?

Copy link
Contributor

@dcherian dcherian Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of now, as_shared_dtype is expected to return pure duck arrays for stack, concatenate, and where.

So that means we need to read from disk, which you do with to_duck_array and all these wrapper layers will be resolved.

It will get more complicated when we do lazy concatenation in Xarray, then we'd need to lazily infer dtypes and apply a lazy astype.

else:
arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously this asarray call would coerce to numpy unnecessarily, when all we really wanted was an array type that we could examine the .dtype attribute of.

# Pass arrays directly instead of dtypes to result_type so scalars
# get handled properly.
# Note that result_type() safely gets the dtype from dask arrays without
# evaluating them.
out_type = dtypes.result_type(*arrays)
return [astype(x, out_type, copy=False) for x in arrays]
array_type_cupy = array_type("cupy")
if array_type_cupy and any(isinstance(data, array_type_cupy)):
import cupy as cp

return asarray(data, xp=cp)
else:
return asarray(data, xp=xp)


def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast arrays to a shared dtype using xarray's type promotion rules."""
duckarrays = [as_duck_array(obj, xp=xp) for obj in scalars_or_arrays]
out_type = dtypes.result_type(*duckarrays)
return [astype(x, out_type, copy=False) for x in duckarrays]


def broadcast_to(array, shape):
Expand Down Expand Up @@ -327,7 +333,7 @@ def sum_where(data, axis=None, dtype=None, where=None):
def where(condition, x, y):
"""Three argument where() with better dtype promotion rules."""
xp = get_array_namespace(condition)
return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
return xp.where(condition, *as_shared_dtype([x, y]))


def where_method(data, cond, other=dtypes.NA):
Expand All @@ -350,14 +356,14 @@ def concatenate(arrays, axis=0):
arrays[0], np.ndarray
):
xp = get_array_namespace(arrays[0])
return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis)
return xp.concat(as_shared_dtype(arrays), axis=axis)
return _concatenate(as_shared_dtype(arrays), axis=axis)


def stack(arrays, axis=0):
"""stack() with better dtype promotion rules."""
xp = get_array_namespace(arrays[0])
return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis)
return xp.stack(as_shared_dtype(arrays), axis=axis)


def reshape(array, shape):
Expand Down
Loading