Skip to content

Commit

Permalink
align groups in binary when they not align (#3033)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 1, 2022
1 parent 3e7631f commit 15a3cdf
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 70 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/frame/groupby/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ impl<'df> GroupBy<'df> {
}
}

#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
pub enum GroupByMethod {
Min,
Max,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,18 +268,6 @@ impl ProjectionPushDown {
}
}
}
// if has_aexpr(*e, expr_arena, |ae| matches!(ae, AExpr::Alias(_, _))) {}
//
// if let AExpr::Alias(_, name) = expr_arena.get(*e) {
// if projected_names.remove(name) {
// acc_projections = acc_projections
// .into_iter()
// .filter(|expr| {
// !aexpr_to_root_names(*expr, expr_arena).contains(name)
// })
// .collect();
// }
// }
}

add_expr_to_accumulated(
Expand Down
28 changes: 21 additions & 7 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,28 @@ impl PhysicalExpr for BinaryExpr {
// Both are or a flat series
// so we can flatten the Series and apply the operators
_ => {
let out = apply_operator(
ac_l.flat_naive().as_ref(),
ac_r.flat_naive().as_ref(),
self.op,
)?;
// the groups state differs, so we aggregate both and flatten again to make them align
if ac_l.update_groups != UpdateGroups::No || ac_r.update_groups != UpdateGroups::No
{
// use the aggregated state to determine the new groups
let lhs = ac_l.aggregated();
ac_l.with_update_groups(UpdateGroups::WithSeriesLenOwned(lhs.clone()));

let out =
apply_operator(&lhs.explode()?, &ac_r.aggregated().explode()?, self.op)?;
ac_l.with_series(out, false);
Ok(ac_l)
} else {
let out = apply_operator(
ac_l.flat_naive().as_ref(),
ac_r.flat_naive().as_ref(),
self.op,
)?;

ac_l.combine_groups(ac_r).with_series(out, false);
Ok(ac_l)
ac_l.combine_groups(ac_r).with_series(out, false);

Ok(ac_l)
}
}
}
}
Expand Down
111 changes: 61 additions & 50 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl AggState {

// lazy update strategy
#[cfg_attr(debug_assertions, derive(Debug))]
#[derive(PartialEq)]
pub(crate) enum UpdateGroups {
/// don't update groups
No,
Expand All @@ -70,6 +71,8 @@ pub(crate) enum UpdateGroups {
/// this one should be used when the length has changed. Note that
/// the series should be aggregated state or else it will panic.
WithSeriesLen,
// Same as WithSeriesLen, but now take a series given by the caller
WithSeriesLenOwned(Series),
}

#[cfg_attr(debug_assertions, derive(Debug))]
Expand Down Expand Up @@ -124,56 +127,12 @@ impl<'a> AggregationContext<'a> {
self.update_groups = UpdateGroups::No;
}
UpdateGroups::WithSeriesLen => {
let mut offset = 0 as IdxSize;
let list = self
.series()
.list()
.expect("impl error, should be a list at this point");

match list.chunks().len() {
1 => {
let arr = list.downcast_iter().next().unwrap();
let offsets = arr.offsets().as_slice();

let mut previous = 0i64;
let groups = offsets[1..]
.iter()
.map(|&o| {
let len = (o - previous) as IdxSize;
// explode will fill empty rows with null, so we must increment the group
// offset accordingly
let new_offset = offset + len + (len == 0) as IdxSize;

previous = o;
let out = [offset, len];
offset = new_offset;
out
})
.collect_trusted();
self.groups = Cow::Owned(GroupsProxy::Slice(groups));
}
_ => {
let groups = self
.series()
.list()
.expect("impl error, should be a list at this point")
.amortized_iter()
.map(|s| {
if let Some(s) = s {
let len = s.as_ref().len() as IdxSize;
let new_offset = offset + len;
let out = [offset, len];
offset = new_offset;
out
} else {
[offset, 0]
}
})
.collect_trusted();
self.groups = Cow::Owned(GroupsProxy::Slice(groups));
}
}
self.update_groups = UpdateGroups::No;
let s = self.series().clone();
self.det_groups_from_list(&s);
}
UpdateGroups::WithSeriesLenOwned(ref s) => {
let s = s.clone();
self.det_groups_from_list(&s);
}
}
&self.groups
Expand Down Expand Up @@ -289,6 +248,58 @@ impl<'a> AggregationContext<'a> {
self
}

pub(crate) fn det_groups_from_list(&mut self, s: &Series) {
let mut offset = 0 as IdxSize;
let list = s
.list()
.expect("impl error, should be a list at this point");

match list.chunks().len() {
1 => {
let arr = list.downcast_iter().next().unwrap();
let offsets = arr.offsets().as_slice();

let mut previous = 0i64;
let groups = offsets[1..]
.iter()
.map(|&o| {
let len = (o - previous) as IdxSize;
// explode will fill empty rows with null, so we must increment the group
// offset accordingly
let new_offset = offset + len + (len == 0) as IdxSize;

previous = o;
let out = [offset, len];
offset = new_offset;
out
})
.collect_trusted();
self.groups = Cow::Owned(GroupsProxy::Slice(groups));
}
_ => {
let groups = self
.series()
.list()
.expect("impl error, should be a list at this point")
.amortized_iter()
.map(|s| {
if let Some(s) = s {
let len = s.as_ref().len() as IdxSize;
let new_offset = offset + len;
let out = [offset, len];
offset = new_offset;
out
} else {
[offset, 0]
}
})
.collect_trusted();
self.groups = Cow::Owned(GroupsProxy::Slice(groups));
}
}
self.update_groups = UpdateGroups::No;
}

/// In a binary expression one state can be aggregated and the other not.
/// If both would be flattened naively one would be sorted and the other not.
/// Calling this function will ensure both are sortened. This will be a no-op
Expand Down
21 changes: 21 additions & 0 deletions polars/tests/it/lazy/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,24 @@ fn test_groups_update() -> Result<()> {
);
Ok(())
}

#[test]
#[cfg(feature = "log")]
fn test_groups_update_binary_shift_log() -> Result<()> {
let out = df![
"a" => [1, 2, 3, 5],
"b" => [1, 2, 1, 2],
]?
.lazy()
.groupby([col("b")])
.agg([col("a") - col("a").shift(1).log(2.0)])
.sort("b", Default::default())
.explode([col("a")])
.collect()?;
assert_eq!(
Vec::from(out.column("a")?.f64()?),
&[None, Some(3.0), None, Some(4.0)]
);

Ok(())
}

0 comments on commit 15a3cdf

Please sign in to comment.