-
-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(feat): Support for pandas
ExtensionArray
#8723
Changes from all commits
b2712f1
47bddd2
dc8b788
75524c8
c9ab452
1f3d0fa
8a70e3c
f5a6505
08a4feb
d5b218b
00256fa
b7ddbd6
a165851
a826edd
fde19ea
4c55707
58ba17d
a255310
4e78b7e
d9cedf5
426664d
22ca77d
f32cfdf
60f8927
ff22d76
2153e81
b6d0b31
d285871
d847277
8238c64
1260cd4
b04ef98
b9937bf
0bba03f
b714549
a3a678c
e521844
2d3e930
04c9969
5514539
bedfa5c
e6c2690
82dbda9
12217ed
dd5b87d
761a874
52cabc8
e0d58fa
c1e0e64
17e3390
dd2ef39
c8e6bfe
b2a9517
f5e1bd0
407fad1
3a47f09
fdd3de4
6b23629
1c9047f
9be6b03
d9304f1
6ec6725
bc9ac4c
1e906db
6fb8668
8f034b4
90a6de6
2bd422a
ff67943
661d9f2
caee1c6
1d12f5e
31dfbb5
23b347f
902c74b
0b64506
0c7e023
dd7fe98
f0df768
e2f0487
1eb6741
2a7300a
9cceadc
f2588c1
a0a63bd
5bb2bde
f85f166
7ecdeba
6bc40fc
e9dc53f
4791799
c649362
fc60dcf
0374086
b9515a6
72bf807
63b6c42
1d18439
17f05da
c906c81
e6db83b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,6 +130,7 @@ module = [ | |
"opt_einsum.*", | ||
"pandas.*", | ||
"pooch.*", | ||
"pyarrow.*", | ||
"pydap.*", | ||
"pytest.*", | ||
"scipy.*", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
from numpy import concatenate as _concatenate | ||
from numpy.lib.stride_tricks import sliding_window_view # noqa | ||
from packaging.version import Version | ||
from pandas.api.types import is_extension_array_dtype | ||
|
||
from xarray.core import dask_array_ops, dtypes, nputils | ||
from xarray.core.options import OPTIONS | ||
|
@@ -156,7 +157,7 @@ def isnull(data): | |
return full_like(data, dtype=bool, fill_value=False) | ||
else: | ||
# at this point, array should have dtype=object | ||
if isinstance(data, np.ndarray): | ||
if isinstance(data, np.ndarray) or is_extension_array_dtype(data): | ||
return pandas_isnull(data) | ||
else: | ||
# Not reachable yet, but intended for use with other duck array | ||
|
@@ -221,9 +222,19 @@ def asarray(data, xp=np): | |
|
||
def as_shared_dtype(scalars_or_arrays, xp=np): | ||
"""Cast a arrays to a shared dtype using xarray's type promotion rules.""" | ||
array_type_cupy = array_type("cupy") | ||
if array_type_cupy and any( | ||
isinstance(x, array_type_cupy) for x in scalars_or_arrays | ||
if any(is_extension_array_dtype(x) for x in scalars_or_arrays): | ||
extension_array_types = [ | ||
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) | ||
] | ||
if len(extension_array_types) == len(scalars_or_arrays) and all( | ||
isinstance(x, type(extension_array_types[0])) for x in extension_array_types | ||
): | ||
return scalars_or_arrays | ||
ilan-gold marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError( | ||
f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}" | ||
) | ||
elif array_type_cupy := array_type("cupy") and any( # noqa: F841 | ||
isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821 | ||
Comment on lines
+236
to
+237
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this kind of syntax allowed? I suspect the CI didn't run this code:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hmaarrfk Interesting. I see you fixed this. I must have done this when testing because I do specifically remember testing this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No issues. It was a straightforward fix. |
||
): | ||
import cupy as cp | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Sequence | ||
from typing import Callable, Generic | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from pandas.api.types import is_extension_array_dtype | ||
|
||
from xarray.core.types import DTypeLikeSave, T_ExtensionArray | ||
|
||
HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} | ||
|
||
|
||
def implements(numpy_function): | ||
"""Register an __array_function__ implementation for MyArray objects.""" | ||
|
||
def decorator(func): | ||
HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func | ||
return func | ||
|
||
return decorator | ||
|
||
|
||
@implements(np.issubdtype) | ||
def __extension_duck_array__issubdtype( | ||
extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave | ||
) -> bool: | ||
return False # never want a function to think a pandas extension dtype is a subtype of numpy | ||
|
||
|
||
@implements(np.broadcast_to) | ||
def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): | ||
if shape[0] == len(arr) and len(shape) == 1: | ||
return arr | ||
raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") | ||
|
||
|
||
@implements(np.stack) | ||
def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): | ||
raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") | ||
|
||
|
||
@implements(np.concatenate) | ||
def __extension_duck_array__concatenate( | ||
arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None | ||
) -> T_ExtensionArray: | ||
return type(arrays[0])._concat_same_type(arrays) | ||
|
||
|
||
@implements(np.where) | ||
def __extension_duck_array__where( | ||
condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray | ||
) -> T_ExtensionArray: | ||
if ( | ||
isinstance(x, pd.Categorical) | ||
and isinstance(y, pd.Categorical) | ||
and x.dtype != y.dtype | ||
): | ||
x = x.add_categories(set(y.categories).difference(set(x.categories))) | ||
y = y.add_categories(set(x.categories).difference(set(y.categories))) | ||
return pd.Series(x).where(condition, pd.Series(y)).array | ||
|
||
|
||
class PandasExtensionArray(Generic[T_ExtensionArray]): | ||
array: T_ExtensionArray | ||
|
||
def __init__(self, array: T_ExtensionArray): | ||
"""NEP-18 compliant wrapper for pandas extension arrays. | ||
|
||
Parameters | ||
---------- | ||
array : T_ExtensionArray | ||
The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. | ||
``` | ||
""" | ||
if not isinstance(array, pd.api.extensions.ExtensionArray): | ||
raise TypeError(f"{array} is not an pandas ExtensionArray.") | ||
self.array = array | ||
|
||
def __array_function__(self, func, types, args, kwargs): | ||
def replace_duck_with_extension_array(args) -> list: | ||
args_as_list = list(args) | ||
for index, value in enumerate(args_as_list): | ||
if isinstance(value, PandasExtensionArray): | ||
args_as_list[index] = value.array | ||
elif isinstance( | ||
value, tuple | ||
): # should handle more than just tuple? iterable? | ||
args_as_list[index] = tuple( | ||
replace_duck_with_extension_array(value) | ||
) | ||
elif isinstance(value, list): | ||
args_as_list[index] = replace_duck_with_extension_array(value) | ||
return args_as_list | ||
|
||
args = tuple(replace_duck_with_extension_array(args)) | ||
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: | ||
return func(*args, **kwargs) | ||
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) | ||
if is_extension_array_dtype(res): | ||
return type(self)[type(res)](res) | ||
return res | ||
|
||
def __array_ufunc__(ufunc, method, *inputs, **kwargs): | ||
return ufunc(*inputs, **kwargs) | ||
|
||
def __repr__(self): | ||
return f"{type(self)}(array={repr(self.array)})" | ||
|
||
def __getattr__(self, attr: str) -> object: | ||
return getattr(self.array, attr) | ||
|
||
def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: | ||
item = self.array[key] | ||
if is_extension_array_dtype(item): | ||
return type(self)(item) | ||
if np.isscalar(item): | ||
return type(self)(type(self.array)([item])) | ||
return item | ||
|
||
def __setitem__(self, key, val): | ||
self.array[key] = val | ||
|
||
def __eq__(self, other): | ||
if np.isscalar(other): | ||
other = type(self)(type(self.array)([other])) | ||
if isinstance(other, PandasExtensionArray): | ||
return self.array == other.array | ||
return self.array == other | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __ne__(self, other): | ||
return ~(self == other) | ||
|
||
def __len__(self): | ||
return len(self.array) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling
.join()
in a loop will make this method take quadratic time. Can you rewrite this to join all the extension arrays together once, e.g., withpd.concat
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pandas-dev/pandas#57676 Not sure what to do. I don't think
concat
is meant for this? In any case very open to other ideas!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also not sure
join
with a list is faster now that I think of it. I couldn't figure out how to doconcat
though...maybe I should make the index on theextension_array_df
the correct multi-index but this seems tricky?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be good to sort this out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shoyer Could you maybe give some details on using
concat
here? I think we truly do want a join, no?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's open an issue to remind ourselves to make this more efficient.
I guess the core problem is that extension arrays cannot be broadcast to nD with
.set_dims
? Maybe we could raise an error iflen(ordered_dims) > 1
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#8950 done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is true.
I think this currently handles the case where this is >1 so why error out? I think
join
is acceptable here IMO