Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use engine flox for ordered groups #266

Merged
merged 34 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ef6bbc4
use engine flox for ordered groups
mathause Sep 29, 2023
e0ea569
Add issorted helper func
dcherian Oct 7, 2023
d4e30d8
Some fixes
dcherian Oct 7, 2023
c8000e3
In xarray too
dcherian Oct 7, 2023
68edd74
Merge remote-tracking branch 'upstream/main' into engine_none
dcherian Oct 7, 2023
c483a1a
formatting
dcherian Oct 7, 2023
3e4bae9
simplify
dcherian Oct 7, 2023
d14299a
Merge remote-tracking branch 'upstream/main' into engine_none
dcherian Oct 7, 2023
92eaf2c
retry
dcherian Oct 10, 2023
1a230eb
flox
dcherian Oct 10, 2023
22bd20c
minversion numabgg
dcherian Oct 10, 2023
699ecd9
cleanup
dcherian Oct 10, 2023
3d6413f
fix type
dcherian Oct 10, 2023
7061fda
Merge branch 'main' into engine_none
dcherian Oct 11, 2023
443fbc2
update gitignore
dcherian Oct 11, 2023
eedeb8e
add types
dcherian Oct 11, 2023
cc90509
Fix env?
dcherian Oct 11, 2023
56957a1
fix
dcherian Oct 11, 2023
16d4393
fix merge
dcherian Oct 11, 2023
f6262d2
cleanup
dcherian Oct 11, 2023
12dc816
Merge branch 'main' into engine_none
dcherian Oct 11, 2023
e13c2d9
[skip-ci] bench
dcherian Oct 11, 2023
a7cd95c
Merge branch 'main' into engine_none
dcherian Oct 11, 2023
46262f0
temporarily disable numbagg
dcherian Oct 12, 2023
ed8bd51
don't cache env
dcherian Oct 12, 2023
db0a6cc
Finally!
dcherian Oct 12, 2023
90cecf8
bugfix
dcherian Oct 14, 2023
f2e0aa6
Fix doctest
dcherian Oct 14, 2023
39c19b8
more fixes
dcherian Oct 14, 2023
7126e60
Fix CI
dcherian Oct 14, 2023
1243903
readd numbagg
dcherian Oct 14, 2023
524f540
Merge remote-tracking branch 'upstream/main' into engine_none
dcherian Oct 14, 2023
b70b0e6
Merge branch 'main' into engine_none
dcherian Oct 14, 2023
3534766
Fix.
dcherian Oct 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/ci-additional.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ jobs:
conda list
- name: Run doctests
run: |
python -m pytest --doctest-modules flox --ignore flox/tests --cov=./ --cov-report=xml
python -m pytest --doctest-modules \
flox/aggregations.py flox/core.py flox/xarray.py \
--ignore flox/tests \
--cov=./ --cov-report=xml
- name: Upload code coverage to Codecov
uses: codecov/codecov-action@v3.1.4
with:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
asv_bench/pkgs/
docs/source/generated/
html/
.asv/
Expand Down
2 changes: 1 addition & 1 deletion asv_bench/benchmarks/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

N = 3000
funcs = ["sum", "nansum", "mean", "nanmean", "max", "nanmax", "count"]
engines = ["flox", "numpy", "numbagg"]
engines = [None, "flox", "numpy", "numbagg"]
expected_groups = {
"None": None,
"bins": pd.IntervalIndex.from_breaks([1, 2, 4]),
Expand Down
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ dependencies:
- pooch
- toolz
- numba
- numbagg>=0.3
- scipy
2 changes: 2 additions & 0 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class AggDtypeInit(TypedDict):


class AggDtype(TypedDict):
user: DTypeLike | None
final: np.dtype
numpy: tuple[np.dtype | type[np.intp], ...]
intermediate: tuple[np.dtype | type[np.intp], ...]
Expand Down Expand Up @@ -569,6 +570,7 @@ def _initialize_aggregation(

final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
agg.dtype = {
"user": dtype, # Save to automatically choose an engine
"final": final_dtype,
"numpy": (final_dtype,),
"intermediate": tuple(
Expand Down
45 changes: 41 additions & 4 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
generic_aggregate,
)
from .cache import memoize
from .xrutils import is_duck_array, is_duck_dask_array, isnull
from .xrutils import is_duck_array, is_duck_dask_array, isnull, module_available

HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0")

if TYPE_CHECKING:
try:
Expand Down Expand Up @@ -69,6 +71,7 @@
T_Dtypes = Union[np.typing.DTypeLike, Sequence[np.typing.DTypeLike], None]
T_FillValues = Union[np.typing.ArrayLike, Sequence[np.typing.ArrayLike], None]
T_Engine = Literal["flox", "numpy", "numba", "numbagg"]
T_EngineOpt = None | T_Engine
T_Method = Literal["map-reduce", "blockwise", "cohorts"]
T_IsBins = Union[bool | Sequence[bool]]

Expand All @@ -83,6 +86,10 @@
DUMMY_AXIS = -2


def _issorted(arr: np.ndarray) -> bool:
return bool((arr[:-1] <= arr[1:]).all())


def _is_arg_reduction(func: T_Agg) -> bool:
if isinstance(func, str) and func in ["argmin", "argmax", "nanargmax", "nanargmin"]:
return True
Expand Down Expand Up @@ -632,6 +639,7 @@ def chunk_argreduce(
reindex: bool = False,
engine: T_Engine = "numpy",
sort: bool = True,
user_dtype=None,
) -> IntermediateDict:
"""
Per-chunk arg reduction.
Expand All @@ -652,6 +660,7 @@ def chunk_argreduce(
dtype=dtype,
engine=engine,
sort=sort,
user_dtype=user_dtype,
)
if not isnull(results["groups"]).all():
idx = np.broadcast_to(idx, array.shape)
Expand Down Expand Up @@ -685,6 +694,7 @@ def chunk_reduce(
engine: T_Engine = "numpy",
kwargs: Sequence[dict] | None = None,
sort: bool = True,
user_dtype=None,
) -> IntermediateDict:
"""
Wrapper for numpy_groupies aggregate that supports nD ``array`` and
Expand Down Expand Up @@ -785,6 +795,7 @@ def chunk_reduce(
group_idx = group_idx.reshape(-1)

assert group_idx.ndim == 1

empty = np.all(props.nanmask)

results: IntermediateDict = {"groups": [], "intermediates": []}
Expand Down Expand Up @@ -1100,6 +1111,7 @@ def _grouped_combine(
dtype=(np.intp,),
engine=engine,
sort=sort,
user_dtype=agg.dtype["user"],
)["intermediates"][0]
)

Expand Down Expand Up @@ -1129,6 +1141,7 @@ def _grouped_combine(
dtype=(dtype,),
engine=engine,
sort=sort,
user_dtype=agg.dtype["user"],
)
results["intermediates"].append(*_results["intermediates"])
results["groups"] = _results["groups"]
Expand Down Expand Up @@ -1174,6 +1187,7 @@ def _reduce_blockwise(
engine=engine,
sort=sort,
reindex=reindex,
user_dtype=agg.dtype["user"],
)

if _is_arg_reduction(agg):
Expand Down Expand Up @@ -1366,6 +1380,7 @@ def dask_groupby_agg(
fill_value=agg.fill_value["intermediate"],
dtype=agg.dtype["intermediate"],
reindex=reindex,
user_dtype=agg.dtype["user"],
)
if do_simple_combine:
# Add a dummy dimension that then gets reduced over
Expand Down Expand Up @@ -1757,6 +1772,23 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) ->
return expected_groups


def _choose_engine(by, agg: Aggregation):
dtype = agg.dtype["user"]

not_arg_reduce = not _is_arg_reduction(agg)

# numbagg only supports nan-skipping reductions
# without dtype specified
if HAS_NUMBAGG and "nan" in agg.name:
if not_arg_reduce and dtype is None:
return "numbagg"

if not_arg_reduce and (not is_duck_dask_array(by) and _issorted(by)):
return "flox"
else:
return "numpy"


def groupby_reduce(
array: np.ndarray | DaskArray,
*by: T_By,
Expand All @@ -1769,7 +1801,7 @@ def groupby_reduce(
dtype: np.typing.DTypeLike = None,
min_count: int | None = None,
method: T_Method = "map-reduce",
engine: T_Engine = "numpy",
engine: T_EngineOpt = None,
reindex: bool | None = None,
finalize_kwargs: dict[Any, Any] | None = None,
) -> tuple[DaskArray, Unpack[tuple[np.ndarray | DaskArray, ...]]]: # type: ignore[misc] # Unpack not in mypy yet
Expand Down Expand Up @@ -2027,9 +2059,14 @@ def groupby_reduce(
# overwrite than when min_count is set
fill_value = np.nan

kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
kwargs = dict(axis=axis_, fill_value=fill_value)
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count_, finalize_kwargs)

# Need to set this early using `agg`
# It cannot be done in the core loop of chunk_reduce
# since we "prepare" the data for flox.
kwargs["engine"] = _choose_engine(by_, agg) if engine is None else engine

groups: tuple[np.ndarray | DaskArray, ...]
if not has_dask:
results = _reduce_blockwise(
Expand Down Expand Up @@ -2080,7 +2117,7 @@ def groupby_reduce(
assert len(groups) == 1
sorted_idx = np.argsort(groups[0])
# This optimization helps specifically with resampling
if not (sorted_idx[:-1] <= sorted_idx[1:]).all():
if not _issorted(sorted_idx):
result = result[..., sorted_idx]
groups = (groups[0][sorted_idx],)

Expand Down
4 changes: 2 additions & 2 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def xarray_reduce(
fill_value=None,
dtype: np.typing.DTypeLike = None,
method: str = "map-reduce",
engine: str = "numpy",
engine: str | None = None,
keep_attrs: bool | None = True,
skipna: bool | None = None,
min_count: int | None = None,
Expand Down Expand Up @@ -369,7 +369,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):

# Flox's count works with non-numeric and its faster than converting.
requires_numeric = func not in ["count", "any", "all"] or (
func == "count" and engine != "flox"
func == "count" and kwargs["engine"] != "flox"
)
if requires_numeric:
is_npdatetime = array.dtype.kind in "Mm"
Expand Down
27 changes: 26 additions & 1 deletion flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# defined in xarray

import datetime
import importlib
from collections.abc import Iterable
from typing import Any
from typing import Any, Optional

import numpy as np
import pandas as pd
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
from packaging.version import Version

try:
import cftime
Expand Down Expand Up @@ -317,3 +319,26 @@ def nanlast(values, axis, keepdims=False):
return np.expand_dims(result, axis=axis)
else:
return result


def module_available(module: str, minversion: Optional[str] = None) -> bool:
"""Checks whether a module is installed without importing it.

Use this for a lightweight check and lazy imports.

Parameters
----------
module : str
Name of the module.

Returns
-------
available : bool
Whether the module is installed.
"""
has = importlib.util.find_spec(module) is not None
if has:
mod = importlib.import_module(module)
return Version(mod.__version__) < Version(minversion) if minversion is not None else True
else:
return False
49 changes: 38 additions & 11 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from numpy_groupies.aggregate_numpy import aggregate

from flox import xrutils
from flox.aggregations import Aggregation
from flox.aggregations import Aggregation, _initialize_aggregation
from flox.core import (
HAS_NUMBAGG,
_choose_engine,
_convert_expected_groups_to_index,
_get_optimal_chunks_for_groups,
_normalize_indexes,
Expand Down Expand Up @@ -600,12 +602,9 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
by = np.broadcast_to(labels2d, (3, *labels2d.shape))
rng = np.random.default_rng(12345)
array = rng.random(by.shape)
kwargs = dict(
func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine
)
expected, _ = groupby_reduce(array, by, **kwargs)
kwargs = dict(func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value)
expected, _ = groupby_reduce(array, by, engine=engine, **kwargs)
if engine == "flox":
kwargs.pop("engine")
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
assert_equal(expected_npg, expected)

Expand All @@ -622,12 +621,9 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
by = np.broadcast_to(labels2d, (3, *labels2d.shape))
rng = np.random.default_rng(12345)
array = rng.random(by.shape)
kwargs = dict(
func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine
)
expected, _ = groupby_reduce(array, by, **kwargs)
kwargs = dict(func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value)
expected, _ = groupby_reduce(array, by, engine=engine, **kwargs)
if engine == "flox":
kwargs.pop("engine")
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
assert_equal(expected_npg, expected)

Expand All @@ -640,6 +636,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
actual, _ = groupby_reduce(
da.from_array(array, chunks=(-1, 2, 3)),
da.from_array(by, chunks=(-1, 2, 2)),
engine=engine,
**kwargs,
)
assert_equal(actual, expected, tolerance)
Expand Down Expand Up @@ -1546,3 +1543,33 @@ def test_method_check_numpy():
]
)
assert_equal(actual, expected)


@pytest.mark.parametrize("dtype", [None, np.float64])
def test_choose_engine(dtype):
numbagg_possible = HAS_NUMBAGG and dtype is None
default = "numbagg" if numbagg_possible else "numpy"
mean = _initialize_aggregation(
"mean",
dtype=dtype,
array_dtype=np.dtype("int64"),
fill_value=0,
min_count=0,
finalize_kwargs=None,
)
argmax = _initialize_aggregation(
"argmax",
dtype=dtype,
array_dtype=np.dtype("int64"),
fill_value=0,
min_count=0,
finalize_kwargs=None,
)

# sorted by -> flox
sorted_engine = _choose_engine(np.array([1, 1, 2, 2]), agg=mean)
assert sorted_engine == ("numbagg" if numbagg_possible else "flox")
# unsorted by -> numpy
assert _choose_engine(np.array([3, 1, 1]), agg=mean) == default
# argmax does not give engine="flox"
assert _choose_engine(np.array([1, 1, 2, 2]), agg=argmax) == "numpy"
Loading