Skip to content

Commit

Permalink
typing fixes (#235)
Browse files Browse the repository at this point in the history
* Update core.py

* Update xarray.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* avoid renaming

* Update xarray.py

* Update xarray.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update xarray.py

* Update xarray.py

* Update xarray.py

* Update xarray.py

* Update xarray.py

* split to optional

* Update xarray.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update xarray.py

* convert to pd.Index instead of ndarray

* Handled different slicer types?

* not supported instead?

* specify type for simple_combine

* Handle None in agg.min_count

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* add overloads and rename

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more overloads

* ignore

* Update core.py

* Update xarray.py

* Update core.py

* Update core.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update flox/core.py

* Update flox/core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* Update core.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update xarray.py

* Have to add another type here because of xarray not supporting IntervalIndex

* Update xarray.py

* test ex instead of e

* Revert "test ex instead of e"

This reverts commit 8e55d3a.

* check reveal_type

* without e

* try no redefinition

* IF redefining ex, mypy always takes the first definition of ex.  even if it has been narrowed down.

* test min_count=0

* test min_count=0

* test min_count=0

* test min_count=0

* test min_count = 0

* test min_count=0

* test min_count=0

* test min_count=0

* test min_count=0

* test min_count=0

* test min_count=0

* test min_count=0

* test min_count=0

* Update asv_bench/benchmarks/combine.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
3 people committed Jul 3, 2023
1 parent 6cf315a commit 66f152b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 59 deletions.
5 changes: 3 additions & 2 deletions asv_bench/benchmarks/combine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
from typing import Any

import numpy as np

Expand Down Expand Up @@ -43,8 +44,8 @@ class Combine1d(Combine):
this is for reducting along a single dimension
"""

def setup(self, *args, **kwargs):
def construct_member(groups):
def setup(self, *args, **kwargs) -> None:
def construct_member(groups) -> dict[str, Any]:
return {
"groups": groups,
"intermediates": [
Expand Down
11 changes: 4 additions & 7 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(
# how to aggregate results after first round of reduction
self.combine: FuncTuple = _atleast_1d(combine)
# simpler reductions used with the "simple combine" algorithm
self.simple_combine = None
self.simple_combine: tuple[Callable, ...] = ()
# final aggregation
self.aggregate: Callable | str = aggregate if aggregate else self.combine[0]
# finalize results (see mean)
Expand All @@ -207,7 +207,7 @@ def __init__(

# The following are set by _initialize_aggregation
self.finalize_kwargs: dict[Any, Any] = {}
self.min_count: int | None = None
self.min_count: int = 0

def _normalize_dtype_fill_value(self, value, name):
value = _atleast_1d(value)
Expand Down Expand Up @@ -504,7 +504,7 @@ def _initialize_aggregation(
dtype,
array_dtype,
fill_value,
min_count: int | None,
min_count: int,
finalize_kwargs: dict[Any, Any] | None,
) -> Aggregation:
if not isinstance(func, Aggregation):
Expand Down Expand Up @@ -559,9 +559,6 @@ def _initialize_aggregation(
assert isinstance(finalize_kwargs, dict)
agg.finalize_kwargs = finalize_kwargs

if min_count is None:
min_count = 0

# This is needed for the dask pathway.
# Because we use intermediate fill_value since a group could be
# absent in one block, but present in another block
Expand All @@ -579,7 +576,7 @@ def _initialize_aggregation(
else:
agg.min_count = 0

simple_combine = []
simple_combine: list[Callable] = []
for combine in agg.combine:
if isinstance(combine, str):
if combine in ["nanfirst", "nanlast"]:
Expand Down
102 changes: 69 additions & 33 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,15 @@
T_DuckArray = Union[np.ndarray, DaskArray] # Any ?
T_By = T_DuckArray
T_Bys = tuple[T_By, ...]
T_ExpectIndex = Union[pd.Index, None]
T_Expect = Union[Sequence, np.ndarray, T_ExpectIndex]
T_ExpectIndex = Union[pd.Index]
T_ExpectIndexTuple = tuple[T_ExpectIndex, ...]
T_ExpectIndexOpt = Union[T_ExpectIndex, None]
T_ExpectIndexOptTuple = tuple[T_ExpectIndexOpt, ...]
T_Expect = Union[Sequence, np.ndarray, T_ExpectIndex]
T_ExpectTuple = tuple[T_Expect, ...]
T_ExpectedGroups = Union[T_Expect, T_ExpectTuple]
T_ExpectOpt = Union[Sequence, np.ndarray, T_ExpectIndexOpt]
T_ExpectOptTuple = tuple[T_ExpectOpt, ...]
T_ExpectedGroups = Union[T_Expect, T_ExpectOptTuple]
T_ExpectedGroupsOpt = Union[T_ExpectedGroups, None]
T_Func = Union[str, Callable]
T_Funcs = Union[T_Func, Sequence[T_Func]]
Expand Down Expand Up @@ -98,7 +102,7 @@ def _is_first_last_reduction(func: T_Agg) -> bool:
return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]


def _get_expected_groups(by: T_By, sort: bool) -> pd.Index:
def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
if is_duck_dask_array(by):
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
flatby = by.reshape(-1)
Expand Down Expand Up @@ -219,8 +223,13 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
raveled = labels.reshape(-1)
# these are chunks where a label is present
label_chunks = pd.Series(which_chunk).groupby(raveled).unique()

# These invert the label_chunks mapping so we know which labels occur together.
chunks_cohorts = tlz.groupby(lambda x: tuple(label_chunks.get(x)), label_chunks.keys())
def invert(x) -> tuple[np.ndarray, ...]:
arr = label_chunks.get(x)
return tuple(arr) # type: ignore [arg-type] # pandas issue?

chunks_cohorts = tlz.groupby(invert, label_chunks.keys())

if merge:
# First sort by number of chunks occupied by cohort
Expand Down Expand Up @@ -459,7 +468,7 @@ def factorize_(
axes: T_Axes,
*,
fastpath: Literal[True],
expected_groups: tuple[pd.Index, ...] | None = None,
expected_groups: T_ExpectIndexOptTuple | None = None,
reindex: bool = False,
sort: bool = True,
) -> tuple[np.ndarray, tuple[np.ndarray, ...], tuple[int, ...], int, int, None]:
Expand All @@ -471,7 +480,7 @@ def factorize_(
by: T_Bys,
axes: T_Axes,
*,
expected_groups: tuple[pd.Index, ...] | None = None,
expected_groups: T_ExpectIndexOptTuple | None = None,
reindex: bool = False,
sort: bool = True,
fastpath: Literal[False] = False,
Expand All @@ -484,7 +493,7 @@ def factorize_(
by: T_Bys,
axes: T_Axes,
*,
expected_groups: tuple[pd.Index, ...] | None = None,
expected_groups: T_ExpectIndexOptTuple | None = None,
reindex: bool = False,
sort: bool = True,
fastpath: bool = False,
Expand All @@ -496,7 +505,7 @@ def factorize_(
by: T_Bys,
axes: T_Axes,
*,
expected_groups: tuple[pd.Index, ...] | None = None,
expected_groups: T_ExpectIndexOptTuple | None = None,
reindex: bool = False,
sort: bool = True,
fastpath: bool = False,
Expand Down Expand Up @@ -546,7 +555,7 @@ def factorize_(
else:
idx = np.zeros_like(flat, dtype=np.intp) - 1

found_groups.append(expect)
found_groups.append(np.array(expect))
else:
if expect is not None and reindex:
sorter = np.argsort(expect)
Expand All @@ -560,7 +569,7 @@ def factorize_(
idx = sorter[(idx,)]
idx[mask] = -1
else:
idx, groups = pd.factorize(flat, sort=sort)
idx, groups = pd.factorize(flat, sort=sort) # type: ignore # pandas issue?

found_groups.append(np.array(groups))
factorized.append(idx.reshape(groupvar.shape))
Expand Down Expand Up @@ -853,7 +862,8 @@ def _finalize_results(
"""
squeezed = _squeeze_results(results, axis)

if agg.min_count > 0:
min_count = agg.min_count
if min_count > 0:
counts = squeezed["intermediates"][-1]
squeezed["intermediates"] = squeezed["intermediates"][:-1]

Expand All @@ -864,8 +874,8 @@ def _finalize_results(
else:
finalized[agg.name] = agg.finalize(*squeezed["intermediates"], **agg.finalize_kwargs)

if agg.min_count > 0:
count_mask = counts < agg.min_count
if min_count > 0:
count_mask = counts < min_count
if count_mask.any():
# For one count_mask.any() prevents promoting bool to dtype(fill_value) unless
# necessary
Expand Down Expand Up @@ -1283,7 +1293,7 @@ def dask_groupby_agg(
array: DaskArray,
by: T_By,
agg: Aggregation,
expected_groups: pd.Index | None,
expected_groups: T_ExpectIndexOpt,
axis: T_Axes = (),
fill_value: Any = None,
method: T_Method = "map-reduce",
Expand Down Expand Up @@ -1423,9 +1433,11 @@ def dask_groupby_agg(
group_chunks = ((np.nan,),)
else:
if expected_groups is None:
expected_groups = _get_expected_groups(by_input, sort=sort)
groups = (expected_groups.to_numpy(),)
group_chunks = ((len(expected_groups),),)
expected_groups_ = _get_expected_groups(by_input, sort=sort)
else:
expected_groups_ = expected_groups
groups = (expected_groups_.to_numpy(),)
group_chunks = ((len(expected_groups_),),)

elif method == "cohorts":
chunks_cohorts = find_group_cohorts(
Expand Down Expand Up @@ -1569,7 +1581,7 @@ def _validate_reindex(
return reindex


def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys):
def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys) -> None:
assert all(b.ndim == by[0].ndim for b in by[1:])
for idx, b in enumerate(by):
if not all(j in [i, 1] for i, j in zip(shape[-b.ndim :], b.shape)):
Expand All @@ -1584,18 +1596,33 @@ def _assert_by_is_aligned(shape: tuple[int, ...], by: T_Bys):
)


@overload
def _convert_expected_groups_to_index(
expected_groups: tuple[None, ...], isbin: Sequence[bool], sort: bool
) -> tuple[None, ...]:
...


@overload
def _convert_expected_groups_to_index(
expected_groups: T_ExpectTuple, isbin: Sequence[bool], sort: bool
) -> T_ExpectIndexTuple:
out: list[pd.Index | None] = []
...


def _convert_expected_groups_to_index(
expected_groups: T_ExpectOptTuple, isbin: Sequence[bool], sort: bool
) -> T_ExpectIndexOptTuple:
out: list[T_ExpectIndexOpt] = []
for ex, isbin_ in zip(expected_groups, isbin):
if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin_):
if sort:
ex = ex.sort_values()
out.append(ex)
out.append(ex.sort_values())
else:
out.append(ex)
elif ex is not None:
if isbin_:
out.append(pd.IntervalIndex.from_breaks(ex))
out.append(pd.IntervalIndex.from_breaks(ex)) # type: ignore [arg-type] # TODO: what do we want here?
else:
if sort:
ex = np.sort(ex)
Expand All @@ -1613,7 +1640,7 @@ def _lazy_factorize_wrapper(*by: T_By, **kwargs) -> np.ndarray:

def _factorize_multiple(
by: T_Bys,
expected_groups: T_ExpectIndexTuple,
expected_groups: T_ExpectIndexOptTuple,
any_by_dask: bool,
reindex: bool,
sort: bool = True,
Expand Down Expand Up @@ -1668,7 +1695,17 @@ def _factorize_multiple(
return (group_idx,), found_groups, grp_shape


def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectTuple:
@overload
def _validate_expected_groups(nby: int, expected_groups: None) -> tuple[None, ...]:
...


@overload
def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroups) -> T_ExpectTuple:
...


def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> T_ExpectOptTuple:
if expected_groups is None:
return (None,) * nby

Expand Down Expand Up @@ -1935,21 +1972,20 @@ def groupby_reduce(
# Consider np.sum([np.nan]) = np.nan, np.nansum([np.nan]) = 0
if min_count is None:
if nax < by_.ndim or fill_value is not None:
min_count = 1
min_count_: int = 1
else:
min_count_ = 0
else:
min_count_ = min_count

# TODO: set in xarray?
if (
min_count is not None
and min_count > 0
and func in ["nansum", "nanprod"]
and fill_value is None
):
if min_count_ > 0 and func in ["nansum", "nanprod"] and fill_value is None:
# nansum, nanprod have fill_value=0, 1
# overwrite than when min_count is set
fill_value = np.nan

kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs)
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count_, finalize_kwargs)

groups: tuple[np.ndarray | DaskArray, ...]
if not has_dask:
Expand Down
Loading

0 comments on commit 66f152b

Please sign in to comment.