Skip to content

Commit

Permalink
BUG: generate proper code for custom aggregation func (#428)
Browse files Browse the repository at this point in the history
  • Loading branch information
UranusSeven committed May 16, 2023
1 parent e6b765e commit e8e7606
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 8 deletions.
Expand Up @@ -1674,8 +1674,7 @@ def test_gpu_groupby_size(data_type, chunked, as_index, sort, setup_gpu):
pd.testing.assert_series_equal(expected, actual)


# TODO: support cuda
# @support_cuda
@support_cuda
@pytest.mark.parametrize(
"as_index",
[True, False],
Expand Down Expand Up @@ -1706,16 +1705,78 @@ def g3(x):
df.groupby("a", as_index=False).agg((g1, g2, g3)),
mdf.groupby("a", as_index=False).agg((g1, g2, g3)).execute().fetch(),
)
pd.testing.assert_frame_equal(
df.groupby("a", as_index=as_index).agg((g1, g1)),
mdf.groupby("a", as_index=as_index).agg((g1, g1)).execute().fetch(),
)
if not gpu:
# cuDF doesn't support having multiple columns with same names yet.
pd.testing.assert_frame_equal(
df.groupby("a", as_index=as_index).agg((g1, g1)),
mdf.groupby("a", as_index=as_index).agg((g1, g1)).execute().fetch(),
)

pd.testing.assert_frame_equal(
df.groupby("a", as_index=as_index)["b"].agg((g1, g2, g3)),
mdf.groupby("a", as_index=as_index)["b"].agg((g1, g2, g3)).execute().fetch(),
)
if not gpu:
# cuDF doesn't support having multiple columns with same names yet.
pd.testing.assert_frame_equal(
df.groupby("a", as_index=as_index)["b"].agg((g1, g1)),
mdf.groupby("a", as_index=as_index)["b"].agg((g1, g1)).execute().fetch(),
)


@support_cuda
def test_groupby_agg_on_custom_funcs(setup_gpu, gpu):
rs = np.random.RandomState(0)
df = pd.DataFrame(
{
"a": rs.choice(["foo", "bar", "baz"], size=100),
"b": rs.choice(["foo", "bar", "baz"], size=100),
"c": rs.choice(["foo", "bar", "baz"], size=100),
},
)

mdf = md.DataFrame(df, chunk_size=34, gpu=gpu)

def g1(x):
return ("foo" == x).sum()

def g2(x):
return ("foo" != x).sum()

def g3(x):
return (x > "bar").sum()

def g4(x):
return (x >= "bar").sum()

def g5(x):
return (x < "baz").sum()

def g6(x):
return (x <= "baz").sum()

pd.testing.assert_frame_equal(
df.groupby("a", as_index=as_index)["b"].agg((g1, g1)),
mdf.groupby("a", as_index=as_index)["b"].agg((g1, g1)).execute().fetch(),
df.groupby("a", as_index=False).agg(
(
g1,
g2,
g3,
g4,
g5,
g6,
)
),
mdf.groupby("a", as_index=False)
.agg(
(
g1,
g2,
g3,
g4,
g5,
g6,
)
)
.execute()
.fetch(),
)
18 changes: 18 additions & 0 deletions python/xorbits/_mars/dataframe/reduction/core.py
Expand Up @@ -1164,6 +1164,24 @@ def _interpret_var(v):
axis_expr = f"axis={op_axis!r}, " if op_axis is not None else ""
op_str = _func_name_to_op[func_name]
if t.op.lhs is t.inputs[0]:
if (
(
func_name
in (
"gt",
"ge",
"lt",
"le",
"eq",
"ne",
)
)
and isinstance(t.op.lhs, DATAFRAME_TYPE)
and isinstance(t.op.rhs, str)
):
# for a cudf dataframe, df == 'foo' doesn't work, so we convert the rhs
# to a tuple.
rhs = f"({rhs},) * len({lhs}.columns)"
statements = [
f"try:",
f" {var_name} = {lhs}.{func_name}({rhs}, {axis_expr})",
Expand Down

0 comments on commit e8e7606

Please sign in to comment.