Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "collect" aggregation support to dask-cudf #15593

Merged
merged 13 commits into from
May 1, 2024
Merged
24 changes: 24 additions & 0 deletions python/dask_cudf/dask_cudf/expr/_collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import warnings
from functools import cached_property

from dask_expr import (
Expand All @@ -17,6 +18,15 @@

import cudf

_LEGACY_WORKAROUND = (
"To disable query planning, set the global "
"'dataframe.query-planning' config to `False` "
"before dask is imported. This can also be done "
"by setting an environment variable: "
"`DASK_DATAFRAME__QUERY_PLANNING=False` "
)


##
## Custom collection classes
##
Expand Down Expand Up @@ -71,6 +81,20 @@ def groupby(
f"`by` must be a column name or list of columns, got {by}."
)

if "as_index" in kwargs:
warnings.warn(
"The `as_index` argument is no longer supported in "
"dask-cudf when query-planning is enabled.",
FutureWarning,
)

if kwargs.pop("as_index", True) is not True:
raise NotImplementedError(
f"`as_index=False` is not supported. Please disable "
"query planning, or reset the index after aggregating.\n"
f"{_LEGACY_WORKAROUND}"
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This essentially does what @charlesbluca suggested - The message is still slightly confusing, but so is the behavior I guess.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something to the effect of "as_index is not supported with query-planning enabled this will behave consistently with as_index=True" for both messages, with an additional blurb for the error showing the non-true value passed by the user?


return GroupBy(
self,
by,
Expand Down
54 changes: 54 additions & 0 deletions python/dask_cudf/dask_cudf/expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,55 @@
from dask_expr._groupby import (
GroupBy as DXGroupBy,
SeriesGroupBy as DXSeriesGroupBy,
SingleAggregation,
)
from dask_expr._util import is_scalar

from dask.dataframe.groupby import Aggregation

##
## Custom groupby classes
##


class Collect(SingleAggregation):
@staticmethod
def groupby_chunk(arg):
return arg.agg("collect")

@staticmethod
def groupby_aggregate(arg):
gb = arg.agg("collect")
if gb.ndim > 1:
for col in gb.columns:
gb[col] = gb[col].list.concat()
return gb
else:
return gb.list.concat()


collect_aggregation = Aggregation(
name="collect",
chunk=Collect.groupby_chunk,
agg=Collect.groupby_aggregate,
)


def _translate_arg(arg):
# Helper function to translate args so that
# they can be processed correctly by upstream
# dask & dask-expr. Right now, the only necessary
# translation is "collect" aggregations.
if isinstance(arg, dict):
return {k: _translate_arg(v) for k, v in arg.items()}
elif isinstance(arg, list):
return [_translate_arg(x) for x in arg]
elif arg in ("collect", "list", list):
return collect_aggregation
else:
return arg


# TODO: These classes are mostly a work-around for missing
# `observed=False` support.
# See: https://github.com/rapidsai/cudf/issues/15173
Expand Down Expand Up @@ -41,8 +83,20 @@ def __getitem__(self, key):
)
return g

def collect(self, **kwargs):
return self._single_agg(Collect, **kwargs)

def aggregate(self, arg, **kwargs):
return super().aggregate(_translate_arg(arg), **kwargs)


class SeriesGroupBy(DXSeriesGroupBy):
def __init__(self, *args, observed=None, **kwargs):
observed = observed if observed is not None else True
super().__init__(*args, observed=observed, **kwargs)

def collect(self, **kwargs):
return self._single_agg(Collect, **kwargs)

def aggregate(self, arg, **kwargs):
return super().aggregate(_translate_arg(arg), **kwargs)
26 changes: 12 additions & 14 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@
from dask_cudf.groupby import OPTIMIZED_AGGS, _aggs_optimized
from dask_cudf.tests.utils import QUERY_PLANNING_ON, xfail_dask_expr

# XFAIL "collect" tests for now
agg_params = [agg for agg in OPTIMIZED_AGGS if agg != "collect"]
if QUERY_PLANNING_ON:
agg_params.append(
# TODO: "collect" not supported with dask-expr yet
pytest.param("collect", marks=pytest.mark.xfail)
)
else:
agg_params.append("collect")


def assert_cudf_groupby_layers(ddf):
for prefix in ("cudf-aggregate-chunk", "cudf-aggregate-agg"):
Expand Down Expand Up @@ -57,7 +47,7 @@ def pdf(request):
return pdf


@pytest.mark.parametrize("aggregation", agg_params)
@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS)
@pytest.mark.parametrize("series", [False, True])
def test_groupby_basic(series, aggregation, pdf):
gdf = cudf.DataFrame.from_pandas(pdf)
Expand Down Expand Up @@ -110,7 +100,7 @@ def test_groupby_cumulative(aggregation, pdf, series):
dd.assert_eq(a, b)


@pytest.mark.parametrize("aggregation", agg_params)
@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS)
@pytest.mark.parametrize(
"func",
[
Expand Down Expand Up @@ -579,8 +569,16 @@ def test_groupby_categorical_key():
dd.assert_eq(expect, got)


@xfail_dask_expr("as_index not supported in dask-expr")
@pytest.mark.parametrize("as_index", [True, False])
@pytest.mark.parametrize(
"as_index",
[
True,
pytest.param(
False,
marks=xfail_dask_expr("as_index not supported in dask-expr"),
wence- marked this conversation as resolved.
Show resolved Hide resolved
),
],
)
@pytest.mark.parametrize("split_out", ["use_dask_default", 1, 2])
@pytest.mark.parametrize("split_every", [False, 4])
@pytest.mark.parametrize("npartitions", [1, 10])
Expand Down
Loading