Skip to content

Commit

Permalink
use engine flox for ordered groups (#266)
Browse files Browse the repository at this point in the history
* use engine flox for ordered groups

* Add issorted helper func

* Some fixes

* In xarray too

* formatting

* simplify

* retry

* flox

* minversion numabgg

* cleanup

* fix type

* update gitignore

* add types

* Fix env?

* fix

* fix merge

* cleanup

* [skip-ci] bench

* temporarily disable numbagg

* don't cache env

* Finally!

* bugfix

* Fix doctest

* more fixes

* Fix CI

* readd numbagg

* Fix.

---------

Co-authored-by: Deepak Cherian <deepak@cherian.net>
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
3 people committed Oct 15, 2023
1 parent 789cf73 commit fecd9a6
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 20 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/ci-additional.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ jobs:
conda list
- name: Run doctests
run: |
python -m pytest --doctest-modules flox --ignore flox/tests --cov=./ --cov-report=xml
python -m pytest --doctest-modules \
flox/aggregations.py flox/core.py flox/xarray.py \
--ignore flox/tests \
--cov=./ --cov-report=xml
- name: Upload code coverage to Codecov
uses: codecov/codecov-action@v3.1.4
with:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
asv_bench/pkgs/
docs/source/generated/
html/
.asv/
Expand Down
2 changes: 1 addition & 1 deletion asv_bench/benchmarks/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

N = 3000
funcs = ["sum", "nansum", "mean", "nanmean", "max", "nanmax", "count"]
engines = ["flox", "numpy", "numbagg"]
engines = [None, "flox", "numpy", "numbagg"]
expected_groups = {
"None": None,
"bins": pd.IntervalIndex.from_breaks([1, 2, 4]),
Expand Down
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ dependencies:
- pooch
- toolz
- numba
- numbagg>=0.3
- scipy
2 changes: 2 additions & 0 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class AggDtypeInit(TypedDict):


class AggDtype(TypedDict):
user: DTypeLike | None
final: np.dtype
numpy: tuple[np.dtype | type[np.intp], ...]
intermediate: tuple[np.dtype | type[np.intp], ...]
Expand Down Expand Up @@ -569,6 +570,7 @@ def _initialize_aggregation(

final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value)
agg.dtype = {
"user": dtype, # Save to automatically choose an engine
"final": final_dtype,
"numpy": (final_dtype,),
"intermediate": tuple(
Expand Down
45 changes: 41 additions & 4 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
generic_aggregate,
)
from .cache import memoize
from .xrutils import is_duck_array, is_duck_dask_array, isnull
from .xrutils import is_duck_array, is_duck_dask_array, isnull, module_available

HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0")

if TYPE_CHECKING:
try:
Expand Down Expand Up @@ -69,6 +71,7 @@
T_Dtypes = Union[np.typing.DTypeLike, Sequence[np.typing.DTypeLike], None]
T_FillValues = Union[np.typing.ArrayLike, Sequence[np.typing.ArrayLike], None]
T_Engine = Literal["flox", "numpy", "numba", "numbagg"]
T_EngineOpt = None | T_Engine
T_Method = Literal["map-reduce", "blockwise", "cohorts"]
T_IsBins = Union[bool | Sequence[bool]]

Expand All @@ -83,6 +86,10 @@
DUMMY_AXIS = -2


def _issorted(arr: np.ndarray) -> bool:
return bool((arr[:-1] <= arr[1:]).all())


def _is_arg_reduction(func: T_Agg) -> bool:
if isinstance(func, str) and func in ["argmin", "argmax", "nanargmax", "nanargmin"]:
return True
Expand Down Expand Up @@ -632,6 +639,7 @@ def chunk_argreduce(
reindex: bool = False,
engine: T_Engine = "numpy",
sort: bool = True,
user_dtype=None,
) -> IntermediateDict:
"""
Per-chunk arg reduction.
Expand All @@ -652,6 +660,7 @@ def chunk_argreduce(
dtype=dtype,
engine=engine,
sort=sort,
user_dtype=user_dtype,
)
if not isnull(results["groups"]).all():
idx = np.broadcast_to(idx, array.shape)
Expand Down Expand Up @@ -685,6 +694,7 @@ def chunk_reduce(
engine: T_Engine = "numpy",
kwargs: Sequence[dict] | None = None,
sort: bool = True,
user_dtype=None,
) -> IntermediateDict:
"""
Wrapper for numpy_groupies aggregate that supports nD ``array`` and
Expand Down Expand Up @@ -785,6 +795,7 @@ def chunk_reduce(
group_idx = group_idx.reshape(-1)

assert group_idx.ndim == 1

empty = np.all(props.nanmask)

results: IntermediateDict = {"groups": [], "intermediates": []}
Expand Down Expand Up @@ -1100,6 +1111,7 @@ def _grouped_combine(
dtype=(np.intp,),
engine=engine,
sort=sort,
user_dtype=agg.dtype["user"],
)["intermediates"][0]
)

Expand Down Expand Up @@ -1129,6 +1141,7 @@ def _grouped_combine(
dtype=(dtype,),
engine=engine,
sort=sort,
user_dtype=agg.dtype["user"],
)
results["intermediates"].append(*_results["intermediates"])
results["groups"] = _results["groups"]
Expand Down Expand Up @@ -1174,6 +1187,7 @@ def _reduce_blockwise(
engine=engine,
sort=sort,
reindex=reindex,
user_dtype=agg.dtype["user"],
)

if _is_arg_reduction(agg):
Expand Down Expand Up @@ -1366,6 +1380,7 @@ def dask_groupby_agg(
fill_value=agg.fill_value["intermediate"],
dtype=agg.dtype["intermediate"],
reindex=reindex,
user_dtype=agg.dtype["user"],
)
if do_simple_combine:
# Add a dummy dimension that then gets reduced over
Expand Down Expand Up @@ -1757,6 +1772,23 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) ->
return expected_groups


def _choose_engine(by, agg: Aggregation):
dtype = agg.dtype["user"]

not_arg_reduce = not _is_arg_reduction(agg)

# numbagg only supports nan-skipping reductions
# without dtype specified
if HAS_NUMBAGG and "nan" in agg.name:
if not_arg_reduce and dtype is None:
return "numbagg"

if not_arg_reduce and (not is_duck_dask_array(by) and _issorted(by)):
return "flox"
else:
return "numpy"


def groupby_reduce(
array: np.ndarray | DaskArray,
*by: T_By,
Expand All @@ -1769,7 +1801,7 @@ def groupby_reduce(
dtype: np.typing.DTypeLike = None,
min_count: int | None = None,
method: T_Method = "map-reduce",
engine: T_Engine = "numpy",
engine: T_EngineOpt = None,
reindex: bool | None = None,
finalize_kwargs: dict[Any, Any] | None = None,
) -> tuple[DaskArray, Unpack[tuple[np.ndarray | DaskArray, ...]]]: # type: ignore[misc] # Unpack not in mypy yet
Expand Down Expand Up @@ -2027,9 +2059,14 @@ def groupby_reduce(
# overwrite than when min_count is set
fill_value = np.nan

kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
kwargs = dict(axis=axis_, fill_value=fill_value)
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count_, finalize_kwargs)

# Need to set this early using `agg`
# It cannot be done in the core loop of chunk_reduce
# since we "prepare" the data for flox.
kwargs["engine"] = _choose_engine(by_, agg) if engine is None else engine

groups: tuple[np.ndarray | DaskArray, ...]
if not has_dask:
results = _reduce_blockwise(
Expand Down Expand Up @@ -2080,7 +2117,7 @@ def groupby_reduce(
assert len(groups) == 1
sorted_idx = np.argsort(groups[0])
# This optimization helps specifically with resampling
if not (sorted_idx[:-1] <= sorted_idx[1:]).all():
if not _issorted(sorted_idx):
result = result[..., sorted_idx]
groups = (groups[0][sorted_idx],)

Expand Down
4 changes: 2 additions & 2 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def xarray_reduce(
fill_value=None,
dtype: np.typing.DTypeLike = None,
method: str = "map-reduce",
engine: str = "numpy",
engine: str | None = None,
keep_attrs: bool | None = True,
skipna: bool | None = None,
min_count: int | None = None,
Expand Down Expand Up @@ -369,7 +369,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs):

# Flox's count works with non-numeric and its faster than converting.
requires_numeric = func not in ["count", "any", "all"] or (
func == "count" and engine != "flox"
func == "count" and kwargs["engine"] != "flox"
)
if requires_numeric:
is_npdatetime = array.dtype.kind in "Mm"
Expand Down
27 changes: 26 additions & 1 deletion flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# defined in xarray

import datetime
import importlib
from collections.abc import Iterable
from typing import Any
from typing import Any, Optional

import numpy as np
import pandas as pd
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
from packaging.version import Version

try:
import cftime
Expand Down Expand Up @@ -317,3 +319,26 @@ def nanlast(values, axis, keepdims=False):
return np.expand_dims(result, axis=axis)
else:
return result


def module_available(module: str, minversion: Optional[str] = None) -> bool:
"""Checks whether a module is installed without importing it.
Use this for a lightweight check and lazy imports.
Parameters
----------
module : str
Name of the module.
Returns
-------
available : bool
Whether the module is installed.
"""
has = importlib.util.find_spec(module) is not None
if has:
mod = importlib.import_module(module)
return Version(mod.__version__) < Version(minversion) if minversion is not None else True
else:
return False
49 changes: 38 additions & 11 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from numpy_groupies.aggregate_numpy import aggregate

from flox import xrutils
from flox.aggregations import Aggregation
from flox.aggregations import Aggregation, _initialize_aggregation
from flox.core import (
HAS_NUMBAGG,
_choose_engine,
_convert_expected_groups_to_index,
_get_optimal_chunks_for_groups,
_normalize_indexes,
Expand Down Expand Up @@ -600,12 +602,9 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
by = np.broadcast_to(labels2d, (3, *labels2d.shape))
rng = np.random.default_rng(12345)
array = rng.random(by.shape)
kwargs = dict(
func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine
)
expected, _ = groupby_reduce(array, by, **kwargs)
kwargs = dict(func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value)
expected, _ = groupby_reduce(array, by, engine=engine, **kwargs)
if engine == "flox":
kwargs.pop("engine")
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
assert_equal(expected_npg, expected)

Expand All @@ -622,12 +621,9 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
by = np.broadcast_to(labels2d, (3, *labels2d.shape))
rng = np.random.default_rng(12345)
array = rng.random(by.shape)
kwargs = dict(
func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value, engine=engine
)
expected, _ = groupby_reduce(array, by, **kwargs)
kwargs = dict(func=func, axis=axis, expected_groups=[0, 2], fill_value=fill_value)
expected, _ = groupby_reduce(array, by, engine=engine, **kwargs)
if engine == "flox":
kwargs.pop("engine")
expected_npg, _ = groupby_reduce(array, by, **kwargs, engine="numpy")
assert_equal(expected_npg, expected)

Expand All @@ -640,6 +636,7 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
actual, _ = groupby_reduce(
da.from_array(array, chunks=(-1, 2, 3)),
da.from_array(by, chunks=(-1, 2, 2)),
engine=engine,
**kwargs,
)
assert_equal(actual, expected, tolerance)
Expand Down Expand Up @@ -1546,3 +1543,33 @@ def test_method_check_numpy():
]
)
assert_equal(actual, expected)


@pytest.mark.parametrize("dtype", [None, np.float64])
def test_choose_engine(dtype):
numbagg_possible = HAS_NUMBAGG and dtype is None
default = "numbagg" if numbagg_possible else "numpy"
mean = _initialize_aggregation(
"mean",
dtype=dtype,
array_dtype=np.dtype("int64"),
fill_value=0,
min_count=0,
finalize_kwargs=None,
)
argmax = _initialize_aggregation(
"argmax",
dtype=dtype,
array_dtype=np.dtype("int64"),
fill_value=0,
min_count=0,
finalize_kwargs=None,
)

# sorted by -> flox
sorted_engine = _choose_engine(np.array([1, 1, 2, 2]), agg=mean)
assert sorted_engine == ("numbagg" if numbagg_possible else "flox")
# unsorted by -> numpy
assert _choose_engine(np.array([3, 1, 1]), agg=mean) == default
# argmax does not give engine="flox"
assert _choose_engine(np.array([1, 1, 2, 2]), agg=argmax) == "numpy"

0 comments on commit fecd9a6

Please sign in to comment.