From b2712f1c0d59fd36dcbfbd9326a6e1e7c8d1c3ae Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 2 Feb 2024 14:40:24 +0100 Subject: [PATCH 01/70] (feat): first pass supporting extension arrays --- xarray/coding/strings.py | 8 ++-- xarray/coding/test.py | 35 ++++++++++++++++ xarray/coding/times.py | 12 ++++-- xarray/coding/variables.py | 20 +++++---- xarray/conventions.py | 12 ++++-- xarray/core/dataset.py | 2 + xarray/core/duck_array_ops.py | 78 ++++++++++++++++++++++++++++++++++- xarray/core/variable.py | 4 ++ 8 files changed, 150 insertions(+), 21 deletions(-) create mode 100644 xarray/coding/test.py diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index fa57bffb8d5..c5654556ad4 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -4,6 +4,7 @@ from functools import partial import numpy as np +from pandas.api.types import is_extension_array_dtype from xarray.coding.variables import ( VariableCoder, @@ -26,11 +27,10 @@ def create_vlen_dtype(element_type): def check_vlen_dtype(dtype): - if dtype.kind != "O" or dtype.metadata is None: + if is_extension_array_dtype(dtype) or dtype.kind != "O" or dtype.metadata is None: return None - else: - # check xarray (element_type) as well as h5py (vlen) - return dtype.metadata.get("element_type", dtype.metadata.get("vlen")) + # check xarray (element_type) as well as h5py (vlen) + return dtype.metadata.get("element_type", dtype.metadata.get("vlen")) def is_unicode_dtype(dtype): diff --git a/xarray/coding/test.py b/xarray/coding/test.py new file mode 100644 index 00000000000..1243cc75d04 --- /dev/null +++ b/xarray/coding/test.py @@ -0,0 +1,35 @@ +import netCDF4 as nc +import numpy as np + +import xarray as xr + +xr.get_options() + +ds = nc.Dataset("mre.nc", "w", format="NETCDF4") +cloud_type_enum = ds.createEnumType(int, "cloud_type", {"clear": 0, "cloudy": 1}) # +ds.createDimension("time", size=(10)) +x = np.arange(10) +ds.createVariable("x", np.int32, dimensions=("time",)) +ds.variables["x"][:] = x +# {'cloud_type': : name = 'cloud_type', numpy dtype = int64, fields/values ={'clear': 0, 'cloudy': 1}} +ds.createVariable("cloud", cloud_type_enum, dimensions=("time",)) +ds["cloud"][:] = [1, 0, 1, 0, 1, 0, 1, 0, 0, 1] +ds.close() + +# -- Open dataset with xarray +xr_ds = xr.open_dataset("./mre.nc") +xr_ds.to_netcdf("./mre_new.nc") +xr_ds = xr.open_dataset("./mre_new.nc") +xr_ds +ds_re_read = nc.Dataset("./mre_new.nc", "r", format="NETCDF4") +ds_re_read + +# import numpy as np +# import xarray as xr + +# codes = np.array([0, 1, 2, 1, 0]) +# categories = {0: 'foo', 1: 'jazz', 2: 'bar'} +# cat_arr = xr.coding.variables.CategoricalArray(codes=codes, categories=categories) +# v = xr.Variable(("time,"), cat_arr, fastpath=True) +# ds = xr.Dataset({'cloud': v}) +# ds.to_zarr('test.zarr') diff --git a/xarray/coding/times.py b/xarray/coding/times.py index f54966dc39a..0337e463c3a 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd +from pandas.api.types import is_extension_array_dtype from pandas.errors import OutOfBoundsDatetime, OutOfBoundsTimedelta from xarray.coding.variables import ( @@ -967,9 +968,10 @@ def __init__(self, use_cftime: bool | None = None) -> None: self.use_cftime = use_cftime def encode(self, variable: Variable, name: T_Name = None) -> Variable: - if np.issubdtype( - variable.data.dtype, np.datetime64 - ) or contains_cftime_datetimes(variable): + if (not is_extension_array_dtype(variable.data)) and ( + np.issubdtype(variable.data.dtype, np.datetime64) + or contains_cftime_datetimes(variable) + ): dims, data, attrs, encoding = unpack_for_encoding(variable) units = encoding.pop("units", None) @@ -1007,7 +1009,9 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: class CFTimedeltaCoder(VariableCoder): def encode(self, variable: Variable, name: T_Name = None) -> Variable: - if np.issubdtype(variable.data.dtype, np.timedelta64): + if (not is_extension_array_dtype(variable.data)) and np.issubdtype( + variable.data.dtype, np.timedelta64 + ): dims, data, attrs, encoding = unpack_for_encoding(variable) data, units = encode_cf_timedelta( diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index c3d57ad1903..1d27084b97f 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd +from pandas.api.types import is_extension_array_dtype from xarray.core import dtypes, duck_array_ops, indexing from xarray.core.parallelcompat import get_chunked_array_type @@ -249,7 +250,6 @@ class CFMaskCoder(VariableCoder): def encode(self, variable: Variable, name: T_Name = None): dims, data, attrs, encoding = unpack_for_encoding(variable) - dtype = np.dtype(encoding.get("dtype", data.dtype)) fv = encoding.get("_FillValue") mv = encoding.get("missing_value") @@ -267,6 +267,8 @@ def encode(self, variable: Variable, name: T_Name = None): # special case DateTime to properly handle NaT is_time_like = _is_time_like(attrs.get("units")) + dtype = np.dtype(encoding.get("dtype", data.dtype)) + if fv_exists: # Ensure _FillValue is cast to same dtype as data's encoding["_FillValue"] = dtype.type(fv) @@ -471,16 +473,18 @@ class DefaultFillvalueCoder(VariableCoder): def encode(self, variable: Variable, name: T_Name = None) -> Variable: dims, data, attrs, encoding = unpack_for_encoding(variable) + has_no_fill = "_FillValue" not in attrs and "_FillValue" not in encoding # make NaN the fill value for float types - if ( - "_FillValue" not in attrs - and "_FillValue" not in encoding - and np.issubdtype(variable.dtype, np.floating) - ): + if is_extension_array_dtype(data): + if not has_no_fill: + raise ValueError( + "Found _FillValue encoding or attr on extension array." + ) + return variable + if has_no_fill and np.issubdtype(variable.dtype, np.floating): attrs["_FillValue"] = variable.dtype.type(np.nan) return Variable(dims, data, attrs, encoding, fastpath=True) - else: - return variable + return variable def decode(self, variable: Variable, name: T_Name = None) -> Variable: raise NotImplementedError() diff --git a/xarray/conventions.py b/xarray/conventions.py index 1d8e81e1bf2..46c7270ad9d 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +from pandas.api.types import is_extension_array_dtype from xarray.coding import strings, times, variables from xarray.coding.variables import SerializationWarning, pop_to @@ -110,7 +111,10 @@ def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: dims, data, attrs, encoding = variables.unpack_for_encoding(var) # leave vlen dtypes unchanged - if strings.check_vlen_dtype(data.dtype) is not None: + if ( + is_extension_array_dtype(data) + or strings.check_vlen_dtype(data.dtype) is not None + ): return var if is_duck_dask_array(data): @@ -352,9 +356,9 @@ def _update_bounds_encoding(variables: T_Variables) -> None: attrs = v.attrs encoding = v.encoding has_date_units = "units" in encoding and "since" in encoding["units"] - is_datetime_type = np.issubdtype( - v.dtype, np.datetime64 - ) or contains_cftime_datetimes(v) + is_datetime_type = (not is_extension_array_dtype(v)) and ( + contains_cftime_datetimes(v) or np.issubdtype(v.dtype, np.datetime64) + ) if ( is_datetime_type diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 7fc689eef7b..c7ae4935e34 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -24,6 +24,7 @@ from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload import numpy as np +from pandas.api.types import is_extension_array_dtype # remove once numpy 2.0 is the oldest supported version try: @@ -6843,6 +6844,7 @@ def reduce( # that don't have the reduce dims: PR5393 not reduce_dims or not numeric_only + or not is_extension_array_dtype(var.dtype) or np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_) ): diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index b30ba4c3a78..c9b92095cc4 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -11,6 +11,7 @@ import warnings from functools import partial from importlib import import_module +from pandas.api.types import is_extension_array_dtype import numpy as np import pandas as pd @@ -52,6 +53,71 @@ dask_available = module_available("dask") + +HANDLED_FUNCTIONS = {} + + +class ExtensionDuckArray: + def __init__(self, array: pd.api.extensions.ExtensionArray | ExtensionDuckArray): + if isinstance(array, ExtensionDuckArray): + self.extension_array = array.extension_array + elif is_extension_array_dtype(array): + self.extension_array = array + else: + raise TypeError(f"{array} is not an pandas ExtensionArray.") + + def __array_function__(self, func, types, args, kwargs): + # if not all(issubclass(t, ExtensionDuckArray) for t in types): + # return NotImplemented + args_as_list = list(args) + for index, value in enumerate(args_as_list): + if isinstance(value, ExtensionDuckArray): + args_as_list[index] = value.extension_array + if isinstance(value, list) and index == 0: + for sub_index, sub_value in enumerate(args_as_list[index]): + if isinstance(sub_value, ExtensionDuckArray): + args_as_list[index][sub_index] = sub_value.extension_array + args_as_list[0] = tuple(args_as_list[0]) + if func not in HANDLED_FUNCTIONS: + return func(*tuple(args_as_list), **kwargs) + return HANDLED_FUNCTIONS[func](*tuple(args_as_list), **kwargs) + + def __array_ufunc__(ufunc, method, *inputs, **kwargs): + return ufunc(*inputs, **kwargs) + + def __repr__(self): + return f"ExtensionDuckArray(array={repr(self.extension_array)})" + + def __getattr__(self, attr: str) -> object: + if hasattr(self.extension_array, attr): + return getattr(self.extension_array, attr) + raise AttributeError(f"{attr} not found.") + + def __getitem__(self, key): + return self.extension_array[key] + + def __setitem__(self, key): + return self.extension_array[key] + + def __eq__(self, other): + if isinstance(other, ExtensionDuckArray): + return self.extension_array == other.extension_array + return self.extension_array == other + +def implements(numpy_function): + """Register an __array_function__ implementation for MyArray objects.""" + + def decorator(func): + HANDLED_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +@implements(np.concatenate) +def concatenate(arrays, axis=0, out=None): + return ExtensionDuckArray(type(arrays[0])._concat_same_type(arrays)) + def get_array_namespace(x): if hasattr(x, "__array_namespace__"): return x.__array_namespace__() @@ -154,7 +220,7 @@ def isnull(data): return full_like(data, dtype=bool, fill_value=False) else: # at this point, array should have dtype=object - if isinstance(data, np.ndarray): + if isinstance(data, np.ndarray) or isinstance(data, ExtensionDuckArray): return pandas_isnull(data) else: # Not reachable yet, but intended for use with other duck array @@ -226,6 +292,16 @@ def as_shared_dtype(scalars_or_arrays, xp=np): import cupy as cp arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] + are_extension_arrays = ( + isinstance(x, ExtensionDuckArray) for x in scalars_or_arrays + ) + if all(are_extension_arrays): + first_type = type(scalars_or_arrays[0].extension_array) + if all( + isinstance(x.extension_array, first_type) + for x in scalars_or_arrays + ): + return scalars_or_arrays else: arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] # Pass arrays directly instead of dtypes to result_type so scalars diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 32323bb2e64..e6f670ec045 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -13,6 +13,7 @@ import numpy as np import pandas as pd from numpy.typing import ArrayLike +from pandas.api.types import is_extension_array_dtype import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils @@ -51,6 +52,7 @@ NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( indexing.ExplicitlyIndexed, pd.Index, + pd.api.extensions.ExtensionArray, ) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) @@ -171,6 +173,8 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) + if is_extension_array_dtype(data): + return duck_array_ops.ExtensionDuckArray(data) return data From 47bddd2719f61749496722eac07a49033a13f0de Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 2 Feb 2024 14:41:59 +0100 Subject: [PATCH 02/70] (feat): categorical tests + functionality --- xarray/core/duck_array_ops.py | 51 ++++++++++++++++++++++++++++------- xarray/tests/__init__.py | 16 +++++++++-- xarray/tests/test_concat.py | 28 ++++++++++++++++--- xarray/tests/test_groupby.py | 1 + xarray/tests/test_merge.py | 2 +- 5 files changed, 82 insertions(+), 16 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index c9b92095cc4..80fefd09d30 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -11,7 +11,6 @@ import warnings from functools import partial from importlib import import_module -from pandas.api.types import is_extension_array_dtype import numpy as np import pandas as pd @@ -32,6 +31,7 @@ from numpy import concatenate as _concatenate from numpy.lib.stride_tricks import sliding_window_view # noqa from packaging.version import Version +from pandas.api.types import is_extension_array_dtype from xarray.core import dask_array_ops, dtypes, nputils, pycompat from xarray.core.options import OPTIONS @@ -53,7 +53,6 @@ dask_available = module_available("dask") - HANDLED_FUNCTIONS = {} @@ -92,18 +91,19 @@ def __getattr__(self, attr: str) -> object: if hasattr(self.extension_array, attr): return getattr(self.extension_array, attr) raise AttributeError(f"{attr} not found.") - + def __getitem__(self, key): return self.extension_array[key] - + def __setitem__(self, key): return self.extension_array[key] - + def __eq__(self, other): if isinstance(other, ExtensionDuckArray): return self.extension_array == other.extension_array return self.extension_array == other + def implements(numpy_function): """Register an __array_function__ implementation for MyArray objects.""" @@ -115,9 +115,43 @@ def decorator(func): @implements(np.concatenate) -def concatenate(arrays, axis=0, out=None): +def _(arrays, axis=0, out=None): return ExtensionDuckArray(type(arrays[0])._concat_same_type(arrays)) + +@implements(np.where) +def _(condition, x, y): + if all(is_extension_array_dtype(array) for array in [x, y]): + # set up new codes array + new_codes = np.where(condition, x.codes, y.codes) + + # Remap shared categories to have same codes + shared_categories = [ + category for category in x.categories if category in set(y.categories) + ] + for shared_category in shared_categories: + new_codes[~condition & (y == shared_category)] = x.codes[ + x == shared_category + ][0] + + # map non-shared y codes to start from the lowest possible number + y_only_categories = [ + category for category in y.categories if category not in shared_categories + ] + new_y_code = len(x.categories) + len(shared_categories) + for y_only_category in y_only_categories: + new_codes[~condition & (y == y_only_category)] = new_y_code + new_y_code += 1 + new_categories = shared_categories + y_only_categories # preserve order + # TODO: think about ordering? + return ExtensionDuckArray( + pd.Categorical.from_codes( + new_codes, categories=new_categories, ordered=False + ) + ) + return np.where(condition, x, y) + + def get_array_namespace(x): if hasattr(x, "__array_namespace__"): return x.__array_namespace__() @@ -297,10 +331,7 @@ def as_shared_dtype(scalars_or_arrays, xp=np): ) if all(are_extension_arrays): first_type = type(scalars_or_arrays[0].extension_array) - if all( - isinstance(x.extension_array, first_type) - for x in scalars_or_arrays - ): + if all(isinstance(x.extension_array, first_type) for x in scalars_or_arrays): return scalars_or_arrays else: arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 207caba48f0..f9320801f48 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -289,6 +289,7 @@ def create_test_data( seed: int | None = None, add_attrs: bool = True, dim_sizes: tuple[int, int, int] = _DEFAULT_TEST_DIM_SIZES, + use_extension_array: bool = False, ) -> Dataset: rs = np.random.RandomState(seed) _vars = { @@ -311,14 +312,25 @@ def create_test_data( obj[v] = (dims, data) if add_attrs: obj[v].attrs = {"foo": "variable"} - + if use_extension_array: + obj["var4"] = ( + "dim1", + pd.Categorical( + np.random.choice( + list(string.ascii_lowercase[: np.random.randint(5)]), + size=dim_sizes[0], + ) + ), + ) if dim_sizes == _DEFAULT_TEST_DIM_SIZES: numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") else: numbers_values = np.random.randint(0, 3, _dims["dim3"], dtype="int64") obj.coords["numbers"] = ("dim3", numbers_values) obj.encoding = {"foo": "bar"} - assert all(obj.data.flags.writeable for obj in obj.variables.values()) + assert all( + obj.data.flags.writeable for k, obj in obj.variables.items() if k != "var4" + ) return obj diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 0cf4cc03a09..885692cc918 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -152,6 +152,22 @@ def test_concat_missing_var() -> None: assert_identical(actual, expected) +def test_concat_categorical() -> None: + data1 = create_test_data(use_extension_array=True) + data2 = create_test_data(use_extension_array=True) + data2["var4"] + concatenated = concat([data1, data2], dim="dim1") + assert ( + concatenated["var4"] + == type(data2["var4"].variable.data.extension_array)._concat_same_type( + [ + data1["var4"].variable.data.extension_array, + data2["var4"].variable.data.extension_array, + ] + ) + ).all() + + def test_concat_missing_multiple_consecutive_var() -> None: datasets = create_concat_datasets(3, seed=123) expected = concat(datasets, dim="day") @@ -451,8 +467,11 @@ def test_concat_fill_missing_variables( class TestConcatDataset: @pytest.fixture - def data(self) -> Dataset: - return create_test_data().drop_dims("dim3") + def data(self, request) -> Dataset: + use_extension_array = request.param if hasattr(request, "param") else False + return create_test_data(use_extension_array=use_extension_array).drop_dims( + "dim3" + ) def rectify_dim_order(self, data, dataset) -> Dataset: # return a new dataset with all variable dimensions transposed into @@ -464,7 +483,9 @@ def rectify_dim_order(self, data, dataset) -> Dataset: ) @pytest.mark.parametrize("coords", ["different", "minimal"]) - @pytest.mark.parametrize("dim", ["dim1", "dim2"]) + @pytest.mark.parametrize( + "dim,data", [["dim1", True], ["dim2", False]], indirect=["data"] + ) def test_concat_simple(self, data, dim, coords) -> None: datasets = [g for _, g in data.groupby(dim, squeeze=False)] assert_identical(data, concat(datasets, dim, coords=coords)) @@ -492,6 +513,7 @@ def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: expected = data.copy().assign(foo=(["dim1", "bar"], foo)) assert_identical(expected, actual) + @pytest.mark.parametrize("data", [False], indirect=["data"]) def test_concat_2(self, data) -> None: dim = "dim2" datasets = [g.squeeze(dim) for _, g in data.groupby(dim, squeeze=False)] diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 25fabd5e2b9..1e39266032d 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -35,6 +35,7 @@ def dataset(): { "foo": (("x", "y", "z"), np.random.randn(3, 4, 2)), "baz": ("x", ["e", "f", "g"]), + "cat": ("x", pd.Categorical(["cat1", "cat2", "cat2"])), }, {"x": ("x", ["a", "b", "c"], {"name": "x"}), "y": [1, 2, 3, 4], "z": [1, 2]}, ) diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index c6597d5abb0..52935e9714e 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -37,7 +37,7 @@ def test_merge_arrays(self): assert_identical(actual, expected) def test_merge_datasets(self): - data = create_test_data(add_attrs=False) + data = create_test_data(add_attrs=False, use_extension_array=True) actual = xr.merge([data[["var1"]], data[["var2"]]]) expected = data[["var1", "var2"]] From dc8b788e0ff080965a4b227ef037873114c17252 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 5 Feb 2024 13:53:29 +0100 Subject: [PATCH 03/70] (feat): use multiple dispatch for unimplemented ops --- pyproject.toml | 1 + xarray/core/duck_array_ops.py | 111 ++++++++++++++++++++-------------- xarray/core/variable.py | 7 ++- xarray/tests/test_concat.py | 1 - 4 files changed, 74 insertions(+), 46 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 26132031da1..13f67f2767e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"] parallel = ["dask[complete]"] viz = ["matplotlib", "seaborn", "nc-time-axis"] +extension_arrays = ["plum"] [project.urls] Documentation = "https://docs.xarray.dev" diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 80fefd09d30..fcf8e5bf5dc 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -9,8 +9,10 @@ import datetime import inspect import warnings +from collections.abc import Iterable from functools import partial from importlib import import_module +from typing import Generic, TypeVar import numpy as np import pandas as pd @@ -32,6 +34,7 @@ from numpy.lib.stride_tricks import sliding_window_view # noqa from packaging.version import Version from pandas.api.types import is_extension_array_dtype +from plum import dispatch # noqa from xarray.core import dask_array_ops, dtypes, nputils, pycompat from xarray.core.options import OPTIONS @@ -55,9 +58,13 @@ HANDLED_FUNCTIONS = {} +T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray) -class ExtensionDuckArray: - def __init__(self, array: pd.api.extensions.ExtensionArray | ExtensionDuckArray): + +class ExtensionDuckArray(Generic[T_ExtensionArray]): + extension_array: T_ExtensionArray + + def __init__(self, array: T_ExtensionArray | ExtensionDuckArray): if isinstance(array, ExtensionDuckArray): self.extension_array = array.extension_array elif is_extension_array_dtype(array): @@ -68,18 +75,25 @@ def __init__(self, array: pd.api.extensions.ExtensionArray | ExtensionDuckArray) def __array_function__(self, func, types, args, kwargs): # if not all(issubclass(t, ExtensionDuckArray) for t in types): # return NotImplemented - args_as_list = list(args) - for index, value in enumerate(args_as_list): - if isinstance(value, ExtensionDuckArray): - args_as_list[index] = value.extension_array - if isinstance(value, list) and index == 0: - for sub_index, sub_value in enumerate(args_as_list[index]): - if isinstance(sub_value, ExtensionDuckArray): - args_as_list[index][sub_index] = sub_value.extension_array - args_as_list[0] = tuple(args_as_list[0]) + def replace_duck_with_extension_array(args) -> list: + args_as_list = list(args) + for index, value in enumerate(args_as_list): + if isinstance(value, ExtensionDuckArray): + args_as_list[index] = value.extension_array + elif isinstance( + value, tuple + ): # should handle more than just tuple? iterable? + args_as_list[index] = tuple( + replace_duck_with_extension_array(value) + ) + elif isinstance(value, list): + args_as_list[index] = replace_duck_with_extension_array(value) + return args_as_list + + args = tuple(replace_duck_with_extension_array(args)) if func not in HANDLED_FUNCTIONS: - return func(*tuple(args_as_list), **kwargs) - return HANDLED_FUNCTIONS[func](*tuple(args_as_list), **kwargs) + return func(*args, **kwargs) + return HANDLED_FUNCTIONS[func](*args, **kwargs) def __array_ufunc__(ufunc, method, *inputs, **kwargs): return ufunc(*inputs, **kwargs) @@ -115,43 +129,52 @@ def decorator(func): @implements(np.concatenate) -def _(arrays, axis=0, out=None): - return ExtensionDuckArray(type(arrays[0])._concat_same_type(arrays)) +@dispatch +def __extension_duck_array__concatenate( + arrays: Iterable[T_ExtensionArray], axis: int = 0, out=None +): + return ExtensionDuckArray[type(arrays[0])]( + type(arrays[0])._concat_same_type(arrays) + ) @implements(np.where) -def _(condition, x, y): - if all(is_extension_array_dtype(array) for array in [x, y]): - # set up new codes array - new_codes = np.where(condition, x.codes, y.codes) - - # Remap shared categories to have same codes - shared_categories = [ - category for category in x.categories if category in set(y.categories) - ] - for shared_category in shared_categories: - new_codes[~condition & (y == shared_category)] = x.codes[ - x == shared_category - ][0] - - # map non-shared y codes to start from the lowest possible number - y_only_categories = [ - category for category in y.categories if category not in shared_categories - ] - new_y_code = len(x.categories) + len(shared_categories) - for y_only_category in y_only_categories: - new_codes[~condition & (y == y_only_category)] = new_y_code - new_y_code += 1 - new_categories = shared_categories + y_only_categories # preserve order - # TODO: think about ordering? - return ExtensionDuckArray( - pd.Categorical.from_codes( - new_codes, categories=new_categories, ordered=False - ) - ) +@dispatch +def __extension_duck_array__where(condition: np.ndarray, x, y): return np.where(condition, x, y) +@dispatch +def __extension_duck_array__where( + condition: np.ndarray, x: pd.Categorical, y: pd.Categorical +): + # set up new codes array + new_codes = np.where(condition, x.codes, y.codes) + + # Remap shared categories to have same codes + shared_categories = [ + category for category in x.categories if category in set(y.categories) + ] + for shared_category in shared_categories: + new_codes[~condition & (y == shared_category)] = x.codes[x == shared_category][ + 0 + ] + + # map non-shared y codes to start from the lowest possible number + y_only_categories = [ + category for category in y.categories if category not in shared_categories + ] + new_y_code = len(x.categories) + len(shared_categories) + for y_only_category in y_only_categories: + new_codes[~condition & (y == y_only_category)] = new_y_code + new_y_code += 1 + new_categories = shared_categories + y_only_categories # preserve order + # TODO: think about ordering? + return ExtensionDuckArray( + pd.Categorical.from_codes(new_codes, categories=new_categories, ordered=False) + ) + + def get_array_namespace(x): if hasattr(x, "__array_namespace__"): return x.__array_namespace__() diff --git a/xarray/core/variable.py b/xarray/core/variable.py index e6f670ec045..c775049125d 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -174,7 +174,12 @@ def _maybe_wrap_data(data): if isinstance(data, pd.Index): return PandasIndexingAdapter(data) if is_extension_array_dtype(data): - return duck_array_ops.ExtensionDuckArray(data) + data_type = ( + type(data.extension_array) + if isinstance(data, duck_array_ops.ExtensionDuckArray) + else type(data) + ) + return duck_array_ops.ExtensionDuckArray[data_type](data) return data diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 885692cc918..6ec1ab3b93a 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -155,7 +155,6 @@ def test_concat_missing_var() -> None: def test_concat_categorical() -> None: data1 = create_test_data(use_extension_array=True) data2 = create_test_data(use_extension_array=True) - data2["var4"] concatenated = concat([data1, data2], dim="dim1") assert ( concatenated["var4"] From 75524c8649d7685929b829597240a92650af9f0d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 5 Feb 2024 14:31:40 +0100 Subject: [PATCH 04/70] (feat): implement (not really) broadcasting --- xarray/core/duck_array_ops.py | 6 ++++++ xarray/tests/test_groupby.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index fcf8e5bf5dc..cb3c0d1b239 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -128,6 +128,12 @@ def decorator(func): return decorator +@implements(np.broadcast_to) +@dispatch +def __extension_duck_array__broadcast(arr: pd.Categorical, shape: tuple): + raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") + + @implements(np.concatenate) @dispatch def __extension_duck_array__concatenate( diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 1e39266032d..373a2f97126 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -35,7 +35,7 @@ def dataset(): { "foo": (("x", "y", "z"), np.random.randn(3, 4, 2)), "baz": ("x", ["e", "f", "g"]), - "cat": ("x", pd.Categorical(["cat1", "cat2", "cat2"])), + "cat": ("y", pd.Categorical(["cat1", "cat2", "cat2", "cat1"])), }, {"x": ("x", ["a", "b", "c"], {"name": "x"}), "y": [1, 2, 3, 4], "z": [1, 2]}, ) From c9ab4525963aa54baf0fb9ea8ac37ac4cb4f6431 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 5 Feb 2024 14:31:52 +0100 Subject: [PATCH 05/70] (chore): add more `groupby` tests --- xarray/tests/test_groupby.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 373a2f97126..8224752b20a 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -714,6 +714,8 @@ def test_groupby_getitem(dataset) -> None: assert_identical(dataset.foo.sel(x="a"), dataset.foo.groupby("x")["a"]) with pytest.warns(UserWarning, match="The `squeeze` kwarg"): assert_identical(dataset.foo.sel(z=1), dataset.foo.groupby("z")[1]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.cat.sel(y=1), dataset.cat.groupby("y")[1]) assert_identical(dataset.sel(x=["a"]), dataset.groupby("x", squeeze=False)["a"]) assert_identical(dataset.sel(z=[1]), dataset.groupby("z", squeeze=False)[1]) @@ -723,6 +725,12 @@ def test_groupby_getitem(dataset) -> None: ) assert_identical(dataset.foo.sel(z=[1]), dataset.foo.groupby("z", squeeze=False)[1]) + assert_identical(dataset.cat.sel(y=[1]), dataset.cat.groupby("y", squeeze=False)[1]) + with pytest.raises( + NotImplementedError, match="Cannot broadcast 1d-only pandas categorical array." + ): + dataset.groupby("boo", squeeze=False) + dataset = dataset.drop_vars(["cat"]) actual = ( dataset.groupby("boo", squeeze=False)["f"].unstack().transpose("x", "y", "z") ) From 1f3d0fa9ad1354e1c0eff3b518e5ad1ec3a5bfe8 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 5 Feb 2024 15:30:12 +0100 Subject: [PATCH 06/70] (fix): fix more groupby incompatibility --- xarray/tests/test_groupby.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 8224752b20a..5cbd847c790 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -78,6 +78,7 @@ def test_groupby_dims_property(dataset, recwarn) -> None: ) assert len(recwarn) == 0 + dataset = dataset.drop_vars(["cat"]) stacked = dataset.stack({"xy": ("x", "y")}) assert tuple(stacked.groupby("xy", squeeze=False).dims) == tuple( stacked.isel(xy=[0]).dims @@ -90,7 +91,7 @@ def test_groupby_sizes_property(dataset) -> None: assert dataset.groupby("x").sizes == dataset.isel(x=1).sizes with pytest.warns(UserWarning, match="The `squeeze` kwarg"): assert dataset.groupby("y").sizes == dataset.isel(y=1).sizes - + dataset = dataset.drop("cat") stacked = dataset.stack({"xy": ("x", "y")}) with pytest.warns(UserWarning, match="The `squeeze` kwarg"): assert stacked.groupby("xy").sizes == stacked.isel(xy=0).sizes From 8a70e3cda505b44a972441cf80f84d5d35ba8242 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 5 Feb 2024 16:54:03 +0100 Subject: [PATCH 07/70] (bug): fix unused categories --- xarray/core/duck_array_ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index cb3c0d1b239..8a42034ed36 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -170,7 +170,10 @@ def __extension_duck_array__where( y_only_categories = [ category for category in y.categories if category not in shared_categories ] - new_y_code = len(x.categories) + len(shared_categories) + used_x_only_categories = ( + x[~x.isin(shared_categories)].remove_unused_categories().categories + ) + new_y_code = len(used_x_only_categories) + len(shared_categories) for y_only_category in y_only_categories: new_codes[~condition & (y == y_only_category)] = new_y_code new_y_code += 1 From f5a65050f36605f59fd73c6c88e35f4246aa0654 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 5 Feb 2024 16:54:38 +0100 Subject: [PATCH 08/70] (chore): refactor dispatched methods + tests --- xarray/core/duck_array_ops.py | 30 ++++---- xarray/tests/__init__.py | 1 + xarray/tests/test_extension_array.py | 110 +++++++++++++++++++++++++++ 3 files changed, 128 insertions(+), 13 deletions(-) create mode 100644 xarray/tests/test_extension_array.py diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 8a42034ed36..6d986e27109 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -9,7 +9,7 @@ import datetime import inspect import warnings -from collections.abc import Iterable +from collections.abc import Sequence from functools import partial from importlib import import_module from typing import Generic, TypeVar @@ -93,7 +93,10 @@ def replace_duck_with_extension_array(args) -> list: args = tuple(replace_duck_with_extension_array(args)) if func not in HANDLED_FUNCTIONS: return func(*args, **kwargs) - return HANDLED_FUNCTIONS[func](*args, **kwargs) + res = HANDLED_FUNCTIONS[func](*args, **kwargs) + if is_extension_array_dtype(res): + return ExtensionDuckArray[type(res)](res) + return res def __array_ufunc__(ufunc, method, *inputs, **kwargs): return ufunc(*inputs, **kwargs) @@ -117,6 +120,9 @@ def __eq__(self, other): return self.extension_array == other.extension_array return self.extension_array == other + def __ne__(self, other): + return ~(self == other) + def implements(numpy_function): """Register an __array_function__ implementation for MyArray objects.""" @@ -130,18 +136,18 @@ def decorator(func): @implements(np.broadcast_to) @dispatch -def __extension_duck_array__broadcast(arr: pd.Categorical, shape: tuple): +def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): + if shape[0] == len(arr) and len(shape) == 1: + return arr raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") @implements(np.concatenate) @dispatch def __extension_duck_array__concatenate( - arrays: Iterable[T_ExtensionArray], axis: int = 0, out=None + arrays: Sequence[pd.Categorical], axis: int = 0, out=None ): - return ExtensionDuckArray[type(arrays[0])]( - type(arrays[0])._concat_same_type(arrays) - ) + return type(arrays[0])._concat_same_type(arrays) @implements(np.where) @@ -150,10 +156,8 @@ def __extension_duck_array__where(condition: np.ndarray, x, y): return np.where(condition, x, y) -@dispatch -def __extension_duck_array__where( - condition: np.ndarray, x: pd.Categorical, y: pd.Categorical -): +@__extension_duck_array__where.dispatch +def _(condition: np.ndarray, x: pd.Categorical, y: pd.Categorical): # set up new codes array new_codes = np.where(condition, x.codes, y.codes) @@ -179,8 +183,8 @@ def __extension_duck_array__where( new_y_code += 1 new_categories = shared_categories + y_only_categories # preserve order # TODO: think about ordering? - return ExtensionDuckArray( - pd.Categorical.from_codes(new_codes, categories=new_categories, ordered=False) + return pd.Categorical.from_codes( + new_codes, categories=new_categories, ordered=False ) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index f9320801f48..045ea5e7599 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -95,6 +95,7 @@ def _importorskip( has_fsspec, requires_fsspec = _importorskip("fsspec") has_iris, requires_iris = _importorskip("iris") has_numbagg, requires_numbagg = _importorskip("numbagg", "0.4.0") +has_plum, requires_plum = _importorskip("plum") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", diff --git a/xarray/tests/test_extension_array.py b/xarray/tests/test_extension_array.py new file mode 100644 index 00000000000..cf7aca2eb37 --- /dev/null +++ b/xarray/tests/test_extension_array.py @@ -0,0 +1,110 @@ +from collections.abc import Sequence + +import numpy as np +import pandas as pd +import pytest + +from xarray.core.duck_array_ops import ( + ExtensionDuckArray, + __extension_duck_array__broadcast, + __extension_duck_array__concatenate, + __extension_duck_array__where, +) +from xarray.tests import requires_plum + + +@pytest.fixture +def categorical1(): + return pd.Categorical(["cat1", "cat2", "cat2", "cat1", "cat2"]) + + +@pytest.fixture +def categorical2(): + return pd.Categorical(["cat2", "cat1", "cat2", "cat3", "cat1"]) + + +@pytest.fixture +def int1(): + return pd.arrays.IntegerArray( + np.array([1, 2, 3, 4, 5]), np.array([True, False, False, True, True]) + ) + + +@pytest.fixture +def int2(): + return pd.arrays.IntegerArray( + np.array([6, 7, 8, 9, 10]), np.array([True, True, False, True, False]) + ) + + +@__extension_duck_array__concatenate.dispatch +def _(arrays: Sequence[pd.arrays.IntegerArray], axis: int = 0, out=None): + values = np.concatenate(arrays) + mask = np.isnan(values) + values = values.astype("int8") + return pd.arrays.IntegerArray(values, mask) + + +@requires_plum +def test_where_all_categoricals(categorical1, categorical2): + assert ( + __extension_duck_array__where( + np.array([True, False, True, False, False]), categorical1, categorical2 + ) + == pd.Categorical(["cat1", "cat1", "cat2", "cat3", "cat1"]) + ).all() + + +@requires_plum +def test_where_drop_categoricals(categorical1, categorical2): + assert ( + __extension_duck_array__where( + np.array([False, True, True, False, True]), categorical1, categorical2 + ).remove_unused_categories() + == pd.Categorical(["cat2", "cat2", "cat2", "cat3", "cat2"]) + ).all() + + +@requires_plum +def test_broadcast_to_categorical(categorical1): + with pytest.raises(NotImplementedError): + __extension_duck_array__broadcast(categorical1, (5, 2)) + + +@requires_plum +def test_broadcast_to_same_categorical(categorical1): + assert (__extension_duck_array__broadcast(categorical1, (5,)) == categorical1).all() + + +@requires_plum +def test_concategorical_categorical(categorical1, categorical2): + assert ( + __extension_duck_array__concatenate([categorical1, categorical2]) + == type(categorical1)._concat_same_type((categorical1, categorical2)) + ).all() + + +@requires_plum +def test_integer_array_register_concatenate(int1, int2): + assert ( + __extension_duck_array__concatenate([int1, int2]) + == type(int1)._concat_same_type((int1, int2)) + ).all() + + +def test_duck_extension_array_equality(categorical1, int1): + int_duck_array = ExtensionDuckArray(int1) + categorical_duck_array = ExtensionDuckArray(categorical1) + assert (int_duck_array != categorical_duck_array).all() + assert (categorical_duck_array == categorical1).all() + assert (int1[0:2] == int_duck_array[0:2]).all() + + +def test_duck_extension_array_repr(int1): + int_duck_array = ExtensionDuckArray(int1) + assert repr(int1) in repr(int_duck_array) + + +def test_duck_extension_array_attr(int1): + int_duck_array = ExtensionDuckArray(int1) + assert (~int_duck_array.fillna(10)).all() From 08a4feb6e9140fa4a1b6dd32ca10fc5eceb77cd9 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 6 Feb 2024 10:04:01 +0100 Subject: [PATCH 09/70] (fix): shared type should check for extension arrays first and then fall back to numpy --- xarray/core/duck_array_ops.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6d986e27109..82d44d3dc8b 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -355,20 +355,23 @@ def asarray(data, xp=np): def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - array_type_cupy = array_type("cupy") - if array_type_cupy and any( - isinstance(x, array_type_cupy) for x in scalars_or_arrays + if any(isinstance(x, ExtensionDuckArray) for x in scalars_or_arrays): + extension_array_types = [ + type(x.extension_array) + for x in scalars_or_arrays + if isinstance(x, ExtensionDuckArray) + ] + if len(extension_array_types) == len(scalars_or_arrays) and all( + isinstance(x, extension_array_types[0]) for x in extension_array_types + ): + return scalars_or_arrays + arrays = [asarray(np.array(x), xp=xp) for x in scalars_or_arrays] + elif array_type_cupy := array_type("cupy") and any( # noqa: F841 + isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821 ): import cupy as cp arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] - are_extension_arrays = ( - isinstance(x, ExtensionDuckArray) for x in scalars_or_arrays - ) - if all(are_extension_arrays): - first_type = type(scalars_or_arrays[0].extension_array) - if all(isinstance(x.extension_array, first_type) for x in scalars_or_arrays): - return scalars_or_arrays else: arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] # Pass arrays directly instead of dtypes to result_type so scalars From d5b218bc05f047d29fd11bf3e7efe8ead7a0b752 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 6 Feb 2024 10:07:07 +0100 Subject: [PATCH 10/70] (refactor): tests moved --- xarray/tests/test_duck_array_ops.py | 103 +++++++++++++++++++++++++ xarray/tests/test_extension_array.py | 110 --------------------------- 2 files changed, 103 insertions(+), 110 deletions(-) delete mode 100644 xarray/tests/test_extension_array.py diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 7757ec58edc..c134a0d2372 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -2,6 +2,7 @@ import datetime as dt import warnings +from collections.abc import Sequence import numpy as np import pandas as pd @@ -11,6 +12,10 @@ from xarray import DataArray, Dataset, cftime_range, concat from xarray.core import dtypes, duck_array_ops from xarray.core.duck_array_ops import ( + ExtensionDuckArray, + __extension_duck_array__broadcast, + __extension_duck_array__concatenate, + __extension_duck_array__where, array_notnull_equiv, concatenate, count, @@ -38,11 +43,44 @@ requires_bottleneck, requires_cftime, requires_dask, + requires_plum, ) dask_array_type = array_type("dask") +@pytest.fixture +def categorical1(): + return pd.Categorical(["cat1", "cat2", "cat2", "cat1", "cat2"]) + + +@pytest.fixture +def categorical2(): + return pd.Categorical(["cat2", "cat1", "cat2", "cat3", "cat1"]) + + +@pytest.fixture +def int1(): + return pd.arrays.IntegerArray( + np.array([1, 2, 3, 4, 5]), np.array([True, False, False, True, True]) + ) + + +@pytest.fixture +def int2(): + return pd.arrays.IntegerArray( + np.array([6, 7, 8, 9, 10]), np.array([True, True, False, True, False]) + ) + + +@__extension_duck_array__concatenate.dispatch +def _(arrays: Sequence[pd.arrays.IntegerArray], axis: int = 0, out=None): + values = np.concatenate(arrays) + mask = np.isnan(values) + values = values.astype("int8") + return pd.arrays.IntegerArray(values, mask) + + class TestOps: @pytest.fixture(autouse=True) def setUp(self): @@ -932,3 +970,68 @@ def test_push_dask(): dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=n ) np.testing.assert_equal(actual, expected) + + +@requires_plum +def test_where_all_categoricals(categorical1, categorical2): + assert ( + __extension_duck_array__where( + np.array([True, False, True, False, False]), categorical1, categorical2 + ) + == pd.Categorical(["cat1", "cat1", "cat2", "cat3", "cat1"]) + ).all() + + +@requires_plum +def test_where_drop_categoricals(categorical1, categorical2): + assert ( + __extension_duck_array__where( + np.array([False, True, True, False, True]), categorical1, categorical2 + ).remove_unused_categories() + == pd.Categorical(["cat2", "cat2", "cat2", "cat3", "cat2"]) + ).all() + + +@requires_plum +def test_broadcast_to_categorical(categorical1): + with pytest.raises(NotImplementedError): + __extension_duck_array__broadcast(categorical1, (5, 2)) + + +@requires_plum +def test_broadcast_to_same_categorical(categorical1): + assert (__extension_duck_array__broadcast(categorical1, (5,)) == categorical1).all() + + +@requires_plum +def test_concategorical_categorical(categorical1, categorical2): + assert ( + __extension_duck_array__concatenate([categorical1, categorical2]) + == type(categorical1)._concat_same_type((categorical1, categorical2)) + ).all() + + +@requires_plum +def test_integer_array_register_concatenate(int1, int2): + assert ( + __extension_duck_array__concatenate([int1, int2]) + == type(int1)._concat_same_type((int1, int2)) + ).all() + + +def test_duck_extension_array_equality(categorical1, int1): + int_duck_array = ExtensionDuckArray(int1) + categorical_duck_array = ExtensionDuckArray(categorical1) + assert (int_duck_array != categorical_duck_array).all() + assert (categorical_duck_array == categorical1).all() + assert (int1[0:2] == int_duck_array[0:2]).all() + + +def test_duck_extension_array_repr(int1): + int_duck_array = ExtensionDuckArray(int1) + assert repr(int1) in repr(int_duck_array) + + +def test_duck_extension_array_attr(int1): + int_duck_array = ExtensionDuckArray(int1) + assert (~int_duck_array.fillna(10)).all() diff --git a/xarray/tests/test_extension_array.py b/xarray/tests/test_extension_array.py deleted file mode 100644 index cf7aca2eb37..00000000000 --- a/xarray/tests/test_extension_array.py +++ /dev/null @@ -1,110 +0,0 @@ -from collections.abc import Sequence - -import numpy as np -import pandas as pd -import pytest - -from xarray.core.duck_array_ops import ( - ExtensionDuckArray, - __extension_duck_array__broadcast, - __extension_duck_array__concatenate, - __extension_duck_array__where, -) -from xarray.tests import requires_plum - - -@pytest.fixture -def categorical1(): - return pd.Categorical(["cat1", "cat2", "cat2", "cat1", "cat2"]) - - -@pytest.fixture -def categorical2(): - return pd.Categorical(["cat2", "cat1", "cat2", "cat3", "cat1"]) - - -@pytest.fixture -def int1(): - return pd.arrays.IntegerArray( - np.array([1, 2, 3, 4, 5]), np.array([True, False, False, True, True]) - ) - - -@pytest.fixture -def int2(): - return pd.arrays.IntegerArray( - np.array([6, 7, 8, 9, 10]), np.array([True, True, False, True, False]) - ) - - -@__extension_duck_array__concatenate.dispatch -def _(arrays: Sequence[pd.arrays.IntegerArray], axis: int = 0, out=None): - values = np.concatenate(arrays) - mask = np.isnan(values) - values = values.astype("int8") - return pd.arrays.IntegerArray(values, mask) - - -@requires_plum -def test_where_all_categoricals(categorical1, categorical2): - assert ( - __extension_duck_array__where( - np.array([True, False, True, False, False]), categorical1, categorical2 - ) - == pd.Categorical(["cat1", "cat1", "cat2", "cat3", "cat1"]) - ).all() - - -@requires_plum -def test_where_drop_categoricals(categorical1, categorical2): - assert ( - __extension_duck_array__where( - np.array([False, True, True, False, True]), categorical1, categorical2 - ).remove_unused_categories() - == pd.Categorical(["cat2", "cat2", "cat2", "cat3", "cat2"]) - ).all() - - -@requires_plum -def test_broadcast_to_categorical(categorical1): - with pytest.raises(NotImplementedError): - __extension_duck_array__broadcast(categorical1, (5, 2)) - - -@requires_plum -def test_broadcast_to_same_categorical(categorical1): - assert (__extension_duck_array__broadcast(categorical1, (5,)) == categorical1).all() - - -@requires_plum -def test_concategorical_categorical(categorical1, categorical2): - assert ( - __extension_duck_array__concatenate([categorical1, categorical2]) - == type(categorical1)._concat_same_type((categorical1, categorical2)) - ).all() - - -@requires_plum -def test_integer_array_register_concatenate(int1, int2): - assert ( - __extension_duck_array__concatenate([int1, int2]) - == type(int1)._concat_same_type((int1, int2)) - ).all() - - -def test_duck_extension_array_equality(categorical1, int1): - int_duck_array = ExtensionDuckArray(int1) - categorical_duck_array = ExtensionDuckArray(categorical1) - assert (int_duck_array != categorical_duck_array).all() - assert (categorical_duck_array == categorical1).all() - assert (int1[0:2] == int_duck_array[0:2]).all() - - -def test_duck_extension_array_repr(int1): - int_duck_array = ExtensionDuckArray(int1) - assert repr(int1) in repr(int_duck_array) - - -def test_duck_extension_array_attr(int1): - int_duck_array = ExtensionDuckArray(int1) - assert (~int_duck_array.fillna(10)).all() From 00256fae40cc0417e0830e1439215d694f320e45 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 6 Feb 2024 10:07:23 +0100 Subject: [PATCH 11/70] (chore): more higher level tests --- xarray/tests/test_duck_array_ops.py | 46 +++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index c134a0d2372..54178a736ac 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -157,6 +157,52 @@ def test_where_type_promotion(self): assert result.dtype == np.float32 assert_array_equal(result, np.array([1, np.nan], dtype=np.float32)) + @requires_plum + def test_where_extension_duck_array(self, categorical1, categorical2): + where_res = where( + np.array([True, False, True, False, False]), + ExtensionDuckArray(categorical1), + ExtensionDuckArray(categorical2), + ) + assert isinstance(where_res, ExtensionDuckArray) + assert ( + where_res == pd.Categorical(["cat1", "cat1", "cat2", "cat3", "cat1"]) + ).all() + + def test_where_extension_duck_array_fallback(self, categorical1, categorical2): + where_res = where( + np.array([True, False, True, False, False]), + ExtensionDuckArray(categorical1), + np.array(categorical2), + ) + assert isinstance(where_res, np.ndarray) + assert (where_res == np.array(["cat1", "cat1", "cat2", "cat3", "cat1"])).all() + + @requires_plum + def test_concatenate_extension_duck_array(self, categorical1, categorical2): + concate_res = concatenate( + [ExtensionDuckArray(categorical1), ExtensionDuckArray(categorical2)] + ) + assert isinstance(concate_res, ExtensionDuckArray) + assert ( + concate_res + == type(categorical1)._concat_same_type((categorical1, categorical2)) + ).all() + + def test_concatenate_extension_duck_array_fallback( + self, categorical1, categorical2 + ): + concate_res = concatenate( + [ExtensionDuckArray(categorical1), np.array(categorical2)] + ) + assert isinstance(concate_res, np.ndarray) + assert ( + concate_res + == np.array( + type(categorical1)._concat_same_type((categorical1, categorical2)) + ) + ).all() + def test_stack_type_promotion(self): result = stack([1, "b"]) assert_array_equal(result, np.array([1, "b"], dtype=object)) From b7ddbd6cc1e158789ab538d9fcba93315ecc2f47 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 08:58:40 +0100 Subject: [PATCH 12/70] (feat): to/from dataframe --- xarray/core/dataset.py | 46 +++++++++++++++++++++++++++++------- xarray/tests/test_dataset.py | 30 +++++++++++++++++------ 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c7ae4935e34..b6678875944 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -65,7 +65,7 @@ create_coords_with_default_indexes, ) from xarray.core.daskmanager import DaskManager -from xarray.core.duck_array_ops import datetime_to_numeric +from xarray.core.duck_array_ops import ExtensionDuckArray, datetime_to_numeric from xarray.core.indexes import ( Index, Indexes, @@ -7159,13 +7159,36 @@ def to_pandas(self) -> pd.Series | pd.DataFrame: ) def _to_dataframe(self, ordered_dims: Mapping[Any, int]): - columns = [k for k in self.variables if k not in self.dims] + columns = [ + k + for k in self.variables + if k not in self.dims + and not isinstance(self.variables[k].data, ExtensionDuckArray) + ] + extension_array_columns = [ + k + for k in self.variables + if k not in self.dims + and isinstance(self.variables[k].data, ExtensionDuckArray) + ] data = [ self._variables[k].set_dims(ordered_dims).values.reshape(-1) for k in columns ] index = self.coords.to_index([*ordered_dims]) - return pd.DataFrame(dict(zip(columns, data)), index=index) + broadcasted_df = pd.DataFrame(dict(zip(columns, data)), index=index) + for extension_array_column in extension_array_columns: + extension_array = self.variables[ + extension_array_column + ].data.extension_array + index = self[self.variables[extension_array_column].dims[0]].data + cat_df = pd.DataFrame( + {extension_array_column: extension_array}, + index=self[self.variables[extension_array_column].dims[0]].data, + ) + cat_df.index.name = self.variables[extension_array_column].dims[0] + broadcasted_df = broadcasted_df.join(cat_df) + return broadcasted_df def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame: """Convert this dataset into a pandas.DataFrame. @@ -7311,11 +7334,14 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: "cannot convert a DataFrame with a non-unique MultiIndex into xarray" ) - # Cast to a NumPy array first, in case the Series is a pandas Extension - # array (which doesn't have a valid NumPy dtype) - # TODO: allow users to control how this casting happens, e.g., by - # forwarding arguments to pandas.Series.to_numpy? - arrays = [(k, np.asarray(v)) for k, v in dataframe.items()] + arrays = [ + (k, np.asarray(v)) + for k, v in dataframe.items() + if not is_extension_array_dtype(v) + ] + extension_arrays = [ + (k, v) for k, v in dataframe.items() if is_extension_array_dtype(v) + ] indexes: dict[Hashable, Index] = {} index_vars: dict[Hashable, Variable] = {} @@ -7329,6 +7355,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: xr_idx = PandasIndex(lev, dim) indexes[dim] = xr_idx index_vars.update(xr_idx.create_variables()) + arrays += [(k, np.asarray(v)) for k, v in extension_arrays] + extension_arrays = [] else: index_name = idx.name if idx.name is not None else "index" dims = (index_name,) @@ -7342,6 +7370,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: obj._set_sparse_data_from_dataframe(idx, arrays, dims) else: obj._set_numpy_data_from_dataframe(idx, arrays, dims) + for name, extension_array in extension_arrays: + obj[name] = (dims, extension_array) return obj def to_dask_dataframe( diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index fa9448f2f41..93b97d3b204 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -61,6 +61,7 @@ requires_dask, requires_numexpr, requires_pint, + requires_plum, requires_scipy, requires_sparse, source_ndarray, @@ -4613,14 +4614,17 @@ def test_to_dataarray(self) -> None: expected = expected.rename({"variable": "abc"}).rename("foo") assert_identical(expected, actual) + @requires_plum def test_to_and_from_dataframe(self) -> None: x = np.random.randn(10) y = np.random.randn(10) t = list("abcdefghij") - ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t)}) + cat = pd.Categorical(["a", "b"] * 5) + ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t), "cat": ("t", cat)}) expected = pd.DataFrame( np.array([x, y]).T, columns=["a", "b"], index=pd.Index(t, name="t") ) + expected["cat"] = cat actual = ds.to_dataframe() # use the .equals method to check all DataFrame metadata assert expected.equals(actual), (expected, actual) @@ -4631,23 +4635,31 @@ def test_to_and_from_dataframe(self) -> None: # check roundtrip assert_identical(ds, Dataset.from_dataframe(actual)) - + assert isinstance(ds["cat"].variable.data.dtype, pd.CategoricalDtype) # test a case with a MultiIndex w = np.random.randn(2, 3) - ds = Dataset({"w": (("x", "y"), w)}) + cat = pd.Categorical(["a", "a", "c"]) + ds = Dataset({"w": (("x", "y"), w), "cat": ("y", cat)}) ds["y"] = ("y", list("abc")) exp_index = pd.MultiIndex.from_arrays( [[0, 0, 0, 1, 1, 1], ["a", "b", "c", "a", "b", "c"]], names=["x", "y"] ) - expected = pd.DataFrame(w.reshape(-1), columns=["w"], index=exp_index) + expected = pd.DataFrame( + {"w": w.reshape(-1), "cat": pd.Categorical(["a", "a", "c", "a", "a", "c"])}, + index=exp_index, + ) actual = ds.to_dataframe() assert expected.equals(actual) # check roundtrip + # from_dataframe attempts to broadcast across because it doesn't know better, so cat must be converted + ds["cat"] = (("x", "y"), np.stack((ds["cat"].to_numpy(), ds["cat"].to_numpy()))) assert_identical(ds.assign_coords(x=[0, 1]), Dataset.from_dataframe(actual)) # Check multiindex reordering new_order = ["x", "y"] + # revert broadcasting fix above for 1d arrays + ds["cat"] = ("y", cat) actual = ds.to_dataframe(dim_order=new_order) assert expected.equals(actual) @@ -4656,7 +4668,11 @@ def test_to_and_from_dataframe(self) -> None: [["a", "a", "b", "b", "c", "c"], [0, 1, 0, 1, 0, 1]], names=["y", "x"] ) expected = pd.DataFrame( - w.transpose().reshape(-1), columns=["w"], index=exp_index + { + "w": w.transpose().reshape(-1), + "cat": pd.Categorical(["a", "a", "a", "a", "c", "c"]), + }, + index=exp_index, ) actual = ds.to_dataframe(dim_order=new_order) assert expected.equals(actual) @@ -4709,7 +4725,7 @@ def test_to_and_from_dataframe(self) -> None: expected = pd.DataFrame([[]], index=idx) assert expected.equals(actual), (expected, actual) - def test_from_dataframe_categorical(self) -> None: + def test_from_dataframe_categorical_index(self) -> None: cat = pd.CategoricalDtype( categories=["foo", "bar", "baz", "qux", "quux", "corge"] ) @@ -4724,7 +4740,7 @@ def test_from_dataframe_categorical(self) -> None: assert len(ds["i1"]) == 2 assert len(ds["i2"]) == 2 - def test_from_dataframe_categorical_string_categories(self) -> None: + def test_from_dataframe_categorical_index_string_categories(self) -> None: cat = pd.CategoricalIndex( pd.Categorical.from_codes( np.array([1, 1, 0, 2]), From a16585119ff4dba74bdf8d7aaf5b87b4dcfae163 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 08:59:27 +0100 Subject: [PATCH 13/70] (chore): check for plum import --- xarray/core/variable.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c775049125d..8cf90afdc26 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -8,6 +8,7 @@ from collections.abc import Hashable, Mapping, Sequence from datetime import timedelta from functools import partial +from importlib.util import find_spec from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast import numpy as np @@ -173,7 +174,7 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) - if is_extension_array_dtype(data): + if is_extension_array_dtype(data) and find_spec("plum"): data_type = ( type(data.extension_array) if isinstance(data, duck_array_ops.ExtensionDuckArray) From a826edde6a6ed382216d9a22861172ef96876a2d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 09:31:40 +0100 Subject: [PATCH 14/70] (fix): `__setitem__`/`__getitem__` --- xarray/core/duck_array_ops.py | 13 ++++++++----- xarray/tests/test_duck_array_ops.py | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 82d44d3dc8b..2baf5650eab 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -109,11 +109,14 @@ def __getattr__(self, attr: str) -> object: return getattr(self.extension_array, attr) raise AttributeError(f"{attr} not found.") - def __getitem__(self, key): - return self.extension_array[key] - - def __setitem__(self, key): - return self.extension_array[key] + def __getitem__(self, key) -> ExtensionDuckArray[T_ExtensionArray]: + item = self.extension_array[key] + if is_extension_array_dtype(item): # not a singleton - better way to check? + return ExtensionDuckArray(item) + return item + + def __setitem__(self, key, val): + self.extension_array[key] = val def __eq__(self, other): if isinstance(other, ExtensionDuckArray): diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 54178a736ac..98850026649 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -203,6 +203,23 @@ def test_concatenate_extension_duck_array_fallback( ) ).all() + @requires_plum + def test___getitem__extension_duck_array(self, categorical1): + extension_duck_array = ExtensionDuckArray(categorical1) + assert (extension_duck_array[0:2] == categorical1[0:2]).all() + assert isinstance(extension_duck_array[0:2], ExtensionDuckArray) + assert extension_duck_array[0] == categorical1[0] + mask = [True, False, True, False, True] + assert (extension_duck_array[mask] == categorical1[mask]).all() + + @requires_plum + def test__setitem__extension_duck_array(self, categorical1): + extension_duck_array = ExtensionDuckArray(categorical1) + extension_duck_array[2] = "cat1" # already existing category + assert extension_duck_array[2] == "cat1" + with pytest.raises(TypeError, match="Cannot setitem on a Categorical"): + extension_duck_array[2] = "cat4" # new category + def test_stack_type_promotion(self): result = stack([1, "b"]) assert_array_equal(result, np.array([1, "b"], dtype=object)) From fde19eaa300f0df02ade2ce130ed9d13a37ac413 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 09:33:39 +0100 Subject: [PATCH 15/70] (chore): disallow stacking --- xarray/core/duck_array_ops.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 2baf5650eab..3d979af3fa6 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -145,6 +145,12 @@ def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") +@implements(np.stack) +@dispatch +def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): + raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") + + @implements(np.concatenate) @dispatch def __extension_duck_array__concatenate( From 4c557074097b2942012b7b7ca022c04f7d236b88 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 09:49:39 +0100 Subject: [PATCH 16/70] (fix): `pyproject.toml` --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 13f67f2767e..d713cb11b67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ [project.optional-dependencies] accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"] -complete = ["xarray[accel,io,parallel,viz,dev]"] +complete = ["xarray[accel,io,parallel,viz,dev,extension_arrays]"] dev = [ "hypothesis", "pre-commit", @@ -45,7 +45,7 @@ dev = [ io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"] parallel = ["dask[complete]"] viz = ["matplotlib", "seaborn", "nc-time-axis"] -extension_arrays = ["plum"] +extension-arrays = ["plum-dispatch"] [project.urls] Documentation = "https://docs.xarray.dev" From 58ba17df163ae7ee787250625b5256da1070bda4 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 10:03:25 +0100 Subject: [PATCH 17/70] (fix): `as_shared_type` fix --- xarray/core/duck_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 3d979af3fa6..7d48f201c33 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -371,7 +371,7 @@ def as_shared_dtype(scalars_or_arrays, xp=np): if isinstance(x, ExtensionDuckArray) ] if len(extension_array_types) == len(scalars_or_arrays) and all( - isinstance(x, extension_array_types[0]) for x in extension_array_types + x == extension_array_types[0] for x in extension_array_types ): return scalars_or_arrays arrays = [asarray(np.array(x), xp=xp) for x in scalars_or_arrays] From a255310f49855b8a3ff93e7a41d584395d88746b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 10:04:59 +0100 Subject: [PATCH 18/70] (chore): add variable tests --- xarray/tests/test_variable.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index fb083586415..6dc4f25072b 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -43,6 +43,7 @@ requires_dask, requires_pandas_version_two, requires_pint, + requires_plum, requires_sparse, source_ndarray, ) @@ -1557,6 +1558,13 @@ def test_transpose_0d(self): actual = variable.transpose() assert_identical(actual, variable) + @requires_plum + def test_pandas_cateogrical_dtype(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + v = self.cls("x", data) + print(v) # should not error + assert pd.api.types.is_extension_array_dtype(v.dtype) + def test_squeeze(self): v = Variable(["x", "y"], [[1]]) assert_identical(Variable([], 1), v.squeeze()) From 4e78b7eb841119eaca74e71b1e6a95e805352284 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 10:26:31 +0100 Subject: [PATCH 19/70] (fix): dask + categoricals --- xarray/core/variable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 8cf90afdc26..5a31b473744 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1068,6 +1068,8 @@ def chunk( if chunkmanager.is_chunked_array(data_old): data_chunked = chunkmanager.rechunk(data_old, chunks) else: + if is_extension_array_dtype(data_old): # dask cannot handle pandas types + data_old = np.asarray(data_old) if not isinstance(data_old, indexing.ExplicitlyIndexed): ndata = data_old else: From d9cedf5347458f0143e558292912e59214bb50f0 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 16:15:28 +0100 Subject: [PATCH 20/70] (chore): notes/docs --- doc/whats-new.rst | 4 ++++ xarray/core/duck_array_ops.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index f5e1efadef5..b6b3b79b212 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -30,6 +30,10 @@ New Features to set default `method` for groupby problems. This only applies to ``flox>=0.9``. By `Deepak Cherian `_. +- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array +by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`, for example, +will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray` then, such as broadcasting. + Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 7d48f201c33..30d5a4808f4 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -65,6 +65,30 @@ class ExtensionDuckArray(Generic[T_ExtensionArray]): extension_array: T_ExtensionArray def __init__(self, array: T_ExtensionArray | ExtensionDuckArray): + """NEP-18 compliant wrapper for pandas extension arrays. + To override behavior of common applicable numpy `__array_function__` methods, + please use :py:module:`plum` to dispatch the corresponding function with the prefix + `__extension_duck_array__` i.e., for :py:func:`numpy.where`. The only function that is + necessary to override at the moment is :py:func:`np.where` since :py:module:`pandas` provides + concatenation, and `broadcast`/`stack` are not supported. + + Parameters + ---------- + array : T_ExtensionArray | ExtensionDuckArray + The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. + If also an ExtensionDuckArray, it's `extension_array` attribute will be extracted and assigned to `self.extension_array`. + + + Examples + -------- + The following is an example of how you might support some different kinds of arrays. + ```python + from plum import dispatch + @dispatch + def ___extension_duck_array__where__(cond: np.ndarray, x: pd.arrays.IntegerArray, y: pd.arrays.BooleanArray) -> pd.arrays.IntegerArray: + pass + ``` + """ if isinstance(array, ExtensionDuckArray): self.extension_array = array.extension_array elif is_extension_array_dtype(array): @@ -154,14 +178,16 @@ def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): @implements(np.concatenate) @dispatch def __extension_duck_array__concatenate( - arrays: Sequence[pd.Categorical], axis: int = 0, out=None + arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None ): return type(arrays[0])._concat_same_type(arrays) @implements(np.where) @dispatch -def __extension_duck_array__where(condition: np.ndarray, x, y): +def __extension_duck_array__where( + condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray +): return np.where(condition, x, y) From 426664d313a3c3624462b7f0a5fe93bf68b0d5eb Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 16:15:57 +0100 Subject: [PATCH 21/70] (chore): remove old testing file --- xarray/coding/test.py | 35 ----------------------------------- 1 file changed, 35 deletions(-) delete mode 100644 xarray/coding/test.py diff --git a/xarray/coding/test.py b/xarray/coding/test.py deleted file mode 100644 index 1243cc75d04..00000000000 --- a/xarray/coding/test.py +++ /dev/null @@ -1,35 +0,0 @@ -import netCDF4 as nc -import numpy as np - -import xarray as xr - -xr.get_options() - -ds = nc.Dataset("mre.nc", "w", format="NETCDF4") -cloud_type_enum = ds.createEnumType(int, "cloud_type", {"clear": 0, "cloudy": 1}) # -ds.createDimension("time", size=(10)) -x = np.arange(10) -ds.createVariable("x", np.int32, dimensions=("time",)) -ds.variables["x"][:] = x -# {'cloud_type': : name = 'cloud_type', numpy dtype = int64, fields/values ={'clear': 0, 'cloudy': 1}} -ds.createVariable("cloud", cloud_type_enum, dimensions=("time",)) -ds["cloud"][:] = [1, 0, 1, 0, 1, 0, 1, 0, 0, 1] -ds.close() - -# -- Open dataset with xarray -xr_ds = xr.open_dataset("./mre.nc") -xr_ds.to_netcdf("./mre_new.nc") -xr_ds = xr.open_dataset("./mre_new.nc") -xr_ds -ds_re_read = nc.Dataset("./mre_new.nc", "r", format="NETCDF4") -ds_re_read - -# import numpy as np -# import xarray as xr - -# codes = np.array([0, 1, 2, 1, 0]) -# categories = {0: 'foo', 1: 'jazz', 2: 'bar'} -# cat_arr = xr.coding.variables.CategoricalArray(codes=codes, categories=categories) -# v = xr.Variable(("time,"), cat_arr, fastpath=True) -# ds = xr.Dataset({'cloud': v}) -# ds.to_zarr('test.zarr') From 22ca77d22604783792ce3293f4ea0ea3b4ed498b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 16:39:53 +0100 Subject: [PATCH 22/70] (chore): remove ocmmented out code --- xarray/core/duck_array_ops.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 30d5a4808f4..dac78582e39 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -40,6 +40,7 @@ from xarray.core.options import OPTIONS from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array from xarray.core.pycompat import array_type, is_duck_dask_array +from xarray.core.types import DTypeLikeSave from xarray.core.utils import is_duck_array, module_available # remove once numpy 2.0 is the oldest supported version @@ -97,8 +98,6 @@ def ___extension_duck_array__where__(cond: np.ndarray, x: pd.arrays.IntegerArray raise TypeError(f"{array} is not an pandas ExtensionArray.") def __array_function__(self, func, types, args, kwargs): - # if not all(issubclass(t, ExtensionDuckArray) for t in types): - # return NotImplemented def replace_duck_with_extension_array(args) -> list: args_as_list = list(args) for index, value in enumerate(args_as_list): @@ -161,6 +160,14 @@ def decorator(func): return decorator +@implements(np.issubdtype) +@dispatch +def __extension_duck_array__issubdtype( + extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave +): + return False # never want a function to think a categorical is a subtype of numpy + + @implements(np.broadcast_to) @dispatch def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): From 60f8927c535601b753ce02ac1f10d561444226ee Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 8 Feb 2024 16:51:02 +0100 Subject: [PATCH 23/70] (fix): import plum dispatch --- xarray/core/duck_array_ops.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index dac78582e39..c549a19a5ee 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -34,7 +34,14 @@ from numpy.lib.stride_tricks import sliding_window_view # noqa from packaging.version import Version from pandas.api.types import is_extension_array_dtype -from plum import dispatch # noqa + +try: + from plum import dispatch +except ImportError: + + def dispatch(*args, **kwargs): + pass + from xarray.core import dask_array_ops, dtypes, nputils, pycompat from xarray.core.options import OPTIONS From b6d0b318f02236fc93f8020cef7af6588a7b20de Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 9 Feb 2024 16:16:27 +0100 Subject: [PATCH 24/70] (refactor): use `is_extension_array_dtype` as much as possible --- xarray/core/dataset.py | 7 +++---- xarray/core/duck_array_ops.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e8c7fb5e4d9..48b5f972afe 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -65,7 +65,7 @@ create_coords_with_default_indexes, ) from xarray.core.daskmanager import DaskManager -from xarray.core.duck_array_ops import ExtensionDuckArray, datetime_to_numeric +from xarray.core.duck_array_ops import datetime_to_numeric from xarray.core.indexes import ( Index, Indexes, @@ -7155,13 +7155,12 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]): k for k in self.variables if k not in self.dims - and not isinstance(self.variables[k].data, ExtensionDuckArray) + and not is_extension_array_dtype(self.variables[k].data) ] extension_array_columns = [ k for k in self.variables - if k not in self.dims - and isinstance(self.variables[k].data, ExtensionDuckArray) + if k not in self.dims and is_extension_array_dtype(self.variables[k].data) ] data = [ self._variables[k].set_dims(ordered_dims).values.reshape(-1) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index ebb50c4d947..24bab40c423 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -340,7 +340,7 @@ def isnull(data): return full_like(data, dtype=bool, fill_value=False) else: # at this point, array should have dtype=object - if isinstance(data, np.ndarray) or isinstance(data, ExtensionDuckArray): + if isinstance(data, np.ndarray) or is_extension_array_dtype(data): return pandas_isnull(data) else: # Not reachable yet, but intended for use with other duck array @@ -405,11 +405,11 @@ def asarray(data, xp=np): def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - if any(isinstance(x, ExtensionDuckArray) for x in scalars_or_arrays): + if any(is_extension_array_dtype(x) for x in scalars_or_arrays): extension_array_types = [ type(x.extension_array) for x in scalars_or_arrays - if isinstance(x, ExtensionDuckArray) + if is_extension_array_dtype(x) ] if len(extension_array_types) == len(scalars_or_arrays) and all( x == extension_array_types[0] for x in extension_array_types From 8238c64081da7e849f0e862ec365e4e8abbc89cc Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Sat, 10 Feb 2024 07:49:08 +0100 Subject: [PATCH 25/70] (refactor): `extension_array`->`array` + move to `indexing` --- xarray/core/dataset.py | 4 +- xarray/core/duck_array_ops.py | 104 ++-------------------------- xarray/core/indexing.py | 93 ++++++++++++++++++++++++- xarray/core/types.py | 3 + xarray/core/variable.py | 7 +- xarray/tests/test_concat.py | 6 +- xarray/tests/test_duck_array_ops.py | 2 +- 7 files changed, 108 insertions(+), 111 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 48b5f972afe..a84d3aa7dc3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7169,9 +7169,7 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]): index = self.coords.to_index([*ordered_dims]) broadcasted_df = pd.DataFrame(dict(zip(columns, data)), index=index) for extension_array_column in extension_array_columns: - extension_array = self.variables[ - extension_array_column - ].data.extension_array + extension_array = self.variables[extension_array_column].data.array index = self[self.variables[extension_array_column].dims[0]].data cat_df = pd.DataFrame( {extension_array_column: extension_array}, diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 24bab40c423..a982fb44c3a 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -13,7 +13,7 @@ from collections.abc import Sequence from functools import partial from importlib import import_module -from typing import Generic, TypeVar +from typing import Callable import numpy as np import pandas as pd @@ -40,7 +40,7 @@ from plum import dispatch except ImportError: - def dispatch(*args, **kwargs): + def dispatch(*args, **kwargs): # type: ignore pass @@ -48,7 +48,7 @@ def dispatch(*args, **kwargs): from xarray.core.options import OPTIONS from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array from xarray.core.pycompat import array_type, is_duck_dask_array -from xarray.core.types import DTypeLikeSave +from xarray.core.types import DTypeLikeSave, T_ExtensionArray from xarray.core.utils import is_duck_array, module_available # remove once numpy 2.0 is the oldest supported version @@ -65,104 +65,14 @@ def dispatch(*args, **kwargs): dask_available = module_available("dask") -HANDLED_FUNCTIONS = {} - -T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray) - - -class ExtensionDuckArray(Generic[T_ExtensionArray]): - extension_array: T_ExtensionArray - - def __init__(self, array: T_ExtensionArray | ExtensionDuckArray): - """NEP-18 compliant wrapper for pandas extension arrays. - To override behavior of common applicable numpy `__array_function__` methods, - please use :py:module:`plum` to dispatch the corresponding function with the prefix - `__extension_duck_array__` i.e., for :py:func:`numpy.where`. The only function that is - necessary to override at the moment is :py:func:`np.where` since :py:module:`pandas` provides - concatenation, and `broadcast`/`stack` are not supported. - - Parameters - ---------- - array : T_ExtensionArray | ExtensionDuckArray - The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. - If also an ExtensionDuckArray, it's `extension_array` attribute will be extracted and assigned to `self.extension_array`. - - - Examples - -------- - The following is an example of how you might support some different kinds of arrays. - ```python - from plum import dispatch - @dispatch - def ___extension_duck_array__where__(cond: np.ndarray, x: pd.arrays.IntegerArray, y: pd.arrays.BooleanArray) -> pd.arrays.IntegerArray: - pass - ``` - """ - if isinstance(array, ExtensionDuckArray): - self.extension_array = array.extension_array - elif is_extension_array_dtype(array): - self.extension_array = array - else: - raise TypeError(f"{array} is not an pandas ExtensionArray.") - - def __array_function__(self, func, types, args, kwargs): - def replace_duck_with_extension_array(args) -> list: - args_as_list = list(args) - for index, value in enumerate(args_as_list): - if isinstance(value, ExtensionDuckArray): - args_as_list[index] = value.extension_array - elif isinstance( - value, tuple - ): # should handle more than just tuple? iterable? - args_as_list[index] = tuple( - replace_duck_with_extension_array(value) - ) - elif isinstance(value, list): - args_as_list[index] = replace_duck_with_extension_array(value) - return args_as_list - - args = tuple(replace_duck_with_extension_array(args)) - if func not in HANDLED_FUNCTIONS: - return func(*args, **kwargs) - res = HANDLED_FUNCTIONS[func](*args, **kwargs) - if is_extension_array_dtype(res): - return ExtensionDuckArray[type(res)](res) - return res - - def __array_ufunc__(ufunc, method, *inputs, **kwargs): - return ufunc(*inputs, **kwargs) - - def __repr__(self): - return f"ExtensionDuckArray(array={repr(self.extension_array)})" - - def __getattr__(self, attr: str) -> object: - if hasattr(self.extension_array, attr): - return getattr(self.extension_array, attr) - raise AttributeError(f"{attr} not found.") - - def __getitem__(self, key) -> ExtensionDuckArray[T_ExtensionArray]: - item = self.extension_array[key] - if is_extension_array_dtype(item): # not a singleton - better way to check? - return ExtensionDuckArray(item) - return item - - def __setitem__(self, key, val): - self.extension_array[key] = val - - def __eq__(self, other): - if isinstance(other, ExtensionDuckArray): - return self.extension_array == other.extension_array - return self.extension_array == other - - def __ne__(self, other): - return ~(self == other) +HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} def implements(numpy_function): """Register an __array_function__ implementation for MyArray objects.""" def decorator(func): - HANDLED_FUNCTIONS[numpy_function] = func + HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func return func return decorator @@ -407,9 +317,7 @@ def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" if any(is_extension_array_dtype(x) for x in scalars_or_arrays): extension_array_types = [ - type(x.extension_array) - for x in scalars_or_arrays - if is_extension_array_dtype(x) + type(x.array) for x in scalars_or_arrays if is_extension_array_dtype(x) ] if len(extension_array_types) == len(scalars_or_arrays) and all( x == extension_array_types[0] for x in extension_array_types diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index cf8c2a3d259..787973aaf0c 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -9,10 +9,11 @@ from dataclasses import dataclass, field from datetime import timedelta from html import escape -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Generic import numpy as np import pandas as pd +from pandas.api.types import is_extension_array_dtype from xarray.core import duck_array_ops from xarray.core.nputils import NumpyVIndexAdapter @@ -24,7 +25,7 @@ is_duck_array, is_duck_dask_array, ) -from xarray.core.types import T_Xarray +from xarray.core.types import T_ExtensionArray, T_Xarray from xarray.core.utils import ( NDArrayMixin, either_dict_or_kwargs, @@ -1573,6 +1574,94 @@ def copy(self, deep: bool = True) -> PandasIndexingAdapter: return type(self)(array, self._dtype) +class ExtensionDuckArray(Generic[T_ExtensionArray]): + array: T_ExtensionArray + + def __init__(self, array: T_ExtensionArray | ExtensionDuckArray): + """NEP-18 compliant wrapper for pandas extension arrays. + To override behavior of common applicable numpy `__array_function__` methods, + please use :py:module:`plum` to dispatch the corresponding function with the prefix + `__extension_duck_array__` i.e., for :py:func:`numpy.where`. The only function that is + necessary to override at the moment is :py:func:`np.where` since :py:module:`pandas` provides + concatenation, and `broadcast`/`stack` are not supported. + + Parameters + ---------- + array : T_ExtensionArray | ExtensionDuckArray + The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. + If also an ExtensionDuckArray, it's `array` attribute will be extracted and assigned to `self.array`. + + + Examples + -------- + The following is an example of how you might support some different kinds of arrays. + ```python + from plum import dispatch + @dispatch + def ___extension_duck_array__where__(cond: np.ndarray, x: pd.arrays.IntegerArray, y: pd.arrays.BooleanArray) -> pd.arrays.IntegerArray: + pass + ``` + """ + if isinstance(array, ExtensionDuckArray): + self.array = array.array + elif is_extension_array_dtype(array): + self.array = array + else: + raise TypeError(f"{array} is not an pandas ExtensionArray.") + + def __array_function__(self, func, types, args, kwargs): + def replace_duck_with_extension_array(args) -> list: + args_as_list = list(args) + for index, value in enumerate(args_as_list): + if isinstance(value, ExtensionDuckArray): + args_as_list[index] = value.array + elif isinstance( + value, tuple + ): # should handle more than just tuple? iterable? + args_as_list[index] = tuple( + replace_duck_with_extension_array(value) + ) + elif isinstance(value, list): + args_as_list[index] = replace_duck_with_extension_array(value) + return args_as_list + + args = tuple(replace_duck_with_extension_array(args)) + if func not in duck_array_ops.HANDLED_EXTENSION_ARRAY_FUNCTIONS: + return func(*args, **kwargs) + res = duck_array_ops.HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) + if is_extension_array_dtype(res): + return type(self)[type(res)](res) + return res + + def __array_ufunc__(ufunc, method, *inputs, **kwargs): + return ufunc(*inputs, **kwargs) + + def __repr__(self): + return f"{type(self)}(array={repr(self.array)})" + + def __getattr__(self, attr: str) -> object: + if hasattr(self.array, attr): + return getattr(self.array, attr) + raise AttributeError(f"{attr} not found.") + + def __getitem__(self, key) -> ExtensionDuckArray[T_ExtensionArray]: + item = self.array[key] + if is_extension_array_dtype(item): # not a singleton - better way to check? + return type(self)(item) + return item + + def __setitem__(self, key, val): + self.array[key] = val + + def __eq__(self, other): + if isinstance(other, ExtensionDuckArray): + return self.array == other.array + return self.array == other + + def __ne__(self, other): + return ~(self == other) + + class PandasMultiIndexingAdapter(PandasIndexingAdapter): """Handles explicit indexing for a pandas.MultiIndex. diff --git a/xarray/core/types.py b/xarray/core/types.py index 410cf3de00b..8f58e54d8cf 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -167,6 +167,9 @@ def copy( # hopefully in the future we can narrow this down more: T_DuckArray = TypeVar("T_DuckArray", bound=Any, covariant=True) +# For typing pandas extension arrays. +T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray) + ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] VarCompatible = Union["Variable", "ScalarOrArray"] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 4c3d12645ac..45063b4ead6 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -22,6 +22,7 @@ from xarray.core.common import AbstractArray from xarray.core.indexing import ( BasicIndexer, + ExtensionDuckArray, OuterIndexer, PandasIndexingAdapter, VectorizedIndexer, @@ -176,11 +177,9 @@ def _maybe_wrap_data(data): return PandasIndexingAdapter(data) if is_extension_array_dtype(data) and find_spec("plum"): data_type = ( - type(data.extension_array) - if isinstance(data, duck_array_ops.ExtensionDuckArray) - else type(data) + type(data.array) if isinstance(data, ExtensionDuckArray) else type(data) ) - return duck_array_ops.ExtensionDuckArray[data_type](data) + return ExtensionDuckArray[data_type](data) return data diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 6ec1ab3b93a..1ddb5a569bd 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -158,10 +158,10 @@ def test_concat_categorical() -> None: concatenated = concat([data1, data2], dim="dim1") assert ( concatenated["var4"] - == type(data2["var4"].variable.data.extension_array)._concat_same_type( + == type(data2["var4"].variable.data.array)._concat_same_type( [ - data1["var4"].variable.data.extension_array, - data2["var4"].variable.data.extension_array, + data1["var4"].variable.data.array, + data2["var4"].variable.data.array, ] ) ).all() diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 98850026649..e09d2d5e6fd 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -12,7 +12,6 @@ from xarray import DataArray, Dataset, cftime_range, concat from xarray.core import dtypes, duck_array_ops from xarray.core.duck_array_ops import ( - ExtensionDuckArray, __extension_duck_array__broadcast, __extension_duck_array__concatenate, __extension_duck_array__where, @@ -32,6 +31,7 @@ timedelta_to_numeric, where, ) +from xarray.core.indexing import ExtensionDuckArray from xarray.core.pycompat import array_type from xarray.testing import assert_allclose, assert_equal, assert_identical from xarray.tests import ( From b04ef9817057262d66cba91606abba46110e04f4 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Sat, 10 Feb 2024 07:53:29 +0100 Subject: [PATCH 26/70] (refactor): change order of classes --- xarray/core/indexing.py | 176 ++++++++++++++++++++-------------------- 1 file changed, 88 insertions(+), 88 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 787973aaf0c..8325fd3de93 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1475,6 +1475,94 @@ def transpose(self, order): return self.array.transpose(order) +class ExtensionDuckArray(Generic[T_ExtensionArray]): + array: T_ExtensionArray + + def __init__(self, array: T_ExtensionArray | ExtensionDuckArray): + """NEP-18 compliant wrapper for pandas extension arrays. + To override behavior of common applicable numpy `__array_function__` methods, + please use :py:module:`plum` to dispatch the corresponding function with the prefix + `__extension_duck_array__` i.e., for :py:func:`numpy.where`. The only function that is + necessary to override at the moment is :py:func:`np.where` since :py:module:`pandas` provides + concatenation, and `broadcast`/`stack` are not supported. + + Parameters + ---------- + array : T_ExtensionArray | ExtensionDuckArray + The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. + If also an ExtensionDuckArray, it's `array` attribute will be extracted and assigned to `self.array`. + + + Examples + -------- + The following is an example of how you might support some different kinds of arrays. + ```python + from plum import dispatch + @dispatch + def ___extension_duck_array__where__(cond: np.ndarray, x: pd.arrays.IntegerArray, y: pd.arrays.BooleanArray) -> pd.arrays.IntegerArray: + pass + ``` + """ + if isinstance(array, ExtensionDuckArray): + self.array = array.array + elif is_extension_array_dtype(array): + self.array = array + else: + raise TypeError(f"{array} is not an pandas ExtensionArray.") + + def __array_function__(self, func, types, args, kwargs): + def replace_duck_with_extension_array(args) -> list: + args_as_list = list(args) + for index, value in enumerate(args_as_list): + if isinstance(value, ExtensionDuckArray): + args_as_list[index] = value.array + elif isinstance( + value, tuple + ): # should handle more than just tuple? iterable? + args_as_list[index] = tuple( + replace_duck_with_extension_array(value) + ) + elif isinstance(value, list): + args_as_list[index] = replace_duck_with_extension_array(value) + return args_as_list + + args = tuple(replace_duck_with_extension_array(args)) + if func not in duck_array_ops.HANDLED_EXTENSION_ARRAY_FUNCTIONS: + return func(*args, **kwargs) + res = duck_array_ops.HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) + if is_extension_array_dtype(res): + return type(self)[type(res)](res) + return res + + def __array_ufunc__(ufunc, method, *inputs, **kwargs): + return ufunc(*inputs, **kwargs) + + def __repr__(self): + return f"{type(self)}(array={repr(self.array)})" + + def __getattr__(self, attr: str) -> object: + if hasattr(self.array, attr): + return getattr(self.array, attr) + raise AttributeError(f"{attr} not found.") + + def __getitem__(self, key) -> ExtensionDuckArray[T_ExtensionArray]: + item = self.array[key] + if is_extension_array_dtype(item): # not a singleton - better way to check? + return type(self)(item) + return item + + def __setitem__(self, key, val): + self.array[key] = val + + def __eq__(self, other): + if isinstance(other, ExtensionDuckArray): + return self.array == other.array + return self.array == other + + def __ne__(self, other): + return ~(self == other) + + class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" @@ -1574,94 +1662,6 @@ def copy(self, deep: bool = True) -> PandasIndexingAdapter: return type(self)(array, self._dtype) -class ExtensionDuckArray(Generic[T_ExtensionArray]): - array: T_ExtensionArray - - def __init__(self, array: T_ExtensionArray | ExtensionDuckArray): - """NEP-18 compliant wrapper for pandas extension arrays. - To override behavior of common applicable numpy `__array_function__` methods, - please use :py:module:`plum` to dispatch the corresponding function with the prefix - `__extension_duck_array__` i.e., for :py:func:`numpy.where`. The only function that is - necessary to override at the moment is :py:func:`np.where` since :py:module:`pandas` provides - concatenation, and `broadcast`/`stack` are not supported. - - Parameters - ---------- - array : T_ExtensionArray | ExtensionDuckArray - The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. - If also an ExtensionDuckArray, it's `array` attribute will be extracted and assigned to `self.array`. - - - Examples - -------- - The following is an example of how you might support some different kinds of arrays. - ```python - from plum import dispatch - @dispatch - def ___extension_duck_array__where__(cond: np.ndarray, x: pd.arrays.IntegerArray, y: pd.arrays.BooleanArray) -> pd.arrays.IntegerArray: - pass - ``` - """ - if isinstance(array, ExtensionDuckArray): - self.array = array.array - elif is_extension_array_dtype(array): - self.array = array - else: - raise TypeError(f"{array} is not an pandas ExtensionArray.") - - def __array_function__(self, func, types, args, kwargs): - def replace_duck_with_extension_array(args) -> list: - args_as_list = list(args) - for index, value in enumerate(args_as_list): - if isinstance(value, ExtensionDuckArray): - args_as_list[index] = value.array - elif isinstance( - value, tuple - ): # should handle more than just tuple? iterable? - args_as_list[index] = tuple( - replace_duck_with_extension_array(value) - ) - elif isinstance(value, list): - args_as_list[index] = replace_duck_with_extension_array(value) - return args_as_list - - args = tuple(replace_duck_with_extension_array(args)) - if func not in duck_array_ops.HANDLED_EXTENSION_ARRAY_FUNCTIONS: - return func(*args, **kwargs) - res = duck_array_ops.HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) - if is_extension_array_dtype(res): - return type(self)[type(res)](res) - return res - - def __array_ufunc__(ufunc, method, *inputs, **kwargs): - return ufunc(*inputs, **kwargs) - - def __repr__(self): - return f"{type(self)}(array={repr(self.array)})" - - def __getattr__(self, attr: str) -> object: - if hasattr(self.array, attr): - return getattr(self.array, attr) - raise AttributeError(f"{attr} not found.") - - def __getitem__(self, key) -> ExtensionDuckArray[T_ExtensionArray]: - item = self.array[key] - if is_extension_array_dtype(item): # not a singleton - better way to check? - return type(self)(item) - return item - - def __setitem__(self, key, val): - self.array[key] = val - - def __eq__(self, other): - if isinstance(other, ExtensionDuckArray): - return self.array == other.array - return self.array == other - - def __ne__(self, other): - return ~(self == other) - - class PandasMultiIndexingAdapter(PandasIndexingAdapter): """Handles explicit indexing for a pandas.MultiIndex. From b9937bf0627ba76d1d62f88ceb2011fb7dc6eda5 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 Feb 2024 11:32:20 +0100 Subject: [PATCH 27/70] (chore): add small pyarrow test --- xarray/tests/test_duck_array_ops.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index e09d2d5e6fd..08979e269ff 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +import pyarrow as pa import pytest from numpy import array, nan @@ -59,6 +60,20 @@ def categorical2(): return pd.Categorical(["cat2", "cat1", "cat2", "cat3", "cat1"]) +@pytest.fixture +def arrow1(): + return pd.arrays.ArrowExtensionArray( + pa.array([{"x": 1, "y": True}, {"x": 2, "y": False}]) + ) + + +@pytest.fixture +def arrow2(): + return pd.arrays.ArrowExtensionArray( + pa.array([{"x": 3, "y": False}, {"x": 4, "y": True}]) + ) + + @pytest.fixture def int1(): return pd.arrays.IntegerArray( @@ -189,6 +204,13 @@ def test_concatenate_extension_duck_array(self, categorical1, categorical2): == type(categorical1)._concat_same_type((categorical1, categorical2)) ).all() + def test_duck_extension_array_pyarrow_concatenate(arrow1, arrow2): + concatenated = concatenate( + (ExtensionDuckArray(arrow1), ExtensionDuckArray(arrow2)) + ) + assert concatenated[2]["x"] == 3 + assert concatenated[3]["y"] + def test_concatenate_extension_duck_array_fallback( self, categorical1, categorical2 ): From 0bba03fba6e9f3f6a57ce7087a57ca2776a893f1 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 Feb 2024 11:45:42 +0100 Subject: [PATCH 28/70] (fix): fix some mypy issues --- pyproject.toml | 1 + xarray/core/duck_array_ops.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 59c5ee0b06a..17b4d1b5acf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,6 +125,7 @@ module = [ "opt_einsum.*", "pandas.*", "pooch.*", + "pyarrow.*", "pydap.*", "pytest.*", "scipy.*", diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index a982fb44c3a..cb414f8a164 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -37,10 +37,10 @@ from pandas.api.types import is_extension_array_dtype try: - from plum import dispatch + from plum import dispatch # type: ignore[import-not-found] except ImportError: - def dispatch(*args, **kwargs): # type: ignore + def dispatch(*args, **kwargs): # type: ignore[misc] pass @@ -116,8 +116,10 @@ def __extension_duck_array__where( return np.where(condition, x, y) -@__extension_duck_array__where.dispatch -def _(condition: np.ndarray, x: pd.Categorical, y: pd.Categorical): +@dispatch # type: ignore[no-redef] +def __extension_duck_array__where( + condition: np.ndarray, x: pd.Categorical, y: pd.Categorical +): # set up new codes array new_codes = np.where(condition, x.codes, y.codes) From b714549d1d83052ab2cae86072dc7825d8e0372a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 Feb 2024 12:27:41 +0100 Subject: [PATCH 29/70] (fix): don't register unregisterable method --- xarray/tests/test_duck_array_ops.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 08979e269ff..c983cc7d14a 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -3,6 +3,7 @@ import datetime as dt import warnings from collections.abc import Sequence +from importlib.util import find_spec import numpy as np import pandas as pd @@ -88,12 +89,14 @@ def int2(): ) -@__extension_duck_array__concatenate.dispatch -def _(arrays: Sequence[pd.arrays.IntegerArray], axis: int = 0, out=None): - values = np.concatenate(arrays) - mask = np.isnan(values) - values = values.astype("int8") - return pd.arrays.IntegerArray(values, mask) +if find_spec("plum"): + + @__extension_duck_array__concatenate.dispatch + def _(arrays: Sequence[pd.arrays.IntegerArray], axis: int = 0, out=None): + values = np.concatenate(arrays) + mask = np.isnan(values) + values = values.astype("int8") + return pd.arrays.IntegerArray(values, mask) class TestOps: @@ -204,7 +207,7 @@ def test_concatenate_extension_duck_array(self, categorical1, categorical2): == type(categorical1)._concat_same_type((categorical1, categorical2)) ).all() - def test_duck_extension_array_pyarrow_concatenate(arrow1, arrow2): + def test_duck_extension_array_pyarrow_concatenate(self, arrow1, arrow2): concatenated = concatenate( (ExtensionDuckArray(arrow1), ExtensionDuckArray(arrow2)) ) From a3a678cd58f16c8f5b8dc7d4edcc46ee7e21df1a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 Feb 2024 12:31:04 +0100 Subject: [PATCH 30/70] (fix): appease mypy --- xarray/core/duck_array_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index cb414f8a164..e1d800cfd07 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -40,8 +40,8 @@ from plum import dispatch # type: ignore[import-not-found] except ImportError: - def dispatch(*args, **kwargs): # type: ignore[misc] - pass + def dispatch(func): + return func from xarray.core import dask_array_ops, dtypes, nputils, pycompat From e5218449671774a262e7d7d4b5b18aff58ea35d4 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 Feb 2024 16:40:58 +0100 Subject: [PATCH 31/70] (fix): more sensible default implemetations allow most use without `plum` --- xarray/core/duck_array_ops.py | 55 ++++++++--------------------- xarray/core/variable.py | 3 +- xarray/tests/test_dataset.py | 2 -- xarray/tests/test_duck_array_ops.py | 9 ----- xarray/tests/test_variable.py | 2 -- 5 files changed, 15 insertions(+), 56 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index e1d800cfd07..af285df2d7d 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -82,8 +82,8 @@ def decorator(func): @dispatch def __extension_duck_array__issubdtype( extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave -): - return False # never want a function to think a categorical is a subtype of numpy +) -> bool: + return False # never want a function to think a pandas extension dtype is a subtype of numpy @implements(np.broadcast_to) @@ -104,7 +104,7 @@ def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): @dispatch def __extension_duck_array__concatenate( arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None -): +) -> T_ExtensionArray: return type(arrays[0])._concat_same_type(arrays) @@ -112,42 +112,15 @@ def __extension_duck_array__concatenate( @dispatch def __extension_duck_array__where( condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray -): - return np.where(condition, x, y) - - -@dispatch # type: ignore[no-redef] -def __extension_duck_array__where( - condition: np.ndarray, x: pd.Categorical, y: pd.Categorical -): - # set up new codes array - new_codes = np.where(condition, x.codes, y.codes) - - # Remap shared categories to have same codes - shared_categories = [ - category for category in x.categories if category in set(y.categories) - ] - for shared_category in shared_categories: - new_codes[~condition & (y == shared_category)] = x.codes[x == shared_category][ - 0 - ] - - # map non-shared y codes to start from the lowest possible number - y_only_categories = [ - category for category in y.categories if category not in shared_categories - ] - used_x_only_categories = ( - x[~x.isin(shared_categories)].remove_unused_categories().categories - ) - new_y_code = len(used_x_only_categories) + len(shared_categories) - for y_only_category in y_only_categories: - new_codes[~condition & (y == y_only_category)] = new_y_code - new_y_code += 1 - new_categories = shared_categories + y_only_categories # preserve order - # TODO: think about ordering? - return pd.Categorical.from_codes( - new_codes, categories=new_categories, ordered=False - ) +) -> T_ExtensionArray: + if ( + isinstance(x, pd.Categorical) + and isinstance(y, pd.Categorical) + and x.dtype != y.dtype + ): + x = x.add_categories(set(y.categories).difference(set(x.categories))) + y = y.add_categories(set(x.categories).difference(set(y.categories))) + return pd.Series(x).where(condition, pd.Series(y)).array def get_array_namespace(x): @@ -319,10 +292,10 @@ def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" if any(is_extension_array_dtype(x) for x in scalars_or_arrays): extension_array_types = [ - type(x.array) for x in scalars_or_arrays if is_extension_array_dtype(x) + x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) ] if len(extension_array_types) == len(scalars_or_arrays) and all( - x == extension_array_types[0] for x in extension_array_types + isinstance(x, type(extension_array_types[0])) for x in extension_array_types ): return scalars_or_arrays arrays = [asarray(np.array(x), xp=xp) for x in scalars_or_arrays] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 45063b4ead6..a08c71621b3 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -8,7 +8,6 @@ from collections.abc import Hashable, Mapping, Sequence from datetime import timedelta from functools import partial -from importlib.util import find_spec from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast import numpy as np @@ -175,7 +174,7 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) - if is_extension_array_dtype(data) and find_spec("plum"): + if is_extension_array_dtype(data): data_type = ( type(data.array) if isinstance(data, ExtensionDuckArray) else type(data) ) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c8050828b04..b298201ab09 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -61,7 +61,6 @@ requires_dask, requires_numexpr, requires_pint, - requires_plum, requires_scipy, requires_sparse, source_ndarray, @@ -4610,7 +4609,6 @@ def test_to_dataarray(self) -> None: expected = expected.rename({"variable": "abc"}).rename("foo") assert_identical(expected, actual) - @requires_plum def test_to_and_from_dataframe(self) -> None: x = np.random.randn(10) y = np.random.randn(10) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index c983cc7d14a..ae417fea070 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -175,7 +175,6 @@ def test_where_type_promotion(self): assert result.dtype == np.float32 assert_array_equal(result, np.array([1, np.nan], dtype=np.float32)) - @requires_plum def test_where_extension_duck_array(self, categorical1, categorical2): where_res = where( np.array([True, False, True, False, False]), @@ -196,7 +195,6 @@ def test_where_extension_duck_array_fallback(self, categorical1, categorical2): assert isinstance(where_res, np.ndarray) assert (where_res == np.array(["cat1", "cat1", "cat2", "cat3", "cat1"])).all() - @requires_plum def test_concatenate_extension_duck_array(self, categorical1, categorical2): concate_res = concatenate( [ExtensionDuckArray(categorical1), ExtensionDuckArray(categorical2)] @@ -228,7 +226,6 @@ def test_concatenate_extension_duck_array_fallback( ) ).all() - @requires_plum def test___getitem__extension_duck_array(self, categorical1): extension_duck_array = ExtensionDuckArray(categorical1) assert (extension_duck_array[0:2] == categorical1[0:2]).all() @@ -237,7 +234,6 @@ def test___getitem__extension_duck_array(self, categorical1): mask = [True, False, True, False, True] assert (extension_duck_array[mask] == categorical1[mask]).all() - @requires_plum def test__setitem__extension_duck_array(self, categorical1): extension_duck_array = ExtensionDuckArray(categorical1) extension_duck_array[2] = "cat1" # already existing category @@ -1060,7 +1056,6 @@ def test_push_dask(): np.testing.assert_equal(actual, expected) -@requires_plum def test_where_all_categoricals(categorical1, categorical2): assert ( __extension_duck_array__where( @@ -1070,7 +1065,6 @@ def test_where_all_categoricals(categorical1, categorical2): ).all() -@requires_plum def test_where_drop_categoricals(categorical1, categorical2): assert ( __extension_duck_array__where( @@ -1080,18 +1074,15 @@ def test_where_drop_categoricals(categorical1, categorical2): ).all() -@requires_plum def test_broadcast_to_categorical(categorical1): with pytest.raises(NotImplementedError): __extension_duck_array__broadcast(categorical1, (5, 2)) -@requires_plum def test_broadcast_to_same_categorical(categorical1): assert (__extension_duck_array__broadcast(categorical1, (5,)) == categorical1).all() -@requires_plum def test_concategorical_categorical(categorical1, categorical2): assert ( __extension_duck_array__concatenate([categorical1, categorical2]) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 71006d5e050..097ca20367b 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -43,7 +43,6 @@ requires_dask, requires_pandas_version_two, requires_pint, - requires_plum, requires_sparse, source_ndarray, ) @@ -1559,7 +1558,6 @@ def test_transpose_0d(self): actual = variable.transpose() assert_identical(actual, variable) - @requires_plum def test_pandas_cateogrical_dtype(self): data = pd.Categorical(np.arange(10, dtype="int64")) v = self.cls("x", data) From 2d3e930a6dddcba780dd52292628daa48790dbf6 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 Feb 2024 16:57:31 +0100 Subject: [PATCH 32/70] (fix): handling `pyarrow` tests --- xarray/tests/__init__.py | 1 + xarray/tests/test_duck_array_ops.py | 23 +++++++++++++---------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 045ea5e7599..1d46d2ba7e7 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -96,6 +96,7 @@ def _importorskip( has_iris, requires_iris = _importorskip("iris") has_numbagg, requires_numbagg = _importorskip("numbagg", "0.4.0") has_plum, requires_plum = _importorskip("plum") +has_pyarrow, requires_pyarrow = _importorskip("pyarrow") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index ae417fea070..0ce4dca0210 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -46,6 +46,7 @@ requires_cftime, requires_dask, requires_plum, + requires_pyarrow, ) dask_array_type = array_type("dask") @@ -61,18 +62,19 @@ def categorical2(): return pd.Categorical(["cat2", "cat1", "cat2", "cat3", "cat1"]) -@pytest.fixture -def arrow1(): - return pd.arrays.ArrowExtensionArray( - pa.array([{"x": 1, "y": True}, {"x": 2, "y": False}]) - ) +if find_spec("arrow"): + @pytest.fixture + def arrow1(): + return pd.arrays.ArrowExtensionArray( + pa.array([{"x": 1, "y": True}, {"x": 2, "y": False}]) + ) -@pytest.fixture -def arrow2(): - return pd.arrays.ArrowExtensionArray( - pa.array([{"x": 3, "y": False}, {"x": 4, "y": True}]) - ) + @pytest.fixture + def arrow2(): + return pd.arrays.ArrowExtensionArray( + pa.array([{"x": 3, "y": False}, {"x": 4, "y": True}]) + ) @pytest.fixture @@ -205,6 +207,7 @@ def test_concatenate_extension_duck_array(self, categorical1, categorical2): == type(categorical1)._concat_same_type((categorical1, categorical2)) ).all() + @requires_pyarrow def test_duck_extension_array_pyarrow_concatenate(self, arrow1, arrow2): concatenated = concatenate( (ExtensionDuckArray(arrow1), ExtensionDuckArray(arrow2)) From 04c9969b84b75b72b8c7a9730e96c2d81d943041 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 12 Feb 2024 17:02:14 +0100 Subject: [PATCH 33/70] (fix): actually do import correctly --- xarray/tests/test_duck_array_ops.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 0ce4dca0210..58ef3bbee16 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -7,7 +7,6 @@ import numpy as np import pandas as pd -import pyarrow as pa import pytest from numpy import array, nan @@ -62,7 +61,8 @@ def categorical2(): return pd.Categorical(["cat2", "cat1", "cat2", "cat3", "cat1"]) -if find_spec("arrow"): +try: + import pyarrow as pa @pytest.fixture def arrow1(): @@ -76,6 +76,9 @@ def arrow2(): pa.array([{"x": 3, "y": False}, {"x": 4, "y": True}]) ) +except ImportError: + pass + @pytest.fixture def int1(): From bedfa5ce49595f29f0f642771020965497ead943 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 Feb 2024 09:34:44 +0100 Subject: [PATCH 34/70] (fix): `reduce` condition --- xarray/core/dataset.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a84d3aa7dc3..2e1d4c0cf80 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6834,11 +6834,13 @@ def reduce( if ( # Some reduction functions (e.g. std, var) need to run on variables # that don't have the reduce dims: PR5393 - not reduce_dims - or not numeric_only - or not is_extension_array_dtype(var.dtype) - or np.issubdtype(var.dtype, np.number) - or (var.dtype == np.bool_) + ( + not reduce_dims + or not numeric_only + or np.issubdtype(var.dtype, np.number) + or (var.dtype == np.bool_) + ) + and not is_extension_array_dtype(var.dtype) ): # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because From 82dbda924bdf3f5cdfbfe7c3e3656d11f74b5e04 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 Feb 2024 10:49:44 +0100 Subject: [PATCH 35/70] (fix): column ordering for dataframes --- xarray/core/dataset.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 03ea60f6bb3..b727c914bc9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7155,33 +7155,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame: ) def _to_dataframe(self, ordered_dims: Mapping[Any, int]): - columns = [ + columns_in_order = [k for k in self.variables if k not in self.dims] + non_extension_array_columns = [ k - for k in self.variables - if k not in self.dims - and not is_extension_array_dtype(self.variables[k].data) + for k in columns_in_order + if not is_extension_array_dtype(self.variables[k].data) ] extension_array_columns = [ k - for k in self.variables - if k not in self.dims and is_extension_array_dtype(self.variables[k].data) + for k in columns_in_order + if is_extension_array_dtype(self.variables[k].data) ] data = [ self._variables[k].set_dims(ordered_dims).values.reshape(-1) - for k in columns + for k in non_extension_array_columns ] index = self.coords.to_index([*ordered_dims]) - broadcasted_df = pd.DataFrame(dict(zip(columns, data)), index=index) + broadcasted_df = pd.DataFrame( + dict(zip(non_extension_array_columns, data)), index=index + ) for extension_array_column in extension_array_columns: extension_array = self.variables[extension_array_column].data.array index = self[self.variables[extension_array_column].dims[0]].data - cat_df = pd.DataFrame( + extension_array_df = pd.DataFrame( {extension_array_column: extension_array}, index=self[self.variables[extension_array_column].dims[0]].data, ) - cat_df.index.name = self.variables[extension_array_column].dims[0] - broadcasted_df = broadcasted_df.join(cat_df) - return broadcasted_df + extension_array_df.index.name = self.variables[extension_array_column].dims[ + 0 + ] + broadcasted_df = broadcasted_df.join(extension_array_df) + return broadcasted_df[columns_in_order] def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame: """Convert this dataset into a pandas.DataFrame. @@ -7365,7 +7369,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: obj._set_numpy_data_from_dataframe(idx, arrays, dims) for name, extension_array in extension_arrays: obj[name] = (dims, extension_array) - return obj + return obj[dataframe.columns] if len(dataframe.columns) else obj def to_dask_dataframe( self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False From 12217ed1c725f3dea1f35f486a15af1adf32201d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 Feb 2024 10:53:05 +0100 Subject: [PATCH 36/70] (refactor): remove encoding business --- xarray/coding/strings.py | 8 ++++---- xarray/coding/times.py | 12 ++++-------- xarray/coding/variables.py | 20 ++++++++------------ xarray/conventions.py | 12 ++++-------- 4 files changed, 20 insertions(+), 32 deletions(-) diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 9dd31c444c8..2c4501dfb0f 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -5,7 +5,6 @@ from functools import partial import numpy as np -from pandas.api.types import is_extension_array_dtype from xarray.coding.variables import ( VariableCoder, @@ -29,10 +28,11 @@ def create_vlen_dtype(element_type): def check_vlen_dtype(dtype): - if is_extension_array_dtype(dtype) or dtype.kind != "O" or dtype.metadata is None: + if dtype.kind != "O" or dtype.metadata is None: return None - # check xarray (element_type) as well as h5py (vlen) - return dtype.metadata.get("element_type", dtype.metadata.get("vlen")) + else: + # check xarray (element_type) as well as h5py (vlen) + return dtype.metadata.get("element_type", dtype.metadata.get("vlen")) def is_unicode_dtype(dtype): diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 1afa19676af..92bce0abeaa 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -9,7 +9,6 @@ import numpy as np import pandas as pd -from pandas.api.types import is_extension_array_dtype from pandas.errors import OutOfBoundsDatetime, OutOfBoundsTimedelta from xarray.coding.variables import ( @@ -969,10 +968,9 @@ def __init__(self, use_cftime: bool | None = None) -> None: self.use_cftime = use_cftime def encode(self, variable: Variable, name: T_Name = None) -> Variable: - if (not is_extension_array_dtype(variable.data)) and ( - np.issubdtype(variable.data.dtype, np.datetime64) - or contains_cftime_datetimes(variable) - ): + if np.issubdtype( + variable.data.dtype, np.datetime64 + ) or contains_cftime_datetimes(variable): dims, data, attrs, encoding = unpack_for_encoding(variable) units = encoding.pop("units", None) @@ -1010,9 +1008,7 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: class CFTimedeltaCoder(VariableCoder): def encode(self, variable: Variable, name: T_Name = None) -> Variable: - if (not is_extension_array_dtype(variable.data)) and np.issubdtype( - variable.data.dtype, np.timedelta64 - ): + if np.issubdtype(variable.data.dtype, np.timedelta64): dims, data, attrs, encoding = unpack_for_encoding(variable) data, units = encode_cf_timedelta( diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 526fe971655..b5e4167f2b2 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -9,7 +9,6 @@ import numpy as np import pandas as pd -from pandas.api.types import is_extension_array_dtype from xarray.core import dtypes, duck_array_ops, indexing from xarray.core.variable import Variable @@ -251,6 +250,7 @@ class CFMaskCoder(VariableCoder): def encode(self, variable: Variable, name: T_Name = None): dims, data, attrs, encoding = unpack_for_encoding(variable) + dtype = np.dtype(encoding.get("dtype", data.dtype)) fv = encoding.get("_FillValue") mv = encoding.get("missing_value") @@ -268,8 +268,6 @@ def encode(self, variable: Variable, name: T_Name = None): # special case DateTime to properly handle NaT is_time_like = _is_time_like(attrs.get("units")) - dtype = np.dtype(encoding.get("dtype", data.dtype)) - if fv_exists: # Ensure _FillValue is cast to same dtype as data's encoding["_FillValue"] = dtype.type(fv) @@ -474,18 +472,16 @@ class DefaultFillvalueCoder(VariableCoder): def encode(self, variable: Variable, name: T_Name = None) -> Variable: dims, data, attrs, encoding = unpack_for_encoding(variable) - has_no_fill = "_FillValue" not in attrs and "_FillValue" not in encoding # make NaN the fill value for float types - if is_extension_array_dtype(data): - if not has_no_fill: - raise ValueError( - "Found _FillValue encoding or attr on extension array." - ) - return variable - if has_no_fill and np.issubdtype(variable.dtype, np.floating): + if ( + "_FillValue" not in attrs + and "_FillValue" not in encoding + and np.issubdtype(variable.dtype, np.floating) + ): attrs["_FillValue"] = variable.dtype.type(np.nan) return Variable(dims, data, attrs, encoding, fastpath=True) - return variable + else: + return variable def decode(self, variable: Variable, name: T_Name = None) -> Variable: raise NotImplementedError() diff --git a/xarray/conventions.py b/xarray/conventions.py index 548f1ec7241..6eff45c5b2d 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd -from pandas.api.types import is_extension_array_dtype from xarray.coding import strings, times, variables from xarray.coding.variables import SerializationWarning, pop_to @@ -115,10 +114,7 @@ def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: dims, data, attrs, encoding = variables.unpack_for_encoding(var) # leave vlen dtypes unchanged - if ( - is_extension_array_dtype(data) - or strings.check_vlen_dtype(data.dtype) is not None - ): + if strings.check_vlen_dtype(data.dtype) is not None: return var if is_duck_dask_array(data): @@ -360,9 +356,9 @@ def _update_bounds_encoding(variables: T_Variables) -> None: attrs = v.attrs encoding = v.encoding has_date_units = "units" in encoding and "since" in encoding["units"] - is_datetime_type = (not is_extension_array_dtype(v)) and ( - contains_cftime_datetimes(v) or np.issubdtype(v.dtype, np.datetime64) - ) + is_datetime_type = np.issubdtype( + v.dtype, np.datetime64 + ) or contains_cftime_datetimes(v) if ( is_datetime_type From dd5b87df89ed4005f4e558498fe17c751a82a100 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 13 Feb 2024 10:53:26 +0100 Subject: [PATCH 37/70] (refactor): raise error for dask + extension array --- xarray/core/variable.py | 5 +++++ xarray/tests/test_variable.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 8fc88c0fee1..8598d57f548 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2543,6 +2543,11 @@ def chunk( # type: ignore[override] dask.array.from_array """ + if is_extension_array_dtype(self): + raise ValueError( + f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first." + ) + if from_array_kwargs is None: from_array_kwargs = {} diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index f25d7db4efa..3f573a6fe0a 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -2361,6 +2361,11 @@ def test_multiindex(self): def test_pad(self, mode, xr_arg, np_arg): super().test_pad(mode, xr_arg, np_arg) + def test_pandas_cateogrical_dtype(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + with pytest.raises(ValueError, match="was found to be a Pandas ExtensionArray"): + self.cls("x", data) + @requires_sparse class TestVariableWithSparse: From e0d58fa86652b603a25ef4ea046902c863a0e96f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Feb 2024 16:43:06 +0100 Subject: [PATCH 38/70] (fix): only wrap `ExtensionDuckArray` that has a `.array` which is a pandas extension array --- xarray/core/variable.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 8598d57f548..28feed8f44d 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -174,7 +174,11 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) - if is_extension_array_dtype(data): + if ( + is_extension_array_dtype(data) + and isinstance(data, ExtensionDuckArray) + and isinstance(data.array, pd.api.extensions.ExtensionArray) + ): data_type = ( type(data.array) if isinstance(data, ExtensionDuckArray) else type(data) ) From c1e0e64e7150b7bb3ecaea7994a4af7ec69d57fb Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Feb 2024 16:43:21 +0100 Subject: [PATCH 39/70] (fix): use duck array equality method, not pandas --- xarray/tests/test_duck_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 15ddb48d457..166aa2f182a 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1109,7 +1109,7 @@ def test_duck_extension_array_equality(categorical1, int1): categorical_duck_array = ExtensionDuckArray(categorical1) assert (int_duck_array != categorical_duck_array).all() assert (categorical_duck_array == categorical1).all() - assert (int1[0:2] == int_duck_array[0:2]).all() + assert (int_duck_array[0:2] == int1[0:2]).all() def test_duck_extension_array_repr(int1): From 17e3390294d27b24d56826061be47d1bbc5d56b6 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Feb 2024 17:30:05 +0100 Subject: [PATCH 40/70] (refactor): bye plum! --- pyproject.toml | 3 +- xarray/core/duck_array_ops.py | 13 ------- xarray/tests/test_duck_array_ops.py | 58 ----------------------------- 3 files changed, 1 insertion(+), 73 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cce10aef4ab..1cef1fd24cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ [project.optional-dependencies] accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"] -complete = ["xarray[accel,io,parallel,viz,dev,extension_arrays]"] +complete = ["xarray[accel,io,parallel,viz,dev]"] dev = [ "hypothesis", "pre-commit", @@ -45,7 +45,6 @@ dev = [ io = ["netCDF4", "h5netcdf", "scipy", 'pydap; python_version<"3.10"', "zarr", "fsspec", "cftime", "pooch"] parallel = ["dask[complete]"] viz = ["matplotlib", "seaborn", "nc-time-axis"] -extension-arrays = ["plum-dispatch"] [project.urls] Documentation = "https://docs.xarray.dev" diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 2cb12695b12..af6088f9a9e 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -36,14 +36,6 @@ from packaging.version import Version from pandas.api.types import is_extension_array_dtype -try: - from plum import dispatch # type: ignore[import-not-found] -except ImportError: - - def dispatch(func): - return func - - from xarray.core import dask_array_ops, dtypes, nputils from xarray.core.options import OPTIONS from xarray.core.types import DTypeLikeSave, T_ExtensionArray @@ -80,7 +72,6 @@ def decorator(func): @implements(np.issubdtype) -@dispatch def __extension_duck_array__issubdtype( extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave ) -> bool: @@ -88,7 +79,6 @@ def __extension_duck_array__issubdtype( @implements(np.broadcast_to) -@dispatch def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): if shape[0] == len(arr) and len(shape) == 1: return arr @@ -96,13 +86,11 @@ def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): @implements(np.stack) -@dispatch def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") @implements(np.concatenate) -@dispatch def __extension_duck_array__concatenate( arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None ) -> T_ExtensionArray: @@ -110,7 +98,6 @@ def __extension_duck_array__concatenate( @implements(np.where) -@dispatch def __extension_duck_array__where( condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray ) -> T_ExtensionArray: diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 166aa2f182a..fc68c06e661 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -2,8 +2,6 @@ import datetime as dt import warnings -from collections.abc import Sequence -from importlib.util import find_spec import numpy as np import pandas as pd @@ -13,9 +11,6 @@ from xarray import DataArray, Dataset, cftime_range, concat from xarray.core import dtypes, duck_array_ops from xarray.core.duck_array_ops import ( - __extension_duck_array__broadcast, - __extension_duck_array__concatenate, - __extension_duck_array__where, array_notnull_equiv, concatenate, count, @@ -44,7 +39,6 @@ requires_bottleneck, requires_cftime, requires_dask, - requires_plum, requires_pyarrow, ) @@ -94,16 +88,6 @@ def int2(): ) -if find_spec("plum"): - - @__extension_duck_array__concatenate.dispatch - def _(arrays: Sequence[pd.arrays.IntegerArray], axis: int = 0, out=None): - values = np.concatenate(arrays) - mask = np.isnan(values) - values = values.astype("int8") - return pd.arrays.IntegerArray(values, mask) - - class TestOps: @pytest.fixture(autouse=True) def setUp(self): @@ -1062,48 +1046,6 @@ def test_push_dask(): np.testing.assert_equal(actual, expected) -def test_where_all_categoricals(categorical1, categorical2): - assert ( - __extension_duck_array__where( - np.array([True, False, True, False, False]), categorical1, categorical2 - ) - == pd.Categorical(["cat1", "cat1", "cat2", "cat3", "cat1"]) - ).all() - - -def test_where_drop_categoricals(categorical1, categorical2): - assert ( - __extension_duck_array__where( - np.array([False, True, True, False, True]), categorical1, categorical2 - ).remove_unused_categories() - == pd.Categorical(["cat2", "cat2", "cat2", "cat3", "cat2"]) - ).all() - - -def test_broadcast_to_categorical(categorical1): - with pytest.raises(NotImplementedError): - __extension_duck_array__broadcast(categorical1, (5, 2)) - - -def test_broadcast_to_same_categorical(categorical1): - assert (__extension_duck_array__broadcast(categorical1, (5,)) == categorical1).all() - - -def test_concategorical_categorical(categorical1, categorical2): - assert ( - __extension_duck_array__concatenate([categorical1, categorical2]) - == type(categorical1)._concat_same_type((categorical1, categorical2)) - ).all() - - -@requires_plum -def test_integer_array_register_concatenate(int1, int2): - assert ( - __extension_duck_array__concatenate([int1, int2]) - == type(int1)._concat_same_type((int1, int2)) - ).all() - - def test_duck_extension_array_equality(categorical1, int1): int_duck_array = ExtensionDuckArray(int1) categorical_duck_array = ExtensionDuckArray(categorical1) From c8e6bfedc2760dbb065cd8b98206f2469a4063c6 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 15 Feb 2024 17:38:58 +0100 Subject: [PATCH 41/70] (fix): `and` to `or` for casting to `ExtensionDuckArray` --- xarray/core/variable.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 28feed8f44d..f529edf773e 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -174,9 +174,8 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) - if ( - is_extension_array_dtype(data) - and isinstance(data, ExtensionDuckArray) + if is_extension_array_dtype(data) or ( + isinstance(data, ExtensionDuckArray) and isinstance(data.array, pd.api.extensions.ExtensionArray) ): data_type = ( From b2a9517d0ed53808be1ee1ab7b6518f5340247d6 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 16 Feb 2024 13:40:20 +0100 Subject: [PATCH 42/70] (fix): check for class, not type --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f529edf773e..7180c8bc409 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -174,7 +174,7 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) - if is_extension_array_dtype(data) or ( + if isinstance(data, pd.api.extensions.ExtensionArray) or ( isinstance(data, ExtensionDuckArray) and isinstance(data.array, pd.api.extensions.ExtensionArray) ): From 407fad1e2912b4b64cc5699203c5c9be27f1f013 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 19 Feb 2024 15:39:08 +0100 Subject: [PATCH 43/70] (fix): only support native endianness --- properties/test_pandas_roundtrip.py | 4 +++- xarray/testing/strategies.py | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 5c0f14976e6..3d87fcce1d9 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -17,7 +17,9 @@ from hypothesis import given # isort:skip numeric_dtypes = st.one_of( - npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes() + npst.unsigned_integer_dtypes(endianness="="), + npst.integer_dtypes(endianness="="), + npst.floating_dtypes(endianness="="), ) numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt)) diff --git a/xarray/testing/strategies.py b/xarray/testing/strategies.py index c5a7afdf54e..79385bc5226 100644 --- a/xarray/testing/strategies.py +++ b/xarray/testing/strategies.py @@ -44,7 +44,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: Generates only those numpy dtypes which xarray can handle. Use instead of hypothesis.extra.numpy.scalar_dtypes in order to exclude weirder dtypes such as unicode, byte_string, array, or nested dtypes. - Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows. + Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows. Checks only native endianness. Requires the hypothesis package to be installed. @@ -55,10 +55,10 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: # TODO should this be exposed publicly? # We should at least decide what the set of numpy dtypes that xarray officially supports is. return ( - npst.integer_dtypes() - | npst.unsigned_integer_dtypes() - | npst.floating_dtypes() - | npst.complex_number_dtypes() + npst.integer_dtypes(endianness="=") + | npst.unsigned_integer_dtypes(endianness="=") + | npst.floating_dtypes(endianness="=") + | npst.complex_number_dtypes(endianness="=") ) From 1c9047fefb1eb71534c6fa55abf59203b75e1ff2 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 22 Feb 2024 11:02:26 +0100 Subject: [PATCH 44/70] (refactor): no need for superfluous checks in `_maybe_wrap_data` --- xarray/core/variable.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 7180c8bc409..679c439cc53 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -174,14 +174,8 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) - if isinstance(data, pd.api.extensions.ExtensionArray) or ( - isinstance(data, ExtensionDuckArray) - and isinstance(data.array, pd.api.extensions.ExtensionArray) - ): - data_type = ( - type(data.array) if isinstance(data, ExtensionDuckArray) else type(data) - ) - return ExtensionDuckArray[data_type](data) + if isinstance(data, pd.api.extensions.ExtensionArray): + return ExtensionDuckArray[type(data)](data) return data From d9304f105e0359934d956f474d6660f75a0ce3a4 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 22 Feb 2024 17:51:37 +0100 Subject: [PATCH 45/70] (chore): clean up docs to no longer reference `plum` --- xarray/core/indexing.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 2b0c4450d1f..095047d8bb6 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1460,27 +1460,11 @@ class ExtensionDuckArray(Generic[T_ExtensionArray]): def __init__(self, array: T_ExtensionArray | ExtensionDuckArray): """NEP-18 compliant wrapper for pandas extension arrays. - To override behavior of common applicable numpy `__array_function__` methods, - please use :py:module:`plum` to dispatch the corresponding function with the prefix - `__extension_duck_array__` i.e., for :py:func:`numpy.where`. The only function that is - necessary to override at the moment is :py:func:`np.where` since :py:module:`pandas` provides - concatenation, and `broadcast`/`stack` are not supported. Parameters ---------- array : T_ExtensionArray | ExtensionDuckArray The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. - If also an ExtensionDuckArray, it's `array` attribute will be extracted and assigned to `self.array`. - - - Examples - -------- - The following is an example of how you might support some different kinds of arrays. - ```python - from plum import dispatch - @dispatch - def ___extension_duck_array__where__(cond: np.ndarray, x: pd.arrays.IntegerArray, y: pd.arrays.BooleanArray) -> pd.arrays.IntegerArray: - pass ``` """ if isinstance(array, ExtensionDuckArray): From 6ec6725ac5f8d77aaed58ced2b4cff474af28284 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 22 Feb 2024 17:52:10 +0100 Subject: [PATCH 46/70] (fix): no longer allow `ExtensionDuckArray` to wrap `ExtensionDuckArray` --- xarray/core/indexing.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 095047d8bb6..f70a431666d 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1458,18 +1458,16 @@ def transpose(self, order): class ExtensionDuckArray(Generic[T_ExtensionArray]): array: T_ExtensionArray - def __init__(self, array: T_ExtensionArray | ExtensionDuckArray): + def __init__(self, array: T_ExtensionArray): """NEP-18 compliant wrapper for pandas extension arrays. Parameters ---------- - array : T_ExtensionArray | ExtensionDuckArray + array : T_ExtensionArray The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. ``` """ - if isinstance(array, ExtensionDuckArray): - self.array = array.array - elif is_extension_array_dtype(array): + if isinstance(array, pd.api.extensions.ExtensionArray): self.array = array else: raise TypeError(f"{array} is not an pandas ExtensionArray.") From bc9ac4ce7db94110516913a74d2e439def65a69e Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 22 Feb 2024 18:04:00 +0100 Subject: [PATCH 47/70] (refactor): move `implements` logic to `indexing` --- xarray/core/duck_array_ops.py | 56 -------------------------------- xarray/core/indexing.py | 61 ++++++++++++++++++++++++++++++++--- 2 files changed, 57 insertions(+), 60 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index af6088f9a9e..69b9d64df51 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -10,10 +10,8 @@ import datetime import inspect import warnings -from collections.abc import Sequence from functools import partial from importlib import import_module -from typing import Callable import numpy as np import pandas as pd @@ -38,7 +36,6 @@ from xarray.core import dask_array_ops, dtypes, nputils from xarray.core.options import OPTIONS -from xarray.core.types import DTypeLikeSave, T_ExtensionArray from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available from xarray.namedarray import pycompat from xarray.namedarray.parallelcompat import get_chunked_array_type @@ -58,59 +55,6 @@ dask_available = module_available("dask") -HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} - - -def implements(numpy_function): - """Register an __array_function__ implementation for MyArray objects.""" - - def decorator(func): - HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func - return func - - return decorator - - -@implements(np.issubdtype) -def __extension_duck_array__issubdtype( - extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave -) -> bool: - return False # never want a function to think a pandas extension dtype is a subtype of numpy - - -@implements(np.broadcast_to) -def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): - if shape[0] == len(arr) and len(shape) == 1: - return arr - raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") - - -@implements(np.stack) -def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): - raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") - - -@implements(np.concatenate) -def __extension_duck_array__concatenate( - arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None -) -> T_ExtensionArray: - return type(arrays[0])._concat_same_type(arrays) - - -@implements(np.where) -def __extension_duck_array__where( - condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray -) -> T_ExtensionArray: - if ( - isinstance(x, pd.Categorical) - and isinstance(y, pd.Categorical) - and x.dtype != y.dtype - ): - x = x.add_categories(set(y.categories).difference(set(x.categories))) - y = y.add_categories(set(x.categories).difference(set(y.categories))) - return pd.Series(x).where(condition, pd.Series(y)).array - - def get_array_namespace(x): if hasattr(x, "__array_namespace__"): return x.__array_namespace__() diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index f70a431666d..49e02917ba0 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -4,7 +4,7 @@ import functools import operator from collections import Counter, defaultdict -from collections.abc import Hashable, Mapping +from collections.abc import Hashable, Mapping, Sequence from contextlib import suppress from dataclasses import dataclass, field from datetime import timedelta @@ -18,7 +18,7 @@ from xarray.core import duck_array_ops from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS -from xarray.core.types import T_ExtensionArray, T_Xarray +from xarray.core.types import DTypeLikeSave, T_ExtensionArray, T_Xarray from xarray.core.utils import ( NDArrayMixin, either_dict_or_kwargs, @@ -1455,6 +1455,59 @@ def transpose(self, order): return self.array.transpose(order) +HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} + + +def implements(numpy_function): + """Register an __array_function__ implementation for MyArray objects.""" + + def decorator(func): + HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +@implements(np.issubdtype) +def __extension_duck_array__issubdtype( + extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave +) -> bool: + return False # never want a function to think a pandas extension dtype is a subtype of numpy + + +@implements(np.broadcast_to) +def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): + if shape[0] == len(arr) and len(shape) == 1: + return arr + raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") + + +@implements(np.stack) +def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): + raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") + + +@implements(np.concatenate) +def __extension_duck_array__concatenate( + arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None +) -> T_ExtensionArray: + return type(arrays[0])._concat_same_type(arrays) + + +@implements(np.where) +def __extension_duck_array__where( + condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray +) -> T_ExtensionArray: + if ( + isinstance(x, pd.Categorical) + and isinstance(y, pd.Categorical) + and x.dtype != y.dtype + ): + x = x.add_categories(set(y.categories).difference(set(x.categories))) + y = y.add_categories(set(x.categories).difference(set(y.categories))) + return pd.Series(x).where(condition, pd.Series(y)).array + + class ExtensionDuckArray(Generic[T_ExtensionArray]): array: T_ExtensionArray @@ -1489,9 +1542,9 @@ def replace_duck_with_extension_array(args) -> list: return args_as_list args = tuple(replace_duck_with_extension_array(args)) - if func not in duck_array_ops.HANDLED_EXTENSION_ARRAY_FUNCTIONS: + if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: return func(*args, **kwargs) - res = duck_array_ops.HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) + res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) if is_extension_array_dtype(res): return type(self)[type(res)](res) return res From 6fb866815508ef0a259a0be7ed3ae563acf007b5 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 29 Feb 2024 09:11:15 +0100 Subject: [PATCH 48/70] (refactor): `indexing.py` -> `extension_array.py` --- xarray/core/extension_array.py | 132 ++++++++++++++++++++++++++++ xarray/core/indexing.py | 130 +-------------------------- xarray/core/variable.py | 2 +- xarray/tests/test_duck_array_ops.py | 2 +- 4 files changed, 137 insertions(+), 129 deletions(-) create mode 100644 xarray/core/extension_array.py diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py new file mode 100644 index 00000000000..c6c95d2cb4f --- /dev/null +++ b/xarray/core/extension_array.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Callable, Generic + +import numpy as np +import pandas as pd +from pandas.api.types import is_extension_array_dtype + +from xarray.core.types import DTypeLikeSave, T_ExtensionArray + +HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} + + +def implements(numpy_function): + """Register an __array_function__ implementation for MyArray objects.""" + + def decorator(func): + HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +@implements(np.issubdtype) +def __extension_duck_array__issubdtype( + extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave +) -> bool: + return False # never want a function to think a pandas extension dtype is a subtype of numpy + + +@implements(np.broadcast_to) +def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): + if shape[0] == len(arr) and len(shape) == 1: + return arr + raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") + + +@implements(np.stack) +def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): + raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") + + +@implements(np.concatenate) +def __extension_duck_array__concatenate( + arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None +) -> T_ExtensionArray: + return type(arrays[0])._concat_same_type(arrays) + + +@implements(np.where) +def __extension_duck_array__where( + condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray +) -> T_ExtensionArray: + if ( + isinstance(x, pd.Categorical) + and isinstance(y, pd.Categorical) + and x.dtype != y.dtype + ): + x = x.add_categories(set(y.categories).difference(set(x.categories))) + y = y.add_categories(set(x.categories).difference(set(y.categories))) + return pd.Series(x).where(condition, pd.Series(y)).array + + +class ExtensionDuckArray(Generic[T_ExtensionArray]): + array: T_ExtensionArray + + def __init__(self, array: T_ExtensionArray): + """NEP-18 compliant wrapper for pandas extension arrays. + + Parameters + ---------- + array : T_ExtensionArray + The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. + ``` + """ + if isinstance(array, pd.api.extensions.ExtensionArray): + self.array = array + else: + raise TypeError(f"{array} is not an pandas ExtensionArray.") + + def __array_function__(self, func, types, args, kwargs): + def replace_duck_with_extension_array(args) -> list: + args_as_list = list(args) + for index, value in enumerate(args_as_list): + if isinstance(value, ExtensionDuckArray): + args_as_list[index] = value.array + elif isinstance( + value, tuple + ): # should handle more than just tuple? iterable? + args_as_list[index] = tuple( + replace_duck_with_extension_array(value) + ) + elif isinstance(value, list): + args_as_list[index] = replace_duck_with_extension_array(value) + return args_as_list + + args = tuple(replace_duck_with_extension_array(args)) + if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: + return func(*args, **kwargs) + res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) + if is_extension_array_dtype(res): + return type(self)[type(res)](res) + return res + + def __array_ufunc__(ufunc, method, *inputs, **kwargs): + return ufunc(*inputs, **kwargs) + + def __repr__(self): + return f"{type(self)}(array={repr(self.array)})" + + def __getattr__(self, attr: str) -> object: + if hasattr(self.array, attr): + return getattr(self.array, attr) + raise AttributeError(f"{attr} not found.") + + def __getitem__(self, key) -> ExtensionDuckArray[T_ExtensionArray]: + item = self.array[key] + if is_extension_array_dtype(item): # not a singleton - better way to check? + return type(self)(item) + return item + + def __setitem__(self, key, val): + self.array[key] = val + + def __eq__(self, other): + if isinstance(other, ExtensionDuckArray): + return self.array == other.array + return self.array == other + + def __ne__(self, other): + return ~(self == other) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 845ba1f5ec1..62889e03861 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -4,21 +4,20 @@ import functools import operator from collections import Counter, defaultdict -from collections.abc import Hashable, Mapping, Sequence +from collections.abc import Hashable, Mapping from contextlib import suppress from dataclasses import dataclass, field from datetime import timedelta from html import escape -from typing import TYPE_CHECKING, Any, Callable, Generic +from typing import TYPE_CHECKING, Any, Callable import numpy as np import pandas as pd -from pandas.api.types import is_extension_array_dtype from xarray.core import duck_array_ops from xarray.core.nputils import NumpyVIndexAdapter from xarray.core.options import OPTIONS -from xarray.core.types import DTypeLikeSave, T_ExtensionArray, T_Xarray +from xarray.core.types import T_Xarray from xarray.core.utils import ( NDArrayMixin, either_dict_or_kwargs, @@ -1526,129 +1525,6 @@ def transpose(self, order): return self.array.transpose(order) -HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} - - -def implements(numpy_function): - """Register an __array_function__ implementation for MyArray objects.""" - - def decorator(func): - HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func - return func - - return decorator - - -@implements(np.issubdtype) -def __extension_duck_array__issubdtype( - extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave -) -> bool: - return False # never want a function to think a pandas extension dtype is a subtype of numpy - - -@implements(np.broadcast_to) -def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): - if shape[0] == len(arr) and len(shape) == 1: - return arr - raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") - - -@implements(np.stack) -def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): - raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") - - -@implements(np.concatenate) -def __extension_duck_array__concatenate( - arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None -) -> T_ExtensionArray: - return type(arrays[0])._concat_same_type(arrays) - - -@implements(np.where) -def __extension_duck_array__where( - condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray -) -> T_ExtensionArray: - if ( - isinstance(x, pd.Categorical) - and isinstance(y, pd.Categorical) - and x.dtype != y.dtype - ): - x = x.add_categories(set(y.categories).difference(set(x.categories))) - y = y.add_categories(set(x.categories).difference(set(y.categories))) - return pd.Series(x).where(condition, pd.Series(y)).array - - -class ExtensionDuckArray(Generic[T_ExtensionArray]): - array: T_ExtensionArray - - def __init__(self, array: T_ExtensionArray): - """NEP-18 compliant wrapper for pandas extension arrays. - - Parameters - ---------- - array : T_ExtensionArray - The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. - ``` - """ - if isinstance(array, pd.api.extensions.ExtensionArray): - self.array = array - else: - raise TypeError(f"{array} is not an pandas ExtensionArray.") - - def __array_function__(self, func, types, args, kwargs): - def replace_duck_with_extension_array(args) -> list: - args_as_list = list(args) - for index, value in enumerate(args_as_list): - if isinstance(value, ExtensionDuckArray): - args_as_list[index] = value.array - elif isinstance( - value, tuple - ): # should handle more than just tuple? iterable? - args_as_list[index] = tuple( - replace_duck_with_extension_array(value) - ) - elif isinstance(value, list): - args_as_list[index] = replace_duck_with_extension_array(value) - return args_as_list - - args = tuple(replace_duck_with_extension_array(args)) - if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: - return func(*args, **kwargs) - res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) - if is_extension_array_dtype(res): - return type(self)[type(res)](res) - return res - - def __array_ufunc__(ufunc, method, *inputs, **kwargs): - return ufunc(*inputs, **kwargs) - - def __repr__(self): - return f"{type(self)}(array={repr(self.array)})" - - def __getattr__(self, attr: str) -> object: - if hasattr(self.array, attr): - return getattr(self.array, attr) - raise AttributeError(f"{attr} not found.") - - def __getitem__(self, key) -> ExtensionDuckArray[T_ExtensionArray]: - item = self.array[key] - if is_extension_array_dtype(item): # not a singleton - better way to check? - return type(self)(item) - return item - - def __setitem__(self, key, val): - self.array[key] = val - - def __eq__(self, other): - if isinstance(other, ExtensionDuckArray): - return self.array == other.array - return self.array == other - - def __ne__(self, other): - return ~(self == other) - - class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" diff --git a/xarray/core/variable.py b/xarray/core/variable.py index a806b079db8..9a033375fe3 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -19,9 +19,9 @@ from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from xarray.core.arithmetic import VariableArithmetic from xarray.core.common import AbstractArray +from xarray.core.extension_array import ExtensionDuckArray from xarray.core.indexing import ( BasicIndexer, - ExtensionDuckArray, OuterIndexer, PandasIndexingAdapter, VectorizedIndexer, diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index fc68c06e661..25bb8fb341a 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -27,7 +27,7 @@ timedelta_to_numeric, where, ) -from xarray.core.indexing import ExtensionDuckArray +from xarray.core.extension_array import ExtensionDuckArray from xarray.namedarray.pycompat import array_type from xarray.testing import assert_allclose, assert_equal, assert_identical from xarray.tests import ( From 8f034b47ab335be5dc0f39fa4d9fc224520bcbe5 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 29 Feb 2024 09:12:50 +0100 Subject: [PATCH 49/70] (refactor): `ExtensionDuckArray` -> `PandasExtensionArray` --- xarray/core/extension_array.py | 8 ++++---- xarray/core/variable.py | 4 ++-- xarray/tests/test_duck_array_ops.py | 32 ++++++++++++++--------------- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index c6c95d2cb4f..25b79e4d384 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -62,7 +62,7 @@ def __extension_duck_array__where( return pd.Series(x).where(condition, pd.Series(y)).array -class ExtensionDuckArray(Generic[T_ExtensionArray]): +class PandasExtensionArray(Generic[T_ExtensionArray]): array: T_ExtensionArray def __init__(self, array: T_ExtensionArray): @@ -83,7 +83,7 @@ def __array_function__(self, func, types, args, kwargs): def replace_duck_with_extension_array(args) -> list: args_as_list = list(args) for index, value in enumerate(args_as_list): - if isinstance(value, ExtensionDuckArray): + if isinstance(value, PandasExtensionArray): args_as_list[index] = value.array elif isinstance( value, tuple @@ -114,7 +114,7 @@ def __getattr__(self, attr: str) -> object: return getattr(self.array, attr) raise AttributeError(f"{attr} not found.") - def __getitem__(self, key) -> ExtensionDuckArray[T_ExtensionArray]: + def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: item = self.array[key] if is_extension_array_dtype(item): # not a singleton - better way to check? return type(self)(item) @@ -124,7 +124,7 @@ def __setitem__(self, key, val): self.array[key] = val def __eq__(self, other): - if isinstance(other, ExtensionDuckArray): + if isinstance(other, PandasExtensionArray): return self.array == other.array return self.array == other diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9a033375fe3..74dd2ae5b6b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -19,7 +19,7 @@ from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from xarray.core.arithmetic import VariableArithmetic from xarray.core.common import AbstractArray -from xarray.core.extension_array import ExtensionDuckArray +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ( BasicIndexer, OuterIndexer, @@ -171,7 +171,7 @@ def _maybe_wrap_data(data): if isinstance(data, pd.Index): return PandasIndexingAdapter(data) if isinstance(data, pd.api.extensions.ExtensionArray): - return ExtensionDuckArray[type(data)](data) + return PandasExtensionArray[type(data)](data) return data diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 25bb8fb341a..196f12bc44b 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -27,7 +27,7 @@ timedelta_to_numeric, where, ) -from xarray.core.extension_array import ExtensionDuckArray +from xarray.core.extension_array import PandasExtensionArray from xarray.namedarray.pycompat import array_type from xarray.testing import assert_allclose, assert_equal, assert_identical from xarray.tests import ( @@ -167,10 +167,10 @@ def test_where_type_promotion(self): def test_where_extension_duck_array(self, categorical1, categorical2): where_res = where( np.array([True, False, True, False, False]), - ExtensionDuckArray(categorical1), - ExtensionDuckArray(categorical2), + PandasExtensionArray(categorical1), + PandasExtensionArray(categorical2), ) - assert isinstance(where_res, ExtensionDuckArray) + assert isinstance(where_res, PandasExtensionArray) assert ( where_res == pd.Categorical(["cat1", "cat1", "cat2", "cat3", "cat1"]) ).all() @@ -178,7 +178,7 @@ def test_where_extension_duck_array(self, categorical1, categorical2): def test_where_extension_duck_array_fallback(self, categorical1, categorical2): where_res = where( np.array([True, False, True, False, False]), - ExtensionDuckArray(categorical1), + PandasExtensionArray(categorical1), np.array(categorical2), ) assert isinstance(where_res, np.ndarray) @@ -186,9 +186,9 @@ def test_where_extension_duck_array_fallback(self, categorical1, categorical2): def test_concatenate_extension_duck_array(self, categorical1, categorical2): concate_res = concatenate( - [ExtensionDuckArray(categorical1), ExtensionDuckArray(categorical2)] + [PandasExtensionArray(categorical1), PandasExtensionArray(categorical2)] ) - assert isinstance(concate_res, ExtensionDuckArray) + assert isinstance(concate_res, PandasExtensionArray) assert ( concate_res == type(categorical1)._concat_same_type((categorical1, categorical2)) @@ -197,7 +197,7 @@ def test_concatenate_extension_duck_array(self, categorical1, categorical2): @requires_pyarrow def test_duck_extension_array_pyarrow_concatenate(self, arrow1, arrow2): concatenated = concatenate( - (ExtensionDuckArray(arrow1), ExtensionDuckArray(arrow2)) + (PandasExtensionArray(arrow1), PandasExtensionArray(arrow2)) ) assert concatenated[2]["x"] == 3 assert concatenated[3]["y"] @@ -206,7 +206,7 @@ def test_concatenate_extension_duck_array_fallback( self, categorical1, categorical2 ): concate_res = concatenate( - [ExtensionDuckArray(categorical1), np.array(categorical2)] + [PandasExtensionArray(categorical1), np.array(categorical2)] ) assert isinstance(concate_res, np.ndarray) assert ( @@ -217,15 +217,15 @@ def test_concatenate_extension_duck_array_fallback( ).all() def test___getitem__extension_duck_array(self, categorical1): - extension_duck_array = ExtensionDuckArray(categorical1) + extension_duck_array = PandasExtensionArray(categorical1) assert (extension_duck_array[0:2] == categorical1[0:2]).all() - assert isinstance(extension_duck_array[0:2], ExtensionDuckArray) + assert isinstance(extension_duck_array[0:2], PandasExtensionArray) assert extension_duck_array[0] == categorical1[0] mask = [True, False, True, False, True] assert (extension_duck_array[mask] == categorical1[mask]).all() def test__setitem__extension_duck_array(self, categorical1): - extension_duck_array = ExtensionDuckArray(categorical1) + extension_duck_array = PandasExtensionArray(categorical1) extension_duck_array[2] = "cat1" # already existing category assert extension_duck_array[2] == "cat1" with pytest.raises(TypeError, match="Cannot setitem on a Categorical"): @@ -1047,18 +1047,18 @@ def test_push_dask(): def test_duck_extension_array_equality(categorical1, int1): - int_duck_array = ExtensionDuckArray(int1) - categorical_duck_array = ExtensionDuckArray(categorical1) + int_duck_array = PandasExtensionArray(int1) + categorical_duck_array = PandasExtensionArray(categorical1) assert (int_duck_array != categorical_duck_array).all() assert (categorical_duck_array == categorical1).all() assert (int_duck_array[0:2] == int1[0:2]).all() def test_duck_extension_array_repr(int1): - int_duck_array = ExtensionDuckArray(int1) + int_duck_array = PandasExtensionArray(int1) assert repr(int1) in repr(int_duck_array) def test_duck_extension_array_attr(int1): - int_duck_array = ExtensionDuckArray(int1) + int_duck_array = PandasExtensionArray(int1) assert (~int_duck_array.fillna(10)).all() From 661d9f2cf450996ea4f07fac153e3c7175b2c73f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 25 Mar 2024 17:32:05 +0100 Subject: [PATCH 50/70] (fix): add writeable property --- xarray/core/extension_array.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 25b79e4d384..b2bb7541b9d 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -79,6 +79,10 @@ def __init__(self, array: T_ExtensionArray): else: raise TypeError(f"{array} is not an pandas ExtensionArray.") + @property + def writeable(self): + return False # TODO: creat i/o within xarray for at common types i.e., categoricals + def __array_function__(self, func, types, args, kwargs): def replace_duck_with_extension_array(args) -> list: args_as_list = list(args) From caee1c6649d65277eaea3767566c406235f799cb Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 25 Mar 2024 17:44:52 +0100 Subject: [PATCH 51/70] (fix): don't check writeable for `PandasExtensionArray` --- xarray/core/extension_array.py | 4 ---- xarray/tests/__init__.py | 5 ++++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index b2bb7541b9d..25b79e4d384 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -79,10 +79,6 @@ def __init__(self, array: T_ExtensionArray): else: raise TypeError(f"{array} is not an pandas ExtensionArray.") - @property - def writeable(self): - return False # TODO: creat i/o within xarray for at common types i.e., categoricals - def __array_function__(self, func, types, args, kwargs): def replace_duck_with_extension_array(args) -> list: args_as_list = list(args) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index e4a3c2539de..4ea301f4fb8 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -18,6 +18,7 @@ from xarray import Dataset from xarray.core import utils from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ExplicitlyIndexed from xarray.core.options import set_options from xarray.core.variable import IndexVariable @@ -52,7 +53,9 @@ def assert_writeable(ds): readonly = [ name for name, var in ds.variables.items() - if not isinstance(var, IndexVariable) and not var.data.flags.writeable + if not isinstance(var, IndexVariable) + and not var.data.flags.writeable + and not isinstance(var.data, PandasExtensionArray) ] assert not readonly, readonly From 1d12f5e1ea545592575efd587679444a117090fc Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 25 Mar 2024 17:48:38 +0100 Subject: [PATCH 52/70] (fix): move check eariler --- xarray/tests/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 4ea301f4fb8..1685e2f4824 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -54,8 +54,8 @@ def assert_writeable(ds): name for name, var in ds.variables.items() if not isinstance(var, IndexVariable) - and not var.data.flags.writeable and not isinstance(var.data, PandasExtensionArray) + and not var.data.flags.writeable ] assert not readonly, readonly From 902c74bf8f1619ff0a5c70713c7ea14ba296d395 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 28 Mar 2024 10:10:53 +0100 Subject: [PATCH 53/70] (refactor): correct guard clause --- xarray/core/extension_array.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 25b79e4d384..c8b89314b94 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -74,10 +74,9 @@ def __init__(self, array: T_ExtensionArray): The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. ``` """ - if isinstance(array, pd.api.extensions.ExtensionArray): - self.array = array - else: + if not isinstance(array, pd.api.extensions.ExtensionArray): raise TypeError(f"{array} is not an pandas ExtensionArray.") + self.array = array def __array_function__(self, func, types, args, kwargs): def replace_duck_with_extension_array(args) -> list: From 0b64506b70bdd996eb5cbbfd4adbdb9f048d9f6b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 28 Mar 2024 10:16:38 +0100 Subject: [PATCH 54/70] (chore): remove unnecessary `AttributeError` --- xarray/core/extension_array.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index c8b89314b94..3523480fe6c 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -109,9 +109,7 @@ def __repr__(self): return f"{type(self)}(array={repr(self.array)})" def __getattr__(self, attr: str) -> object: - if hasattr(self.array, attr): - return getattr(self.array, attr) - raise AttributeError(f"{attr} not found.") + return getattr(self.array, attr) def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: item = self.array[key] From 0c7e0237c6ec64e91cb99ab3c9bc2eeb1f94687b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 28 Mar 2024 10:17:22 +0100 Subject: [PATCH 55/70] (feat): singleton wrapped as array --- xarray/core/extension_array.py | 4 +++- xarray/tests/test_duck_array_ops.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 3523480fe6c..0f40f6a500d 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -113,8 +113,10 @@ def __getattr__(self, attr: str) -> object: def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: item = self.array[key] - if is_extension_array_dtype(item): # not a singleton - better way to check? + if is_extension_array_dtype(item): return type(self)(item) + if np.isscalar(item): + return type(self)([item]) return item def __setitem__(self, key, val): diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 196f12bc44b..f9c494fa80f 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -221,6 +221,7 @@ def test___getitem__extension_duck_array(self, categorical1): assert (extension_duck_array[0:2] == categorical1[0:2]).all() assert isinstance(extension_duck_array[0:2], PandasExtensionArray) assert extension_duck_array[0] == categorical1[0] + assert isinstance(extension_duck_array[0], PandasExtensionArray) mask = [True, False, True, False, True] assert (extension_duck_array[mask] == categorical1[mask]).all() From dd7fe98f329fde8ee7bd2d8703927e513db03123 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 28 Mar 2024 10:18:32 +0100 Subject: [PATCH 56/70] (feat): remove shared dtype casting --- xarray/core/duck_array_ops.py | 4 +++- xarray/tests/test_duck_array_ops.py | 23 ----------------------- 2 files changed, 3 insertions(+), 24 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 69b9d64df51..d95dfa566cc 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -230,7 +230,9 @@ def as_shared_dtype(scalars_or_arrays, xp=np): isinstance(x, type(extension_array_types[0])) for x in extension_array_types ): return scalars_or_arrays - arrays = [asarray(np.array(x), xp=xp) for x in scalars_or_arrays] + raise ValueError( + f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}" + ) elif array_type_cupy := array_type("cupy") and any( # noqa: F841 isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821 ): diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index f9c494fa80f..26821c69495 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -175,15 +175,6 @@ def test_where_extension_duck_array(self, categorical1, categorical2): where_res == pd.Categorical(["cat1", "cat1", "cat2", "cat3", "cat1"]) ).all() - def test_where_extension_duck_array_fallback(self, categorical1, categorical2): - where_res = where( - np.array([True, False, True, False, False]), - PandasExtensionArray(categorical1), - np.array(categorical2), - ) - assert isinstance(where_res, np.ndarray) - assert (where_res == np.array(["cat1", "cat1", "cat2", "cat3", "cat1"])).all() - def test_concatenate_extension_duck_array(self, categorical1, categorical2): concate_res = concatenate( [PandasExtensionArray(categorical1), PandasExtensionArray(categorical2)] @@ -202,20 +193,6 @@ def test_duck_extension_array_pyarrow_concatenate(self, arrow1, arrow2): assert concatenated[2]["x"] == 3 assert concatenated[3]["y"] - def test_concatenate_extension_duck_array_fallback( - self, categorical1, categorical2 - ): - concate_res = concatenate( - [PandasExtensionArray(categorical1), np.array(categorical2)] - ) - assert isinstance(concate_res, np.ndarray) - assert ( - concate_res - == np.array( - type(categorical1)._concat_same_type((categorical1, categorical2)) - ) - ).all() - def test___getitem__extension_duck_array(self, categorical1): extension_duck_array = PandasExtensionArray(categorical1) assert (extension_duck_array[0:2] == categorical1[0:2]).all() From f0df768d6f270a24f2999a22d94693418ada4a7c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 28 Mar 2024 10:19:01 +0100 Subject: [PATCH 57/70] (feat): loop once over `dataframe.items` --- xarray/core/dataset.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2a533c45ecd..b2713416d27 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7332,14 +7332,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: "cannot convert a DataFrame with a non-unique MultiIndex into xarray" ) - arrays = [ - (k, np.asarray(v)) - for k, v in dataframe.items() - if not is_extension_array_dtype(v) - ] - extension_arrays = [ - (k, v) for k, v in dataframe.items() if is_extension_array_dtype(v) - ] + arrays = [] + extension_arrays = [] + for k, v in dataframe.items(): + if not is_extension_array_dtype(v): + arrays.append((k, np.asarray(v))) + else: + extension_arrays.append((k, v)) indexes: dict[Hashable, Index] = {} index_vars: dict[Hashable, Variable] = {} From e2f04877dea3935c79d14909fa84d1f62064ff0a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 28 Mar 2024 10:20:36 +0100 Subject: [PATCH 58/70] (feat): add `__len__` attribute --- xarray/core/extension_array.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 0f40f6a500d..9adaea9b267 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -129,3 +129,6 @@ def __eq__(self, other): def __ne__(self, other): return ~(self == other) + + def __len__(self): + return len(self.array) From 1eb67416d3a1a1ac61531551a42593f8a8b63484 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 28 Mar 2024 10:28:50 +0100 Subject: [PATCH 59/70] (fix): ensure constructor recieves `pd.Categorical` --- xarray/core/extension_array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 9adaea9b267..484b6791a74 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -116,13 +116,15 @@ def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: if is_extension_array_dtype(item): return type(self)(item) if np.isscalar(item): - return type(self)([item]) + return type(self)(pd.Categorical([item])) return item def __setitem__(self, key, val): self.array[key] = val def __eq__(self, other): + if np.isscalar(other): + other = type(self)(pd.Categorical([other])) if isinstance(other, PandasExtensionArray): return self.array == other.array return self.array == other From 9cceadc77891c1bfd1625dc23a46704396798529 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Thu, 28 Mar 2024 17:02:15 +0100 Subject: [PATCH 60/70] Update xarray/core/extension_array.py Co-authored-by: Deepak Cherian --- xarray/core/extension_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index 484b6791a74..ba697056ae0 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -124,7 +124,7 @@ def __setitem__(self, key, val): def __eq__(self, other): if np.isscalar(other): - other = type(self)(pd.Categorical([other])) + other = type(self)(type(self.array)([other])) if isinstance(other, PandasExtensionArray): return self.array == other.array return self.array == other From f2588c1cfaefa12e2c4f278e7101f49454c29c32 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Thu, 28 Mar 2024 17:02:24 +0100 Subject: [PATCH 61/70] Update xarray/core/extension_array.py Co-authored-by: Deepak Cherian --- xarray/core/extension_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index ba697056ae0..6521e425615 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -116,7 +116,7 @@ def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: if is_extension_array_dtype(item): return type(self)(item) if np.isscalar(item): - return type(self)(pd.Categorical([item])) + return type(self)(type(self.array)([item])) return item def __setitem__(self, key, val): From a0a63bde7b97a49cef40e3a20b5c55ee7c3b233b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 28 Mar 2024 17:11:55 +0100 Subject: [PATCH 62/70] (fix): drop condition for categorical corrected --- xarray/core/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b2713416d27..4866bdbf988 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6836,13 +6836,13 @@ def reduce( if ( # Some reduction functions (e.g. std, var) need to run on variables # that don't have the reduce dims: PR5393 - ( + not is_extension_array_dtype(var.dtype) + and ( not reduce_dims or not numeric_only or np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_) ) - and not is_extension_array_dtype(var.dtype) ): # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because From e9dc53f239e0ea7f2c7d0520f3cb12b4bc14d555 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 12 Apr 2024 21:18:07 -0600 Subject: [PATCH 63/70] Apply suggestions from code review --- xarray/tests/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 1685e2f4824..3ce788dfb7f 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -115,7 +115,6 @@ def _importorskip( has_fsspec, requires_fsspec = _importorskip("fsspec") has_iris, requires_iris = _importorskip("iris") has_numbagg, requires_numbagg = _importorskip("numbagg", "0.4.0") -has_plum, requires_plum = _importorskip("plum") has_pyarrow, requires_pyarrow = _importorskip("pyarrow") with warnings.catch_warnings(): warnings.filterwarnings( From 4791799ad31d8fd4f65a04b5ebde7c926d391e34 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 16 Apr 2024 11:01:39 +0200 Subject: [PATCH 64/70] (chore): test `chunk` behavior --- xarray/tests/test_variable.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index db60ebd33e8..8a9345e74d4 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1582,6 +1582,14 @@ def test_pandas_cateogrical_dtype(self): print(v) # should not error assert pd.api.types.is_extension_array_dtype(v.dtype) + def test_pandas_cateogrical_no_chunk(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + v = self.cls("x", data) + with pytest.raises( + ValueError, match=r".*was found to be a Pandas ExtensionArray.*" + ): + v.chunk((5,)) + def test_squeeze(self): v = Variable(["x", "y"], [[1]]) assert_identical(Variable([], 1), v.squeeze()) From 037408603903627b581b07f16aa77dd47815a6a8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 16 Apr 2024 09:09:50 -0600 Subject: [PATCH 65/70] Update xarray/core/variable.py --- xarray/core/variable.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 2229eaa2d24..6a990ff40a7 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2575,11 +2575,6 @@ def chunk( # type: ignore[override] dask.array.from_array """ - if is_extension_array_dtype(self): - raise ValueError( - f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first." - ) - if from_array_kwargs is None: from_array_kwargs = {} From 72bf807d197ac6a1c7a46fe4aadc7d3f1953e34d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Apr 2024 15:11:30 +0000 Subject: [PATCH 66/70] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/variable.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 6a990ff40a7..95d6199d460 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -13,7 +13,6 @@ import numpy as np import pandas as pd from numpy.typing import ArrayLike -from pandas.api.types import is_extension_array_dtype import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils From 63b6c4288d301ad35534554d17c9cf97613a0a24 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 17 Apr 2024 10:55:04 +0200 Subject: [PATCH 67/70] (fix): bring back error --- xarray/core/variable.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 95d6199d460..2229eaa2d24 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -13,6 +13,7 @@ import numpy as np import pandas as pd from numpy.typing import ArrayLike +from pandas.api.types import is_extension_array_dtype import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils @@ -2574,6 +2575,11 @@ def chunk( # type: ignore[override] dask.array.from_array """ + if is_extension_array_dtype(self): + raise ValueError( + f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first." + ) + if from_array_kwargs is None: from_array_kwargs = {} From 1d18439770dec649f93253943ad001388683e85d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 17 Apr 2024 10:59:11 +0200 Subject: [PATCH 68/70] (chore): add test for dropping cat for mean --- xarray/tests/test_dataset.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c5f9d6b157c..a948fafc815 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5463,18 +5463,22 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None: assert list(actual) == expected def test_reduce_non_numeric(self) -> None: - data1 = create_test_data(seed=44) + data1 = create_test_data(seed=44, use_extension_array=True) data2 = create_test_data(seed=44) - add_vars = {"var4": ["dim1", "dim2"], "var5": ["dim1"]} + add_vars = {"var5": ["dim1", "dim2"], "var6": ["dim1"]} for v, dims in sorted(add_vars.items()): size = tuple(data1.sizes[d] for d in dims) data = np.random.randint(0, 100, size=size).astype(np.str_) data1[v] = (dims, data, {"foo": "variable"}) - - assert "var4" not in data1.mean() and "var5" not in data1.mean() + # var4 is extension array categorical and should be dropped + assert ( + "var4" not in data1.mean() + and "var5" not in data1.mean() + and "var6" not in data1.mean() + ) assert_equal(data1.mean(), data2.mean()) assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1")) - assert "var4" not in data1.mean(dim="dim2") and "var5" in data1.mean(dim="dim2") + assert "var5" not in data1.mean(dim="dim2") and "var6" in data1.mean(dim="dim2") @pytest.mark.filterwarnings( "ignore:Once the behaviour of DataArray:DeprecationWarning" From 17f05da8e641384629c90410d4214335546aaf2b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 17 Apr 2024 11:14:22 -0600 Subject: [PATCH 69/70] Update whats-new.rst --- doc/whats-new.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0d15b29b421..22515398d00 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,7 +24,11 @@ New Features ~~~~~~~~~~~~ - New "random" method for converting to and from 360_day calendars (:pull:`8603`). By `Pascal Bourgault `_. - +- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array + by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`, + for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray` + then, such as broadcasting. + By `Ilan Gold `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -154,10 +158,6 @@ New Features (:issue:`7377`, :pull:`8684`) By `Marco Wolsza `_. -- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array -by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`, for example, -will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray` then, such as broadcasting. - Breaking changes ~~~~~~~~~~~~~~~~ From c906c812ba4f35aee5c92f424aa668b770a18fd0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 17 Apr 2024 17:15:05 +0000 Subject: [PATCH 70/70] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 22515398d00..a485c3fb5c2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,7 +25,7 @@ New Features - New "random" method for converting to and from 360_day calendars (:pull:`8603`). By `Pascal Bourgault `_. - Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array - by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`, + by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`, for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray` then, such as broadcasting. By `Ilan Gold `_.