Skip to content

Commit

Permalink
fix groups state after apply (#2992)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 28, 2022
1 parent c0f184b commit 5fabae0
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 18 deletions.
21 changes: 10 additions & 11 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ impl ApplyExpr {
&self,
mut ac: AggregationContext<'a>,
ca: ListChunked,
all_unit_len: bool,
) -> AggregationContext<'a> {
let all_unit_len = all_unit_length(&ca);
if all_unit_len && self.auto_explode {
ac.with_series(ca.explode().unwrap().into_series(), true);
ac.update_groups = UpdateGroups::No;
} else {
ac.with_series(ca.into_series(), true);
ac.with_update_groups(UpdateGroups::WithSeriesLen);
}
ac.with_all_unit_len(all_unit_len);
ac.with_update_groups(UpdateGroups::WithSeriesLen);
ac
}
}
Expand Down Expand Up @@ -101,13 +101,16 @@ impl PhysicalExpr for ApplyExpr {
})
.collect();

let all_unit_len = all_unit_length(&ca);

ca.rename(&name);
let ac = self.finish_apply_groups(ac, ca, all_unit_len);
let ac = self.finish_apply_groups(ac, ca);
Ok(ac)
}
ApplyOptions::ApplyFlat => {
// make sure the groups are updated because we are about to throw away
// the series length information
if let UpdateGroups::WithSeriesLen = ac.update_groups {
ac.groups();
}
let input = ac.flat_naive().into_owned();
let input_len = input.len();
let s = self.function.call_udf(&mut [input])?;
Expand All @@ -116,9 +119,6 @@ impl PhysicalExpr for ApplyExpr {
return Err(PolarsError::ComputeError("A map function may never return a Series of a different length than its input".into()));
}

if ac.is_aggregated() {
ac.with_update_groups(UpdateGroups::WithGroupsLen);
}
ac.with_series(s, false);
Ok(ac)
}
Expand Down Expand Up @@ -166,9 +166,8 @@ impl PhysicalExpr for ApplyExpr {
})
.collect_trusted();
ca.rename(&name);
let all_unit_len = all_unit_length(&ca);
let ac = acs.pop().unwrap();
let ac = self.finish_apply_groups(ac, ca, all_unit_len);
let ac = self.finish_apply_groups(ac, ca);
Ok(ac)
}
ApplyOptions::ApplyFlat => {
Expand Down
7 changes: 0 additions & 7 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,9 @@ pub struct AggregationContext<'a> {
/// This is true when the Series and GroupsProxy still have all
/// their original values. Not the case when filtered
original_len: bool,
all_unit_len: bool,
}

impl<'a> AggregationContext<'a> {
pub(crate) fn with_all_unit_len(&mut self, toggle: bool) {
self.all_unit_len = toggle
}

pub(crate) fn groups(&mut self) -> &Cow<'a, GroupsProxy> {
match self.update_groups {
UpdateGroups::No => {}
Expand Down Expand Up @@ -264,7 +259,6 @@ impl<'a> AggregationContext<'a> {
sorted: false,
update_groups: UpdateGroups::No,
original_len: true,
all_unit_len: false,
}
}

Expand All @@ -275,7 +269,6 @@ impl<'a> AggregationContext<'a> {
sorted: false,
update_groups: UpdateGroups::No,
original_len: true,
all_unit_len: false,
}
}

Expand Down
23 changes: 23 additions & 0 deletions polars/tests/it/lazy/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,26 @@ fn test_arange_agg() -> Result<()> {

Ok(())
}

#[test]
#[cfg(all(feature = "unique_counts", feature = "log"))]
fn test_groups_update() -> Result<()> {
let df = df!["group" => ["A" ,"A", "A", "B", "B", "B", "B"],
"id"=> [1, 1, 2, 3, 4, 3, 5]
]?;

let out = df
.lazy()
.groupby_stable([col("group")])
.agg([col("id").unique_counts().log(2.0)])
.explode([col("id")])
.collect()?;
assert_eq!(
out.column("id")?
.f64()?
.into_no_null_iter()
.collect::<Vec<_>>(),
&[1.0, 0.0, 1.0, 0.0, 0.0]
);
Ok(())
}

0 comments on commit 5fabae0

Please sign in to comment.