Skip to content

Commit

Permalink
update groups in count() agg and correctly update state (#2963)
Browse files Browse the repository at this point in the history
* update groups in count() agg and correctly update state

* don't materialize groups for simple count expr
  • Loading branch information
ritchie46 committed Mar 24, 2022
1 parent a28dfa7 commit 69e7151
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 5 deletions.
61 changes: 58 additions & 3 deletions polars/polars-lazy/src/physical_plan/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use crate::physical_plan::PhysicalAggregation;
use crate::prelude::*;
use polars_arrow::export::arrow::{array::*, compute::concatenate::concatenate};
use polars_arrow::prelude::QuantileInterpolOptions;
use polars_arrow::utils::CustomIterTools;
use polars_core::frame::groupby::{GroupByMethod, GroupsProxy};
use polars_core::utils::NoNull;
use polars_core::{prelude::*, POOL};
use std::borrow::Cow;
use std::sync::Arc;
Expand Down Expand Up @@ -89,9 +91,62 @@ impl PhysicalAggregation for AggregationExpr {
Ok(rename_option_series(agg_s, &keep_name))
}
GroupByMethod::Count => {
let mut ca = ac.groups.group_count();
ca.rename(&keep_name);
Ok(Some(ca.into_series()))
// a few fast paths that prevent materializing new groups
match ac.update_groups {
UpdateGroups::WithSeriesLen => {
let list = ac
.series()
.list()
.expect("impl error, should be a list at this point");

let mut s = match list.chunks().len() {
1 => {
let arr = list.downcast_iter().next().unwrap();
let offsets = arr.offsets().as_slice();

let mut previous = 0i64;
let counts: NoNull<IdxCa> = offsets[1..]
.iter()
.map(|&o| {
let len = (o - previous) as IdxSize;
previous = o;
len
})
.collect_trusted();
counts.into_inner()
}
_ => {
let counts: NoNull<IdxCa> = list
.amortized_iter()
.map(|s| {
if let Some(s) = s {
s.as_ref().len() as IdxSize
} else {
1
}
})
.collect_trusted();
counts.into_inner()
}
};
s.rename(&keep_name);
Ok(Some(s.into_series()))
}
UpdateGroups::WithGroupsLen => {
// no need to update the groups
// we can just get the attribute, because we only need the length,
// not the correct order
let mut ca = ac.groups.group_count();
ca.rename(&keep_name);
Ok(Some(ca.into_series()))
}
// materialize groups
_ => {
let mut ca = ac.groups().group_count();
ca.rename(&keep_name);
Ok(Some(ca.into_series()))
}
}
}
GroupByMethod::First => {
let mut agg_s = ac.flat_naive().into_owned().agg_first(ac.groups());
Expand Down
11 changes: 9 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl PhysicalExpr for BinaryExpr {
opt_s
.map(|s| {
let r = s.as_ref();
// TODO: optimize this? Its slow and unsafe.
// TODO: optimize this?

// Safety:
// we are in bounds
Expand Down Expand Up @@ -221,7 +221,14 @@ impl PhysicalExpr for BinaryExpr {
ca.rename(l.name());

ac_l.with_series(ca.into_series(), true);
ac_l.with_update_groups(UpdateGroups::WithGroupsLen);
// Todo! maybe always update with groups len here?
if matches!(ac_l.update_groups, UpdateGroups::WithSeriesLen)
|| matches!(ac_r.update_groups, UpdateGroups::WithSeriesLen)
{
ac_l.with_update_groups(UpdateGroups::WithSeriesLen);
} else {
ac_l.with_update_groups(UpdateGroups::WithGroupsLen);
}
Ok(ac_l)
}
(AggState::AggregatedList(_), AggState::NotAggregated(_) | AggState::Literal(_), _)
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ pub(crate) enum UpdateGroups {
/// don't update groups
No,
/// use the length of the current groups to determine new sorted indexes, preferred
/// for performance
WithGroupsLen,
/// use the series list offsets to determine the new group lengths
/// this one should be used when the length has changed. Note that
Expand Down
30 changes: 30 additions & 0 deletions py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,3 +1195,33 @@ class SubClassedLazyFrame(pl.LazyFrame):
extended_ldf = ldf.with_column(pl.lit(1).alias("column_2"))
assert isinstance(extended_ldf, pl.LazyFrame)
assert isinstance(extended_ldf, SubClassedLazyFrame)


def test_group_lengths() -> None:
df = pl.DataFrame(
{
"group": ["A", "A", "A", "B", "B", "B", "B"],
"id": ["1", "1", "2", "3", "4", "3", "5"],
}
)

assert (
df.groupby(["group"], maintain_order=True)
.agg(
[
(pl.col("id").unique_counts() / pl.col("id").len())
.sum()
.alias("unique_counts_sum"),
pl.col("id").unique().len().alias("unique_len"),
]
)
.frame_equal(
pl.DataFrame(
{
"group": ["A", "B"],
"unique_counts_sum": [1.0, 1.0],
"unique_len": [2, 3],
}
)
)
)

0 comments on commit 69e7151

Please sign in to comment.