diff --git a/asv_bench/benchmarks/combine.py b/asv_bench/benchmarks/combine.py index 286746c75..dd3d7a178 100644 --- a/asv_bench/benchmarks/combine.py +++ b/asv_bench/benchmarks/combine.py @@ -1,4 +1,5 @@ from functools import partial +from typing import Any import numpy as np @@ -43,8 +44,8 @@ class Combine1d(Combine): this is for reducting along a single dimension """ - def setup(self, *args, **kwargs): - def construct_member(groups): + def setup(self, *args, **kwargs) -> None: + def construct_member(groups) -> dict[str, Any]: return { "groups": groups, "intermediates": [ diff --git a/flox/aggregations.py b/flox/aggregations.py index 13b23fafe..e5013032a 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -185,7 +185,7 @@ def __init__( # how to aggregate results after first round of reduction self.combine: FuncTuple = _atleast_1d(combine) # simpler reductions used with the "simple combine" algorithm - self.simple_combine = None + self.simple_combine: tuple[Callable, ...] = () # final aggregation self.aggregate: Callable | str = aggregate if aggregate else self.combine[0] # finalize results (see mean) @@ -207,7 +207,7 @@ def __init__( # The following are set by _initialize_aggregation self.finalize_kwargs: dict[Any, Any] = {} - self.min_count: int | None = None + self.min_count: int = 0 def _normalize_dtype_fill_value(self, value, name): value = _atleast_1d(value) @@ -504,7 +504,7 @@ def _initialize_aggregation( dtype, array_dtype, fill_value, - min_count: int | None, + min_count: int, finalize_kwargs: dict[Any, Any] | None, ) -> Aggregation: if not isinstance(func, Aggregation): @@ -559,9 +559,6 @@ def _initialize_aggregation( assert isinstance(finalize_kwargs, dict) agg.finalize_kwargs = finalize_kwargs - if min_count is None: - min_count = 0 - # This is needed for the dask pathway. # Because we use intermediate fill_value since a group could be # absent in one block, but present in another block @@ -579,7 +576,7 @@ def _initialize_aggregation( else: agg.min_count = 0 - simple_combine = [] + simple_combine: list[Callable] = [] for combine in agg.combine: if isinstance(combine, str): if combine in ["nanfirst", "nanlast"]: diff --git a/flox/core.py b/flox/core.py index 02f53b837..f821df9bc 100644 --- a/flox/core.py +++ b/flox/core.py @@ -51,11 +51,15 @@ T_DuckArray = Union[np.ndarray, DaskArray] # Any ? T_By = T_DuckArray T_Bys = tuple[T_By, ...] - T_ExpectIndex = Union[pd.Index, None] - T_Expect = Union[Sequence, np.ndarray, T_ExpectIndex] + T_ExpectIndex = Union[pd.Index] T_ExpectIndexTuple = tuple[T_ExpectIndex, ...] + T_ExpectIndexOpt = Union[T_ExpectIndex, None] + T_ExpectIndexOptTuple = tuple[T_ExpectIndexOpt, ...] + T_Expect = Union[Sequence, np.ndarray, T_ExpectIndex] T_ExpectTuple = tuple[T_Expect, ...] - T_ExpectedGroups = Union[T_Expect, T_ExpectTuple] + T_ExpectOpt = Union[Sequence, np.ndarray, T_ExpectIndexOpt] + T_ExpectOptTuple = tuple[T_ExpectOpt, ...] + T_ExpectedGroups = Union[T_Expect, T_ExpectOptTuple] T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None] T_Func = Union[str, Callable] T_Funcs = Union[T_Func, Sequence[T_Func]] @@ -98,7 +102,7 @@ def _is_first_last_reduction(func: T_Agg) -> bool: return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"] -def _get_expected_groups(by: T_By, sort: bool) -> pd.Index: +def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex: if is_duck_dask_array(by): raise ValueError("Please provide expected_groups if not grouping by a numpy array.") flatby = by.reshape(-1) @@ -219,8 +223,13 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: raveled = labels.reshape(-1) # these are chunks where a label is present label_chunks = pd.Series(which_chunk).groupby(raveled).unique() + # These invert the label_chunks mapping so we know which labels occur together. - chunks_cohorts = tlz.groupby(lambda x: tuple(label_chunks.get(x)), label_chunks.keys()) + def invert(x) -> tuple[np.ndarray, ...]: + arr = label_chunks.get(x) + return tuple(arr) # type: ignore [arg-type] # pandas issue? + + chunks_cohorts = tlz.groupby(invert, label_chunks.keys()) if merge: # First sort by number of chunks occupied by cohort @@ -459,7 +468,7 @@ def factorize_( axes: T_Axes, *, fastpath: Literal[True], - expected_groups: tuple[pd.Index, ...] | None = None, + expected_groups: T_ExpectIndexOptTuple | None = None, reindex: bool = False, sort: bool = True, ) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, None]: @@ -471,7 +480,7 @@ def factorize_( by: T_Bys, axes: T_Axes, *, - expected_groups: tuple[pd.Index, ...] | None = None, + expected_groups: T_ExpectIndexOptTuple | None = None, reindex: bool = False, sort: bool = True, fastpath: Literal[False] = False, @@ -484,7 +493,7 @@ def factorize_( by: T_Bys, axes: T_Axes, *, - expected_groups: tuple[pd.Index, ...] | None = None, + expected_groups: T_ExpectIndexOptTuple | None = None, reindex: bool = False, sort: bool = True, fastpath: bool = False, @@ -496,7 +505,7 @@ def factorize_( by: T_Bys, axes: T_Axes, *, - expected_groups: tuple[pd.Index, ...] | None = None, + expected_groups: T_ExpectIndexOptTuple | None = None, reindex: bool = False, sort: bool = True, fastpath: bool = False, @@ -546,7 +555,7 @@ def factorize_( else: idx = np.zeros_like(flat, dtype=np.intp) - 1 - found_groups.append(expect) + found_groups.append(np.array(expect)) else: if expect is not None and reindex: sorter = np.argsort(expect) @@ -560,7 +569,7 @@ def factorize_( idx = sorter[(idx,)] idx[mask] = -1 else: - idx, groups = pd.factorize(flat, sort=sort) + idx, groups = pd.factorize(flat, sort=sort) # type: ignore # pandas issue? found_groups.append(np.array(groups)) factorized.append(idx.reshape(groupvar.shape)) @@ -853,7 +862,8 @@ def _finalize_results( """ squeezed = _squeeze_results(results, axis) - if agg.min_count > 0: + min_count = agg.min_count + if min_count > 0: counts = squeezed["intermediates"][-1] squeezed["intermediates"] = squeezed["intermediates"][:-1] @@ -864,8 +874,8 @@ def _finalize_results( else: finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs) - if agg.min_count > 0: - count_mask = counts < agg.min_count + if min_count > 0: + count_mask = counts < min_count if count_mask.any(): # For one count_mask.any() prevents promoting bool to dtype(fill_value) unless # necessary @@ -1283,7 +1293,7 @@ def dask_groupby_agg( array: DaskArray, by: T_By, agg: Aggregation, - expected_groups: pd.Index | None, + expected_groups: T_ExpectIndexOpt, axis: T_Axes = (), fill_value: Any = None, method: T_Method = "map-reduce", @@ -1423,9 +1433,11 @@ def dask_groupby_agg( group_chunks = ((np.nan,),) else: if expected_groups is None: - expected_groups = _get_expected_groups(by_input, sort=sort) - groups = (expected_groups.to_numpy(),) - group_chunks = ((len(expected_groups),),) + expected_groups_ = _get_expected_groups(by_input, sort=sort) + else: + expected_groups_ = expected_groups + groups = (expected_groups_.to_numpy(),) + group_chunks = ((len(expected_groups_),),) elif method == "cohorts": chunks_cohorts = find_group_cohorts( @@ -1569,7 +1581,7 @@ def _validate_reindex( return reindex -def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys): +def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys) -> None: assert all(b.ndim == by[0].ndim for b in by[1:]) for idx, b in enumerate(by): if not all(j in [i, 1] for i, j in zip(shape[-b.ndim :], b.shape)): @@ -1584,18 +1596,33 @@ def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys): ) +@overload +def _convert_expected_groups_to_index( + expected_groups: tuple[None, ...], isbin: Sequence[bool], sort: bool +) -> tuple[None, ...]: + ... + + +@overload def _convert_expected_groups_to_index( expected_groups: T_ExpectTuple, isbin: Sequence[bool], sort: bool ) -> T_ExpectIndexTuple: - out: list[pd.Index | None] = [] + ... + + +def _convert_expected_groups_to_index( + expected_groups: T_ExpectOptTuple, isbin: Sequence[bool], sort: bool +) -> T_ExpectIndexOptTuple: + out: list[T_ExpectIndexOpt] = [] for ex, isbin_ in zip(expected_groups, isbin): if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin_): if sort: - ex = ex.sort_values() - out.append(ex) + out.append(ex.sort_values()) + else: + out.append(ex) elif ex is not None: if isbin_: - out.append(pd.IntervalIndex.from_breaks(ex)) + out.append(pd.IntervalIndex.from_breaks(ex)) # type: ignore [arg-type] # TODO: what do we want here? else: if sort: ex = np.sort(ex) @@ -1613,7 +1640,7 @@ def _lazy_factorize_wrapper(*by: T_By, **kwargs) -> np.ndarray: def _factorize_multiple( by: T_Bys, - expected_groups: T_ExpectIndexTuple, + expected_groups: T_ExpectIndexOptTuple, any_by_dask: bool, reindex: bool, sort: bool = True, @@ -1668,7 +1695,17 @@ def _factorize_multiple( return (group_idx,), found_groups, grp_shape -def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectTuple: +@overload +def _validate_expected_groups(nby: int, expected_groups: None) -> tuple[None, ...]: + ... + + +@overload +def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroups) -> T_ExpectTuple: + ... + + +def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectOptTuple: if expected_groups is None: return (None,) * nby @@ -1935,21 +1972,20 @@ def groupby_reduce( # Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0 if min_count is None: if nax < by_.ndim or fill_value is not None: - min_count = 1 + min_count_: int = 1 + else: + min_count_ = 0 + else: + min_count_ = min_count # TODO: set in xarray? - if ( - min_count is not None - and min_count > 0 - and func in ["nansum", "nanprod"] - and fill_value is None - ): + if min_count_ > 0 and func in ["nansum", "nanprod"] and fill_value is None: # nansum, nanprod have fill_value=0, 1 # overwrite than when min_count is set fill_value = np.nan kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine) - agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs) + agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count_, finalize_kwargs) groups: tuple[np.ndarray | DaskArray, ...] if not has_dask: diff --git a/flox/xarray.py b/flox/xarray.py index df8f773df..acd3f2d6c 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -22,6 +22,8 @@ if TYPE_CHECKING: from xarray.core.types import T_DataArray, T_Dataset + from .core import T_ExpectedGroupsOpt, T_ExpectIndex, T_ExpectOpt + Dims = Union[str, Iterable[Hashable], None] @@ -63,7 +65,7 @@ def xarray_reduce( obj: T_Dataset | T_DataArray, *by: T_DataArray | Hashable, func: str | Aggregation, - expected_groups=None, + expected_groups: T_ExpectedGroupsOpt = None, isbin: bool | Sequence[bool] = False, sort: bool = True, dim: Dims | ellipsis = None, @@ -215,7 +217,7 @@ def xarray_reduce( else: isbins = (isbin,) * nby - expected_groups = _validate_expected_groups(nby, expected_groups) + expected_groups_valid = _validate_expected_groups(nby, expected_groups) if not sort: raise NotImplementedError("sort must be True for xarray_reduce") @@ -313,10 +315,10 @@ def xarray_reduce( # Set expected_groups and convert to index since we need coords, sizes # for output xarray objects - expected_groups = list(expected_groups) + expected_groups_valid_list: list[T_ExpectIndex] = [] group_names: tuple[Any, ...] = () group_sizes: dict[Any, int] = {} - for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups, isbins)): + for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups_valid, isbins)): group_name = ( f"{b_.name}_bins" if isbin_ or isinstance(expect, pd.IntervalIndex) else b_.name ) @@ -326,21 +328,23 @@ def xarray_reduce( raise NotImplementedError( "flox does not support binning into an integer number of bins yet." ) + + expect1: T_ExpectOpt if expect is None: if isbin_: raise ValueError( f"Please provided bin edges for group variable {idx} " f"named {group_name} in expected_groups." ) - expect_ = _get_expected_groups(b_.data, sort=sort) + expect1 = _get_expected_groups(b_.data, sort=sort) else: - expect_ = expect - expect_index = _convert_expected_groups_to_index((expect_,), (isbin_,), sort=sort)[0] + expect1 = expect + expect_index = _convert_expected_groups_to_index((expect1,), (isbin_,), sort=sort)[0] # The if-check is for type hinting mainly, it narrows down the return # type of _convert_expected_groups_to_index to pure pd.Index: if expect_index is not None: - expected_groups[idx] = expect_index + expected_groups_valid_list.append(expect_index) group_sizes[group_name] = len(expect_index) else: # This will never be reached @@ -426,7 +430,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): "skipna": skipna, "engine": engine, "reindex": reindex, - "expected_groups": tuple(expected_groups), + "expected_groups": tuple(expected_groups_valid_list), "isbin": isbins, "finalize_kwargs": finalize_kwargs, "dtype": dtype, @@ -440,10 +444,15 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): if all(d not in ds_broad[var].dims for d in dim_tuple): actual[var] = ds_broad[var] - for name, expect, by_ in zip(group_names, expected_groups, by_da): - # Can't remove this till xarray handles IntervalIndex - if isinstance(expect, pd.IntervalIndex): - expect = expect.to_numpy() + expect3: T_ExpectIndex | np.ndarray + for name, expect2, by_ in zip(group_names, expected_groups_valid_list, by_da): + # Can't remove this until xarray handles IntervalIndex: + if isinstance(expect2, pd.IntervalIndex): + # TODO: Only place where expect3 is an ndarray, remove the type if xarray + # starts supporting IntervalIndex. + expect3 = expect2.to_numpy() + else: + expect3 = expect2 if isinstance(actual, xr.Dataset) and name in actual: actual = actual.drop_vars(name) # When grouping by MultiIndex, expect is an pd.Index wrapping @@ -451,15 +460,18 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): if ( name in ds_broad.indexes and isinstance(ds_broad.indexes[name], pd.MultiIndex) - and not isinstance(expect, pd.RangeIndex) + and not isinstance(expect3, pd.RangeIndex) ): levelnames = ds_broad.indexes[name].names - expect = pd.MultiIndex.from_tuples(expect.values, names=levelnames) - actual[name] = expect + if isinstance(expect3, np.ndarray): + # TODO: workaoround for IntervalIndex issue. + raise NotImplementedError + expect3 = pd.MultiIndex.from_tuples(expect3.values, names=levelnames) + actual[name] = expect3 if Version(xr.__version__) > Version("2022.03.0"): actual = actual.set_coords(levelnames) else: - actual[name] = expect + actual[name] = expect3 if keep_attrs: actual[name].attrs = by_.attrs