Skip to content

Commit

Permalink
fix[rust]: apply flat overlapping row groups when possible (#5039)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 30, 2022
1 parent 0e6bff2 commit 8ecba44
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 72 deletions.
18 changes: 1 addition & 17 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1187,11 +1187,7 @@ impl Expr {
function: FunctionExpr::Pow,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: false,
auto_explode: false,
fmt_str: "pow",
cast_to_supertypes: false,
allow_rename: false,
..Default::default()
},
}
}
Expand All @@ -1204,7 +1200,6 @@ impl Expr {
function: FunctionExpr::Trigonometry(TrigonometricFunction::Sin),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "sin",
..Default::default()
},
}
Expand All @@ -1218,7 +1213,6 @@ impl Expr {
function: FunctionExpr::Trigonometry(TrigonometricFunction::Cos),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "cos",
..Default::default()
},
}
Expand All @@ -1232,7 +1226,6 @@ impl Expr {
function: FunctionExpr::Trigonometry(TrigonometricFunction::Tan),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "tan",
..Default::default()
},
}
Expand All @@ -1246,7 +1239,6 @@ impl Expr {
function: FunctionExpr::Trigonometry(TrigonometricFunction::ArcSin),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "arcsin",
..Default::default()
},
}
Expand All @@ -1260,7 +1252,6 @@ impl Expr {
function: FunctionExpr::Trigonometry(TrigonometricFunction::ArcCos),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "arccos",
..Default::default()
},
}
Expand All @@ -1274,7 +1265,6 @@ impl Expr {
function: FunctionExpr::Trigonometry(TrigonometricFunction::ArcTan),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "arctan",
..Default::default()
},
}
Expand All @@ -1288,7 +1278,6 @@ impl Expr {
function: FunctionExpr::Trigonometry(TrigonometricFunction::Sinh),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "sinh",
..Default::default()
},
}
Expand All @@ -1302,7 +1291,6 @@ impl Expr {
function: FunctionExpr::Trigonometry(TrigonometricFunction::Cosh),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "cosh",
..Default::default()
},
}
Expand All @@ -1316,7 +1304,6 @@ impl Expr {
function: FunctionExpr::Trigonometry(TrigonometricFunction::Tanh),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "tanh",
..Default::default()
},
}
Expand All @@ -1330,7 +1317,6 @@ impl Expr {
function: FunctionExpr::Trigonometry(TrigonometricFunction::ArcSinh),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "arcsinh",
..Default::default()
},
}
Expand All @@ -1344,7 +1330,6 @@ impl Expr {
function: FunctionExpr::Trigonometry(TrigonometricFunction::ArcCosh),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "arccosh",
..Default::default()
},
}
Expand Down Expand Up @@ -1372,7 +1357,6 @@ impl Expr {
function: FunctionExpr::Sign,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
fmt_str: "sign",
..Default::default()
},
}
Expand Down
130 changes: 76 additions & 54 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,62 +222,57 @@ impl PhysicalExpr for ApplyExpr {

// overlapping groups always take this branch as explode bloats data size
(_, ApplyOptions::ApplyGroups) | (true, _) => {
let mut container = vec![Default::default(); acs.len()];
let name = acs[0].series().name().to_string();

// aggregate representation of the aggregation contexts
// then unpack the lists and finally create iterators from this list chunked arrays.
let mut iters = acs
.iter_mut()
.map(|ac| ac.iter_groups())
.collect::<Vec<_>>();

// length of the items to iterate over
let len = iters[0].size_hint().0;

let mut ca: ListChunked = (0..len)
.map(|_| {
container.clear();
for iter in &mut iters {
match iter.next().unwrap() {
None => return None,
Some(s) => container.push(s.deep_clone()),
// if
// - there are overlapping groups
// - can do elementwise operations
// - we don't have to explode
// then apply flat
if let (
true,
ApplyOptions::ApplyFlat,
AggState::AggregatedFlat(_) | AggState::NotAggregated(_),
) = (
state.overlapping_groups(),
self.collect_groups,
acs[0].agg_state(),
) {
apply_multiple_flat(acs, self.function.as_ref())
} else {
let mut container = vec![Default::default(); acs.len()];
let name = acs[0].series().name().to_string();

// aggregate representation of the aggregation contexts
// then unpack the lists and finally create iterators from this list chunked arrays.
let mut iters = acs
.iter_mut()
.map(|ac| ac.iter_groups())
.collect::<Vec<_>>();

// length of the items to iterate over
let len = iters[0].size_hint().0;

let mut ca: ListChunked = (0..len)
.map(|_| {
container.clear();
for iter in &mut iters {
match iter.next().unwrap() {
None => return None,
Some(s) => container.push(s.deep_clone()),
}
}
}
self.function.call_udf(&mut container).ok()
})
.collect_trusted();
ca.rename(&name);
drop(iters);

// take the first aggregation context that as that is the input series
let ac = acs.swap_remove(0);
let ac = self.finish_apply_groups(ac, ca);
Ok(ac)
}
(_, ApplyOptions::ApplyFlat) => {
let mut s = acs
.iter_mut()
.map(|ac| {
// make sure the groups are updated because we are about to throw away
// the series length information
if let UpdateGroups::WithSeriesLen = ac.update_groups {
ac.groups();
}

ac.flat_naive().into_owned()
})
.collect::<Vec<_>>();

let input_len = s[0].len();
let s = self.function.call_udf(&mut s)?;
check_map_output_len(input_len, s.len())?;

// take the first aggregation context that as that is the input series
let mut ac = acs.swap_remove(0);
ac.with_series(s, false);
Ok(ac)
self.function.call_udf(&mut container).ok()
})
.collect_trusted();
ca.rename(&name);
drop(iters);

// take the first aggregation context that as that is the input series
let ac = acs.swap_remove(0);
let ac = self.finish_apply_groups(ac, ca);
Ok(ac)
}
}
(_, ApplyOptions::ApplyFlat) => apply_multiple_flat(acs, self.function.as_ref()),
}
}
}
Expand Down Expand Up @@ -310,6 +305,33 @@ impl PhysicalExpr for ApplyExpr {
}
}

fn apply_multiple_flat<'a>(
mut acs: Vec<AggregationContext<'a>>,
function: &dyn SeriesUdf,
) -> PolarsResult<AggregationContext<'a>> {
let mut s = acs
.iter_mut()
.map(|ac| {
// make sure the groups are updated because we are about to throw away
// the series length information
if let UpdateGroups::WithSeriesLen = ac.update_groups {
ac.groups();
}

ac.flat_naive().into_owned()
})
.collect::<Vec<_>>();

let input_len = s[0].len();
let s = function.call_udf(&mut s)?;
check_map_output_len(input_len, s.len())?;

// take the first aggregation context that as that is the input series
let mut ac = acs.swap_remove(0);
ac.with_series(s, false);
Ok(ac)
}

#[cfg(feature = "parquet")]
impl StatsEvaluator for ApplyExpr {
fn should_read(&self, stats: &BatchStats) -> PolarsResult<bool> {
Expand Down
22 changes: 21 additions & 1 deletion py-polars/tests/unit/test_groupby.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from datetime import datetime
from datetime import datetime, timedelta

import pytest

import polars as pl

Expand Down Expand Up @@ -170,3 +172,21 @@ def test_groupby_dynamic_flat_agg_4814() -> None:
"last_ratio_1": [6.0, 6.0],
"last_ratio_2": [6.0, 6.0],
}


def test_groupby_dynamic_overlapping_groups_flat_apply_multiple_5038() -> None:
assert (
pl.DataFrame(
{
"a": [
datetime(2021, 1, 1) + timedelta(seconds=2**i) for i in range(10)
],
"b": [float(i) for i in range(10)],
}
)
.lazy()
.groupby_dynamic("a", every="10s", period="100s")
.agg([pl.col("b").var().sqrt().alias("corr")])
).collect().sum().to_dict(False) == pytest.approx(
{"a": [None], "corr": [6.988674024215477]}
)

0 comments on commit 8ecba44

Please sign in to comment.