Skip to content
/ cudf Public
forked from rapidsai/cudf

Commit

Permalink
Translate round and unique unary operations
Browse files Browse the repository at this point in the history
And add evaluation handlers.
  • Loading branch information
wence- committed Jun 17, 2024
1 parent e12e9b3 commit acf0e2f
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 3 deletions.
86 changes: 86 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,92 @@ def do_evaluate(
) # pragma: no cover; init trips first


class UnaryFunction(Expr):
__slots__ = ("name", "options", "children")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(
self, dtype: plc.DataType, name: str, options: tuple[Any, ...], *children: Expr
) -> None:
super().__init__(dtype)
self.name = name
self.options = options
self.children = children
if self.name not in ("round", "unique"):
raise NotImplementedError(f"Unary function {name=}")

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if self.name == "round":
(decimal_places,) = self.options
(values,) = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
return Column(
plc.round.round(
values.obj, decimal_places, plc.round.RoundingMethod.HALF_UP
)
).sorted_like(values)
elif self.name == "unique":
(maintain_order,) = self.options
(values,) = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
# Only one column, so keep_any is the same as keep_first
# for stable distinct
keep = plc.stream_compaction.DuplicateKeepOption.KEEP_ANY
if values.is_sorted:
maintain_order = True
result = plc.stream_compaction.unique(
plc.Table([values.obj]),
[0],
keep,
plc.types.NullEquality.EQUAL,
)
else:
distinct = (
plc.stream_compaction.stable_distinct
if maintain_order
else plc.stream_compaction.distinct
)
result = distinct(
plc.Table([values.obj]),
[0],
keep,
plc.types.NullEquality.EQUAL,
plc.types.NanEquality.ALL_EQUAL,
)
(column,) = result.columns()
if maintain_order:
return Column(column).sorted_like(values)
return Column(column)
raise NotImplementedError(
f"Unimplemented unary function {self.name=}"
) # pragma: no cover; init trips first

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
if depth == 1:
# inside aggregation, need to pre-evaluate,
# This recurses to check if we have nested aggs
# groupby construction has checked that we don't have
# nested aggs, so stop the recursion and return ourselves
# for pre-eval
return AggInfo([(self, plc.aggregation.collect_list(), self)])
else:
(child,) = self.children
return child.collect_agg(depth=depth)


class Sort(Expr):
__slots__ = ("options", "children")
_non_child = ("dtype", "options")
Expand Down
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def check_agg(agg: expr.Expr) -> int:
NotImplementedError
For unsupported expression nodes.
"""
if isinstance(agg, (expr.BinOp, expr.Cast)):
if isinstance(agg, (expr.BinOp, expr.Cast, expr.UnaryFunction)):
return max(GroupBy.check_agg(child) for child in agg.children)
elif isinstance(agg, expr.Agg):
return 1 + max(GroupBy.check_agg(child) for child in agg.children)
Expand Down
12 changes: 10 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,16 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex
options,
*(translate_expr(visitor, n=n) for n in node.input),
)
else:
raise NotImplementedError(f"No handler for Expr function node with {name=}")
elif isinstance(name, str):
return expr.UnaryFunction(
dtype,
name,
options,
*(translate_expr(visitor, n=n) for n in node.input),
)
raise NotImplementedError(
f"No handler for Expr function node with {name=}"
) # pragma: no cover; polars raises on the rust side for now


@_translate_expr.register
Expand Down
37 changes: 37 additions & 0 deletions python/cudf_polars/tests/expressions/test_round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import math

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture(params=[pl.Float32, pl.Float64])
def dtype(request):
return request.param


@pytest.fixture(params=[False, True], ids=["no_nulls", "nulls"])
def with_nulls(request):
return request.param


@pytest.fixture
def df(dtype, with_nulls):
a = [-math.e, 10, 22.5, 1.5, 2.5, -1.5, math.pi, 8]
if with_nulls:
a[2] = None
a[-1] = None
return pl.LazyFrame({"a": a}, schema={"a": dtype})


@pytest.mark.parametrize("decimals", [0, 2, 4])
def test_round(df, decimals):
q = df.select(pl.col("a").round(decimals=decimals))

assert_gpu_result_equal(q, check_exact=False)
24 changes: 24 additions & 0 deletions python/cudf_polars/tests/expressions/test_unique.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.mark.parametrize("maintain_order", [False, True], ids=["unstable", "stable"])
@pytest.mark.parametrize("pre_sorted", [False, True], ids=["unsorted", "sorted"])
def test_unique(maintain_order, pre_sorted):
ldf = pl.DataFrame(
{
"b": [1.5, 2.5, None, 1.5, 3, float("nan"), 3],
}
).lazy()
if pre_sorted:
ldf = ldf.sort("b")

query = ldf.select(pl.col("b").unique(maintain_order=maintain_order))
assert_gpu_result_equal(query, check_row_order=maintain_order)
2 changes: 2 additions & 0 deletions python/cudf_polars/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def keys(request):
[pl.col("float").max() - pl.col("int").min()],
[pl.col("float").mean(), pl.col("int").std()],
[(pl.col("float") - pl.lit(2)).max()],
[pl.col("float").sum().round(decimals=1)],
[pl.col("float").round(decimals=1).sum()],
],
ids=lambda aggs: "-".join(map(str, aggs)),
)
Expand Down

0 comments on commit acf0e2f

Please sign in to comment.