Skip to content

(fix): handle internal type promotion and nas for extension arrays properly #10423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
"""

from functools import partial
from typing import cast

import numpy as np
import pandas as pd
import pytest

import xarray as xr
from xarray.core.dataset import Dataset

pytest.importorskip("hypothesis")
import hypothesis.extra.numpy as npst # isort:skip
Expand Down Expand Up @@ -88,10 +90,10 @@ def test_roundtrip_dataarray(data, arr) -> None:


@given(datasets_1d_vars())
def test_roundtrip_dataset(dataset) -> None:
def test_roundtrip_dataset(dataset: Dataset) -> None:
df = dataset.to_dataframe()
assert isinstance(df, pd.DataFrame)
roundtripped = xr.Dataset(df)
roundtripped = xr.Dataset.from_dataframe(df)
xr.testing.assert_identical(dataset, roundtripped)


Expand All @@ -101,7 +103,7 @@ def test_roundtrip_pandas_series(ser, ix_name) -> None:
ser.index.name = ix_name
arr = xr.DataArray(ser)
roundtripped = arr.to_pandas()
pd.testing.assert_series_equal(ser, roundtripped)
pd.testing.assert_series_equal(ser, roundtripped) # type: ignore[arg-type]
xr.testing.assert_identical(arr, roundtripped.to_xarray())


Expand All @@ -119,7 +121,7 @@ def test_roundtrip_pandas_dataframe(df) -> None:
df.columns.name = "cols"
arr = xr.DataArray(df)
roundtripped = arr.to_pandas()
pd.testing.assert_frame_equal(df, roundtripped)
pd.testing.assert_frame_equal(df, cast(pd.DataFrame, roundtripped))
xr.testing.assert_identical(arr, roundtripped.to_xarray())


Expand All @@ -143,8 +145,8 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
pd.arrays.IntervalArray(
[pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)]
),
pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])),
pd.arrays.DatetimeArray._from_sequence(
pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])), # type: ignore[attr-defined]
pd.arrays.DatetimeArray._from_sequence( # type: ignore[attr-defined]
pd.DatetimeIndex(["2023-01-01", "2023-01-02", "2023-01-03"], freq="D")
),
np.array([1, 2, 3], dtype="int64"),
Expand Down
106 changes: 81 additions & 25 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from __future__ import annotations

import functools
from typing import Any
from collections.abc import Iterable
from typing import TYPE_CHECKING, TypeVar, cast

import numpy as np
from pandas.api.extensions import ExtensionDtype
from pandas.api.types import is_extension_array_dtype

from xarray.compat import array_api_compat, npcompat
from xarray.compat.npcompat import HAS_STRING_DTYPE
from xarray.core import utils

if TYPE_CHECKING:
from typing import Any


# Use as a sentinel value to indicate a dtype appropriate NA value.
NA = utils.ReprObject("<NA>")

Expand Down Expand Up @@ -47,8 +53,10 @@ def __eq__(self, other):
(np.bytes_, np.str_), # numpy promotes to unicode
)

T_dtype = TypeVar("T_dtype", np.dtype, ExtensionDtype)

def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:

def maybe_promote(dtype: T_dtype) -> tuple[T_dtype, Any]:
"""Simpler equivalent of pandas.core.common._maybe_promote

Parameters
Expand All @@ -63,7 +71,13 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
# N.B. these casting rules should match pandas
dtype_: np.typing.DTypeLike
fill_value: Any
if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
if is_extension_array_dtype(dtype):
return dtype, cast(ExtensionDtype, dtype).na_value # type: ignore[redundant-cast]
if not isinstance(dtype, np.dtype):
raise TypeError(
f"dtype {dtype} must be one of an extension array dtype or numpy dtype"
)
elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
# for now, we always promote string dtypes to object for consistency with existing behavior
# TODO: refactor this once we have a better way to handle numpy vlen-string dtypes
dtype_ = object
Expand Down Expand Up @@ -222,23 +236,66 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
return xp.isdtype(dtype, kind)


def preprocess_types(t):
if isinstance(t, str | bytes):
return type(t)
elif isinstance(dtype := getattr(t, "dtype", t), np.dtype) and (
np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)
):
def maybe_promote_to_variable_width(
array_or_dtype: np.typing.ArrayLike
| np.typing.DTypeLike
| ExtensionDtype
| str
| bytes,
*,
should_return_str_or_bytes: bool = False,
) -> np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype:
if isinstance(array_or_dtype, str | bytes):
if should_return_str_or_bytes:
return array_or_dtype
return type(array_or_dtype)
elif isinstance(
dtype := getattr(array_or_dtype, "dtype", array_or_dtype), np.dtype
) and (np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)):
# drop the length from numpy's fixed-width string dtypes, it is better to
# recalculate
# TODO(keewis): remove once the minimum version of `numpy.result_type` does this
# for us
return dtype.type
else:
return t
return array_or_dtype


def should_promote_to_object(
arrays_and_dtypes: Iterable[
np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype
],
xp,
) -> bool:
"""
Test whether the given arrays_and_dtypes, when evaluated individually, match the
type promotion rules found in PROMOTE_TO_OBJECT.
"""
np_result_types = set()
for arr_or_dtype in arrays_and_dtypes:
try:
result_type = array_api_compat.result_type(
maybe_promote_to_variable_width(arr_or_dtype), xp=xp
)
if isinstance(result_type, np.dtype):
np_result_types.add(result_type)
except TypeError:
# passing individual objects to xp.result_type means NEP-18 implementations won't have
# a chance to intercept special values (such as NA) that numpy core cannot handle
pass

if np_result_types:
for left, right in PROMOTE_TO_OBJECT:
if any(np.issubdtype(t, left) for t in np_result_types) and any(
np.issubdtype(t, right) for t in np_result_types
):
return True

return False


def result_type(
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike | ExtensionDtype,
xp=None,
) -> np.dtype:
"""Like np.result_type, but with type promotion rules matching pandas.
Expand All @@ -263,19 +320,18 @@ def result_type(
if xp is None:
xp = get_array_namespace(arrays_and_dtypes)

types = {
array_api_compat.result_type(preprocess_types(t), xp=xp)
for t in arrays_and_dtypes
}
if any(isinstance(t, np.dtype) for t in types):
# only check if there's numpy dtypes – the array API does not
# define the types we're checking for
for left, right in PROMOTE_TO_OBJECT:
if any(np.issubdtype(t, left) for t in types) and any(
np.issubdtype(t, right) for t in types
):
return np.dtype(object)

if should_promote_to_object(arrays_and_dtypes, xp):
return np.dtype(object)
return array_api_compat.result_type(
*map(preprocess_types, arrays_and_dtypes), xp=xp
*map(
functools.partial(
maybe_promote_to_variable_width,
# let extension arrays handle their own str/bytes
should_return_str_or_bytes=any(
map(is_extension_array_dtype, arrays_and_dtypes) # type: ignore[arg-type]
),
),
arrays_and_dtypes,
),
xp=xp,
)
51 changes: 26 additions & 25 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from xarray.compat import dask_array_compat, dask_array_ops
from xarray.compat.array_api_compat import get_array_namespace
from xarray.core import dtypes, nputils
from xarray.core.extension_array import PandasExtensionArray
from xarray.core.extension_array import (
PandasExtensionArray,
as_extension_array,
is_scalar,
)
from xarray.core.options import OPTIONS
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
from xarray.namedarray.parallelcompat import get_chunked_array_type
Expand Down Expand Up @@ -250,12 +254,19 @@

if xp == np:
# numpy currently doesn't have a astype:
return data.astype(dtype, **kwargs)

Check warning on line 257 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 257 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 257 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 257 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

invalid value encountered in cast

Check warning on line 257 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast

Check warning on line 257 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast
return xp.astype(data, dtype, **kwargs)


def asarray(data, xp=np, dtype=None):
converted = data if is_duck_array(data) else xp.asarray(data)
if is_duck_array(data):
converted = data
elif is_extension_array_dtype(dtype):
# data may or may not be an ExtensionArray, so we can't rely on
# np.asarray to call our NEP-18 handler; gotta hook it ourselves
converted = PandasExtensionArray(as_extension_array(data, dtype))
else:
converted = xp.asarray(data, dtype=dtype)

if dtype is None or converted.dtype == dtype:
return converted
Expand All @@ -267,27 +278,7 @@


def as_shared_dtype(scalars_or_arrays, xp=None):
"""Cast arrays to a shared dtype using xarray's type promotion rules."""
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
non_nans = [x for x in scalars_or_arrays if not isna(x)]
if len(extension_array_types) == len(non_nans) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return [
x
if not isna(x)
else PandasExtensionArray(
type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype)
)
for x in scalars_or_arrays
]
raise ValueError(
f"Cannot cast values to shared type, found values: {scalars_or_arrays}"
)

"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
# Avoid calling array_type("cupy") repeatidely in the any check
array_type_cupy = array_type("cupy")
if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
Expand All @@ -296,7 +287,12 @@
xp = cp
elif xp is None:
xp = get_array_namespace(scalars_or_arrays)

scalars_or_arrays = [
PandasExtensionArray(s_or_a)
if isinstance(s_or_a, pd.api.extensions.ExtensionArray)
else s_or_a
for s_or_a in scalars_or_arrays
]
# Pass arrays directly instead of dtypes to result_type so scalars
# get handled properly.
# Note that result_type() safely gets the dtype from dask arrays without
Expand Down Expand Up @@ -407,7 +403,12 @@
else:
condition = astype(condition, dtype=dtype, xp=xp)

return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
promoted_x, promoted_y = as_shared_dtype([x, y], xp=xp)

# pd.where won't broadcast 0-dim arrays across a series; scalar y's must be preserved
maybe_promoted_y = y if is_extension_array_dtype(x) and is_scalar(y) else promoted_y

return xp.where(condition, promoted_x, maybe_promoted_y)


def where_method(data, cond, other=dtypes.NA):
Expand Down
Loading
Loading