Skip to content

Commit

Permalink
fix(rust, python): fix nested aggregatin in when then and window expr… (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 16, 2022
1 parent 5d096fb commit 49058bb
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
43 changes: 42 additions & 1 deletion polars/polars-lazy/src/physical_plan/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,41 @@ impl WindowExpr {
agg_col
}

/// check if the the branches have an aggregation
/// when(a > sum)
/// then (foo)
/// otherwise(bar - sum)
fn has_different_group_sources(&self) -> bool {
let mut has_arity = false;
let mut agg_col = false;
for e in &self.expr {
if let Expr::Window { function, .. } = e {
// or list().alias
for e in &**function {
match e {
Expr::Ternary { .. } | Expr::BinaryExpr { .. } => {
has_arity = true;
}
Expr::Alias(_, _) => {}
Expr::Agg(_) => {
agg_col = true;
}
Expr::Function { options, .. }
| Expr::AnonymousFunction { options, .. } => {
if options.auto_explode
&& matches!(options.collect_groups, ApplyOptions::ApplyGroups)
{
agg_col = true;
}
}
_ => {}
}
}
}
}
has_arity && agg_col
}

fn determine_map_strategy(
&self,
agg_state: &AggState,
Expand Down Expand Up @@ -225,7 +260,7 @@ impl PhysicalExpr for WindowExpr {
let explicit_list_agg = self.is_explicit_list_agg();

// if we flatten this column we need to make sure the groups are sorted.
let sort_groups = self.options.explode ||
let mut sort_groups = self.options.explode ||
// if not
// `col().over()`
// and not
Expand All @@ -236,6 +271,12 @@ impl PhysicalExpr for WindowExpr {
// we may optimize with explode call
(!self.is_simple_column_expr() && !explicit_list_agg && sorted_keys && !self.is_aggregation());

// overwrite sort_groups for some expressions
// TODO: fully understand the rationale is here.
if self.has_different_group_sources() {
sort_groups = true
}

let create_groups = || {
let gb = df.groupby_with_series(groupby_columns.clone(), true, sort_groups)?;
let out: PolarsResult<GroupsProxy> = Ok(gb.take_groups());
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,26 @@ def test_sorted_window_expression() -> None:
out2 = df.with_column(expr)

assert out1.frame_equal(out2)


def test_nested_aggregation_window_expression() -> None:
df = pl.DataFrame(
{
"x": [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 2, 13, 4, 15, 6, None, None, 19],
"y": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
}
)

assert df.with_columns(
[
pl.when(pl.col("x") >= pl.col("x").quantile(0.1))
.then(1)
.otherwise(None)
.over("y")
.alias("foo")
]
).to_dict(False) == {
"x": [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 2, 13, 4, 15, 6, None, None, 19],
"y": [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
"foo": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, None, None, 1],
}

0 comments on commit 49058bb

Please sign in to comment.