Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): Support for pandas ExtensionArray #8723

Merged
merged 101 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 90 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
b2712f1
(feat): first pass supporting extension arrays
ilan-gold Feb 2, 2024
47bddd2
(feat): categorical tests + functionality
ilan-gold Feb 2, 2024
dc8b788
(feat): use multiple dispatch for unimplemented ops
ilan-gold Feb 5, 2024
75524c8
(feat): implement (not really) broadcasting
ilan-gold Feb 5, 2024
c9ab452
(chore): add more `groupby` tests
ilan-gold Feb 5, 2024
1f3d0fa
(fix): fix more groupby incompatibility
ilan-gold Feb 5, 2024
8a70e3c
(bug): fix unused categories
ilan-gold Feb 5, 2024
f5a6505
(chore): refactor dispatched methods + tests
ilan-gold Feb 5, 2024
08a4feb
(fix): shared type should check for extension arrays first and then f…
ilan-gold Feb 6, 2024
d5b218b
(refactor): tests moved
ilan-gold Feb 6, 2024
00256fa
(chore): more higher level tests
ilan-gold Feb 6, 2024
b7ddbd6
(feat): to/from dataframe
ilan-gold Feb 8, 2024
a165851
(chore): check for plum import
ilan-gold Feb 8, 2024
a826edd
(fix): `__setitem__`/`__getitem__`
ilan-gold Feb 8, 2024
fde19ea
(chore): disallow stacking
ilan-gold Feb 8, 2024
4c55707
(fix): `pyproject.toml`
ilan-gold Feb 8, 2024
58ba17d
(fix): `as_shared_type` fix
ilan-gold Feb 8, 2024
a255310
(chore): add variable tests
ilan-gold Feb 8, 2024
4e78b7e
(fix): dask + categoricals
ilan-gold Feb 8, 2024
d9cedf5
(chore): notes/docs
ilan-gold Feb 8, 2024
426664d
(chore): remove old testing file
ilan-gold Feb 8, 2024
22ca77d
(chore): remove ocmmented out code
ilan-gold Feb 8, 2024
f32cfdf
Merge branch 'main' into extension_arrays
ilan-gold Feb 8, 2024
60f8927
(fix): import plum dispatch
ilan-gold Feb 8, 2024
ff22d76
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 8, 2024
2153e81
Merge branch 'main' into extension_arrays
ilan-gold Feb 9, 2024
b6d0b31
(refactor): use `is_extension_array_dtype` as much as possible
ilan-gold Feb 9, 2024
d285871
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 9, 2024
d847277
Merge branch 'main' into extension_arrays
ilan-gold Feb 9, 2024
8238c64
(refactor): `extension_array`->`array` + move to `indexing`
ilan-gold Feb 10, 2024
1260cd4
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 10, 2024
b04ef98
(refactor): change order of classes
ilan-gold Feb 10, 2024
b9937bf
(chore): add small pyarrow test
ilan-gold Feb 12, 2024
0bba03f
(fix): fix some mypy issues
ilan-gold Feb 12, 2024
b714549
(fix): don't register unregisterable method
ilan-gold Feb 12, 2024
a3a678c
(fix): appease mypy
ilan-gold Feb 12, 2024
e521844
(fix): more sensible default implemetations allow most use without `p…
ilan-gold Feb 12, 2024
2d3e930
(fix): handling `pyarrow` tests
ilan-gold Feb 12, 2024
04c9969
(fix): actually do import correctly
ilan-gold Feb 12, 2024
5514539
Merge branch 'main' into extension_arrays
ilan-gold Feb 12, 2024
bedfa5c
(fix): `reduce` condition
ilan-gold Feb 13, 2024
e6c2690
Merge branch 'main' into extension_arrays
ilan-gold Feb 13, 2024
82dbda9
(fix): column ordering for dataframes
ilan-gold Feb 13, 2024
12217ed
(refactor): remove encoding business
ilan-gold Feb 13, 2024
dd5b87d
(refactor): raise error for dask + extension array
ilan-gold Feb 13, 2024
761a874
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 13, 2024
52cabc8
Merge branch 'main' into extension_arrays
ilan-gold Feb 13, 2024
e0d58fa
(fix): only wrap `ExtensionDuckArray` that has a `.array` which is a …
ilan-gold Feb 15, 2024
c1e0e64
(fix): use duck array equality method, not pandas
ilan-gold Feb 15, 2024
17e3390
(refactor): bye plum!
ilan-gold Feb 15, 2024
dd2ef39
Merge branch 'main' into extension_arrays
ilan-gold Feb 15, 2024
c8e6bfe
(fix): `and` to `or` for casting to `ExtensionDuckArray`
ilan-gold Feb 15, 2024
b2a9517
(fix): check for class, not type
ilan-gold Feb 16, 2024
f5e1bd0
Merge branch 'main' into extension_arrays
ilan-gold Feb 16, 2024
407fad1
(fix): only support native endianness
ilan-gold Feb 19, 2024
3a47f09
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 19, 2024
fdd3de4
Merge branch 'main' into extension_arrays
ilan-gold Feb 19, 2024
6b23629
Merge branch 'main' into extension_arrays
ilan-gold Feb 20, 2024
1c9047f
(refactor): no need for superfluous checks in `_maybe_wrap_data`
ilan-gold Feb 22, 2024
9be6b03
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Feb 22, 2024
d9304f1
(chore): clean up docs to no longer reference `plum`
ilan-gold Feb 22, 2024
6ec6725
(fix): no longer allow `ExtensionDuckArray` to wrap `ExtensionDuckArray`
ilan-gold Feb 22, 2024
bc9ac4c
(refactor): move `implements` logic to `indexing`
ilan-gold Feb 22, 2024
1e906db
Merge branch 'main' into extension_arrays
ilan-gold Feb 29, 2024
6fb8668
(refactor): `indexing.py` -> `extension_array.py`
ilan-gold Feb 29, 2024
8f034b4
(refactor): `ExtensionDuckArray` -> `PandasExtensionArray`
ilan-gold Feb 29, 2024
90a6de6
Merge branch 'main' into extension_arrays
dcherian Mar 3, 2024
2bd422a
Merge branch 'main' into extension_arrays
ilan-gold Mar 18, 2024
ff67943
Merge branch 'main' into extension_arrays
ilan-gold Mar 25, 2024
661d9f2
(fix): add writeable property
ilan-gold Mar 25, 2024
caee1c6
(fix): don't check writeable for `PandasExtensionArray`
ilan-gold Mar 25, 2024
1d12f5e
(fix): move check eariler
ilan-gold Mar 25, 2024
31dfbb5
Merge branch 'main' into extension_arrays
ilan-gold Mar 26, 2024
23b347f
Merge branch 'main' into extension_arrays
ilan-gold Mar 28, 2024
902c74b
(refactor): correct guard clause
ilan-gold Mar 28, 2024
0b64506
(chore): remove unnecessary `AttributeError`
ilan-gold Mar 28, 2024
0c7e023
(feat): singleton wrapped as array
ilan-gold Mar 28, 2024
dd7fe98
(feat): remove shared dtype casting
ilan-gold Mar 28, 2024
f0df768
(feat): loop once over `dataframe.items`
ilan-gold Mar 28, 2024
e2f0487
(feat): add `__len__` attribute
ilan-gold Mar 28, 2024
1eb6741
(fix): ensure constructor recieves `pd.Categorical`
ilan-gold Mar 28, 2024
2a7300a
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Mar 28, 2024
9cceadc
Update xarray/core/extension_array.py
ilan-gold Mar 28, 2024
f2588c1
Update xarray/core/extension_array.py
ilan-gold Mar 28, 2024
a0a63bd
(fix): drop condition for categorical corrected
ilan-gold Mar 28, 2024
5bb2bde
Merge branch 'main' into extension_arrays
ilan-gold Mar 28, 2024
f85f166
Merge branch 'main' into extension_arrays
ilan-gold Apr 3, 2024
7ecdeba
Merge branch 'main' into extension_arrays
ilan-gold Apr 4, 2024
6bc40fc
Merge branch 'main' into extension_arrays
ilan-gold Apr 11, 2024
e9dc53f
Apply suggestions from code review
dcherian Apr 13, 2024
4791799
(chore): test `chunk` behavior
ilan-gold Apr 16, 2024
c649362
Merge branch 'extension_arrays' of github.com:ilan-gold/xarray into e…
ilan-gold Apr 16, 2024
fc60dcf
Merge branch 'main' into extension_arrays
ilan-gold Apr 16, 2024
0374086
Update xarray/core/variable.py
dcherian Apr 16, 2024
b9515a6
Merge branch 'main' into extension_arrays
dcherian Apr 16, 2024
72bf807
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
63b6c42
(fix): bring back error
ilan-gold Apr 17, 2024
1d18439
(chore): add test for dropping cat for mean
ilan-gold Apr 17, 2024
17f05da
Update whats-new.rst
dcherian Apr 17, 2024
c906c81
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
e6db83b
Merge branch 'main' into extension_arrays
ilan-gold Apr 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ New Features
(:issue:`7377`, :pull:`8684`)
By `Marco Wolsza <https://github.com/maawoo>`_.

- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array
by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`, for example,
will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray` then, such as broadcasting.

Breaking changes
~~~~~~~~~~~~~~~~

Expand Down
4 changes: 3 additions & 1 deletion properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from hypothesis import given # isort:skip

numeric_dtypes = st.one_of(
npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes()
npst.unsigned_integer_dtypes(endianness="="),
npst.integer_dtypes(endianness="="),
npst.floating_dtypes(endianness="="),
)

numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt))
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ module = [
"opt_einsum.*",
"pandas.*",
"pooch.*",
"pyarrow.*",
"pydap.*",
"pytest.*",
"scipy.*",
Expand Down
60 changes: 47 additions & 13 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload

import numpy as np
from pandas.api.types import is_extension_array_dtype

# remove once numpy 2.0 is the oldest supported version
try:
Expand Down Expand Up @@ -6852,10 +6853,13 @@ def reduce(
if (
# Some reduction functions (e.g. std, var) need to run on variables
# that don't have the reduce dims: PR5393
not reduce_dims
or not numeric_only
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
not is_extension_array_dtype(var.dtype)
and (
not reduce_dims
or not numeric_only
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
)
):
# prefer to aggregate over axis=None rather than
# axis=(0, 1) if they will be equivalent, because
Expand Down Expand Up @@ -7168,13 +7172,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
)

def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
columns = [k for k in self.variables if k not in self.dims]
columns_in_order = [k for k in self.variables if k not in self.dims]
non_extension_array_columns = [
k
for k in columns_in_order
if not is_extension_array_dtype(self.variables[k].data)
]
extension_array_columns = [
k
for k in columns_in_order
if is_extension_array_dtype(self.variables[k].data)
]
data = [
self._variables[k].set_dims(ordered_dims).values.reshape(-1)
for k in columns
for k in non_extension_array_columns
]
index = self.coords.to_index([*ordered_dims])
return pd.DataFrame(dict(zip(columns, data)), index=index)
broadcasted_df = pd.DataFrame(
dict(zip(non_extension_array_columns, data)), index=index
)
for extension_array_column in extension_array_columns:
extension_array = self.variables[extension_array_column].data.array
index = self[self.variables[extension_array_column].dims[0]].data
extension_array_df = pd.DataFrame(
{extension_array_column: extension_array},
index=self[self.variables[extension_array_column].dims[0]].data,
)
extension_array_df.index.name = self.variables[extension_array_column].dims[
0
]
broadcasted_df = broadcasted_df.join(extension_array_df)
Comment on lines +7194 to +7204
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling .join() in a loop will make this method take quadratic time. Can you rewrite this to join all the extension arrays together once, e.g., with pd.concat?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pandas-dev/pandas#57676 Not sure what to do. I don't think concat is meant for this? In any case very open to other ideas!

Copy link
Contributor Author

@ilan-gold ilan-gold Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also not sure join with a list is faster now that I think of it. I couldn't figure out how to do concat though...maybe I should make the index on the extension_array_df the correct multi-index but this seems tricky?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be good to sort this out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shoyer Could you maybe give some details on using concat here? I think we truly do want a join, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's open an issue to remind ourselves to make this more efficient.

I guess the core problem is that extension arrays cannot be broadcast to nD with .set_dims? Maybe we could raise an error if len(ordered_dims) > 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#8950 done!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the core problem is that extension arrays cannot be broadcast to nD with .set_dims?

I think this is true.

Maybe we could raise an error if len(ordered_dims) > 1?

I think this currently handles the case where this is >1 so why error out? I think join is acceptable here IMO

return broadcasted_df[columns_in_order]

def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame:
"""Convert this dataset into a pandas.DataFrame.
Expand Down Expand Up @@ -7321,11 +7349,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
"cannot convert a DataFrame with a non-unique MultiIndex into xarray"
)

# Cast to a NumPy array first, in case the Series is a pandas Extension
# array (which doesn't have a valid NumPy dtype)
# TODO: allow users to control how this casting happens, e.g., by
# forwarding arguments to pandas.Series.to_numpy?
arrays = [(k, np.asarray(v)) for k, v in dataframe.items()]
arrays = []
extension_arrays = []
for k, v in dataframe.items():
if not is_extension_array_dtype(v):
arrays.append((k, np.asarray(v)))
else:
extension_arrays.append((k, v))

indexes: dict[Hashable, Index] = {}
index_vars: dict[Hashable, Variable] = {}
Expand All @@ -7339,6 +7369,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
xr_idx = PandasIndex(lev, dim)
indexes[dim] = xr_idx
index_vars.update(xr_idx.create_variables())
arrays += [(k, np.asarray(v)) for k, v in extension_arrays]
extension_arrays = []
else:
index_name = idx.name if idx.name is not None else "index"
dims = (index_name,)
Expand All @@ -7352,7 +7384,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
obj._set_sparse_data_from_dataframe(idx, arrays, dims)
else:
obj._set_numpy_data_from_dataframe(idx, arrays, dims)
return obj
for name, extension_array in extension_arrays:
obj[name] = (dims, extension_array)
return obj[dataframe.columns] if len(dataframe.columns) else obj

def to_dask_dataframe(
self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False
Expand Down
19 changes: 15 additions & 4 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from numpy import concatenate as _concatenate
from numpy.lib.stride_tricks import sliding_window_view # noqa
from packaging.version import Version
from pandas.api.types import is_extension_array_dtype

from xarray.core import dask_array_ops, dtypes, nputils
from xarray.core.options import OPTIONS
Expand Down Expand Up @@ -156,7 +157,7 @@ def isnull(data):
return full_like(data, dtype=bool, fill_value=False)
else:
# at this point, array should have dtype=object
if isinstance(data, np.ndarray):
if isinstance(data, np.ndarray) or is_extension_array_dtype(data):
return pandas_isnull(data)
else:
# Not reachable yet, but intended for use with other duck array
Expand Down Expand Up @@ -221,9 +222,19 @@ def asarray(data, xp=np):

def as_shared_dtype(scalars_or_arrays, xp=np):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
array_type_cupy = array_type("cupy")
if array_type_cupy and any(
isinstance(x, array_type_cupy) for x in scalars_or_arrays
if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
if len(extension_array_types) == len(scalars_or_arrays) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return scalars_or_arrays
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}"
)
elif array_type_cupy := array_type("cupy") and any( # noqa: F841
isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821
Comment on lines +236 to +237
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this kind of syntax allowed?

I suspect the CI didn't run this code:

$ pytest xarray/tests/test_array_api.py -x
================================= test session starts =================================
platform linux -- Python 3.10.14, pytest-8.2.0, pluggy-1.5.0
rootdir: /home/mark/git/xarray
configfile: pyproject.toml
collected 13 items

xarray/tests/test_array_api.py .F

====================================== FAILURES =======================================
__________________________________ test_aggregation ___________________________________

arrays = (<xarray.DataArray (x: 2, y: 3)> Size: 48B
array([[ 1.,  2.,  3.],
       [ 4.,  5., nan]])
Coordinates:
  * x        ...      [ 4.,  5., nan]], dtype=float64)
Coordinates:
  * x        (x) int64 16B 10 20
Dimensions without coordinates: y)

    def test_aggregation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
        np_arr, xp_arr = arrays
>       expected = np_arr.sum()

/home/mark/git/xarray/xarray/tests/test_array_api.py:51:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/home/mark/git/xarray/xarray/core/_aggregations.py:1857: in sum
    return self.reduce(
/home/mark/git/xarray/xarray/core/dataarray.py:3805: in reduce
    var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
/home/mark/git/xarray/xarray/core/variable.py:1663: in reduce
    result = super().reduce(
/home/mark/git/xarray/xarray/namedarray/core.py:889: in reduce
    data = func(self.data, **kwargs)
/home/mark/git/xarray/xarray/core/duck_array_ops.py:427: in f
    return func(values, axis=axis, **kwargs)
/home/mark/git/xarray/xarray/core/nanops.py:99: in nansum
    result = sum_where(a, axis=axis, dtype=dtype, where=mask)
/home/mark/git/xarray/xarray/core/duck_array_ops.py:333: in sum_where
    a = where_method(xp.zeros_like(data), where, data)
/home/mark/git/xarray/xarray/core/duck_array_ops.py:349: in where_method
    return where(cond, data, other)
/home/mark/git/xarray/xarray/core/duck_array_ops.py:343: in where
    return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
/home/mark/git/xarray/xarray/core/duck_array_ops.py:236: in as_shared_dtype
    elif array_type_cupy := array_type("cupy") and any(  # noqa: F841
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

.0 = <list_iterator object at 0x7e404174b6d0>

    elif array_type_cupy := array_type("cupy") and any(  # noqa: F841
>       isinstance(x, array_type_cupy) for x in scalars_or_arrays  # noqa: F821
    ):
E   NameError: free variable 'array_type_cupy' referenced before assignment in enclosing scope

/home/mark/git/xarray/xarray/core/duck_array_ops.py:237: NameError
=============================== short test summary info ===============================
FAILED xarray/tests/test_array_api.py::test_aggregation - NameError: free variable 'array_type_cupy' referenced before assignment in enclosi...
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
============================= 1 failed, 1 passed in 0.33s =============================

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hmaarrfk Interesting. I see you fixed this. I must have done this when testing because I do specifically remember testing this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No issues. It was a straightforward fix.

):
import cupy as cp

Expand Down
136 changes: 136 additions & 0 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Callable, Generic

import numpy as np
import pandas as pd
from pandas.api.types import is_extension_array_dtype

from xarray.core.types import DTypeLikeSave, T_ExtensionArray

HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {}


def implements(numpy_function):
"""Register an __array_function__ implementation for MyArray objects."""

def decorator(func):
HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func
return func

return decorator


@implements(np.issubdtype)
def __extension_duck_array__issubdtype(
extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave
) -> bool:
return False # never want a function to think a pandas extension dtype is a subtype of numpy


@implements(np.broadcast_to)
def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple):
if shape[0] == len(arr) and len(shape) == 1:
return arr
raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.")


@implements(np.stack)
def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int):
raise NotImplementedError("Cannot stack 1d-only pandas categorical array.")


@implements(np.concatenate)
def __extension_duck_array__concatenate(
arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None
) -> T_ExtensionArray:
return type(arrays[0])._concat_same_type(arrays)


@implements(np.where)
def __extension_duck_array__where(
condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray
) -> T_ExtensionArray:
if (
isinstance(x, pd.Categorical)
and isinstance(y, pd.Categorical)
and x.dtype != y.dtype
):
x = x.add_categories(set(y.categories).difference(set(x.categories)))
y = y.add_categories(set(x.categories).difference(set(y.categories)))
return pd.Series(x).where(condition, pd.Series(y)).array


class PandasExtensionArray(Generic[T_ExtensionArray]):
array: T_ExtensionArray

def __init__(self, array: T_ExtensionArray):
"""NEP-18 compliant wrapper for pandas extension arrays.

Parameters
----------
array : T_ExtensionArray
The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation.
```
"""
if not isinstance(array, pd.api.extensions.ExtensionArray):
raise TypeError(f"{array} is not an pandas ExtensionArray.")
self.array = array

def __array_function__(self, func, types, args, kwargs):
def replace_duck_with_extension_array(args) -> list:
args_as_list = list(args)
for index, value in enumerate(args_as_list):
if isinstance(value, PandasExtensionArray):
args_as_list[index] = value.array
elif isinstance(
value, tuple
): # should handle more than just tuple? iterable?
args_as_list[index] = tuple(
replace_duck_with_extension_array(value)
)
elif isinstance(value, list):
args_as_list[index] = replace_duck_with_extension_array(value)
return args_as_list

args = tuple(replace_duck_with_extension_array(args))
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
return func(*args, **kwargs)
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs)
if is_extension_array_dtype(res):
return type(self)[type(res)](res)
return res

def __array_ufunc__(ufunc, method, *inputs, **kwargs):
return ufunc(*inputs, **kwargs)

def __repr__(self):
return f"{type(self)}(array={repr(self.array)})"

def __getattr__(self, attr: str) -> object:
return getattr(self.array, attr)

def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
item = self.array[key]
if is_extension_array_dtype(item):
return type(self)(item)
if np.isscalar(item):
return type(self)(type(self.array)([item]))
return item

def __setitem__(self, key, val):
self.array[key] = val

def __eq__(self, other):
if np.isscalar(other):
other = type(self)(type(self.array)([other]))
if isinstance(other, PandasExtensionArray):
return self.array == other.array
return self.array == other
dcherian marked this conversation as resolved.
Show resolved Hide resolved

def __ne__(self, other):
return ~(self == other)

def __len__(self):
return len(self.array)
3 changes: 3 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def copy(
# hopefully in the future we can narrow this down more:
T_DuckArray = TypeVar("T_DuckArray", bound=Any, covariant=True)

# For typing pandas extension arrays.
T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray)


ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
VarCompatible = Union["Variable", "ScalarOrArray"]
Expand Down
10 changes: 10 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from pandas.api.types import is_extension_array_dtype

import xarray as xr # only for Dataset and DataArray
from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils
from xarray.core.arithmetic import VariableArithmetic
from xarray.core.common import AbstractArray
from xarray.core.extension_array import PandasExtensionArray
from xarray.core.indexing import (
BasicIndexer,
OuterIndexer,
Expand Down Expand Up @@ -47,6 +49,7 @@
NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
indexing.ExplicitlyIndexed,
pd.Index,
pd.api.extensions.ExtensionArray,
dcherian marked this conversation as resolved.
Show resolved Hide resolved
)
# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,)
Expand Down Expand Up @@ -184,6 +187,8 @@ def _maybe_wrap_data(data):
"""
if isinstance(data, pd.Index):
return PandasIndexingAdapter(data)
if isinstance(data, pd.api.extensions.ExtensionArray):
return PandasExtensionArray[type(data)](data)
return data


Expand Down Expand Up @@ -2570,6 +2575,11 @@ def chunk( # type: ignore[override]
dask.array.from_array
"""

if is_extension_array_dtype(self):
raise ValueError(
f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first."
)

ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
if from_array_kwargs is None:
from_array_kwargs = {}

Expand Down
10 changes: 5 additions & 5 deletions xarray/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
Generates only those numpy dtypes which xarray can handle.

Use instead of hypothesis.extra.numpy.scalar_dtypes in order to exclude weirder dtypes such as unicode, byte_string, array, or nested dtypes.
Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows.
Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows. Checks only native endianness.

Requires the hypothesis package to be installed.

Expand All @@ -56,10 +56,10 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]:
# TODO should this be exposed publicly?
# We should at least decide what the set of numpy dtypes that xarray officially supports is.
return (
npst.integer_dtypes()
| npst.unsigned_integer_dtypes()
| npst.floating_dtypes()
| npst.complex_number_dtypes()
npst.integer_dtypes(endianness="=")
| npst.unsigned_integer_dtypes(endianness="=")
| npst.floating_dtypes(endianness="=")
| npst.complex_number_dtypes(endianness="=")
# | npst.datetime64_dtypes()
# | npst.timedelta64_dtypes()
# | npst.unicode_string_dtypes()
Expand Down