Skip to content

Commit

Permalink
allow final aggregations in ternary expression (#1876)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 24, 2021
1 parent a8d0dd2 commit 4e04d02
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 21 deletions.
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ impl PhysicalExpr for BinaryExpr {
ac_l.with_series(ca.into_series(), true);
Ok(ac_l)
}
// Both are or a flat series or aggreagated into a list
// so we can flatten the Series an apply the operators
// Both are or a flat series or aggregated into a list
// so we can flatten the Series and apply the operators
_ => {
let out = apply_operator(ac_l.flat().as_ref(), ac_r.flat().as_ref(), self.op)?;
ac_l.combine_groups(ac_r).with_series(out, false);
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ impl<'a> AggregationContext<'a> {
}

pub(crate) fn aggregated(&mut self) -> Cow<'_, Series> {
// we do this here instead of the patter match because of mutable borrow overlaps.
// we do this here instead of the pattern match because of mutable borrow overlaps.
//
// The groups are determined lazily and in case of a flat/non-aggregated
// series we use the groups to aggregate the list
Expand Down
115 changes: 108 additions & 7 deletions polars/polars-lazy/src/physical_plan/expressions/ternary.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;
use polars_arrow::arrow::array::ArrayRef;
use polars_core::frame::groupby::GroupTuples;
use polars_core::prelude::*;
use std::convert::TryFrom;
use std::sync::Arc;

pub struct TernaryExpr {
Expand Down Expand Up @@ -47,7 +49,7 @@ The predicate produced {} values. Where the original DataFrame has {} values",

let mask = mask_s.bool()?;
let mut ac_truthy = self.truthy.evaluate_on_groups(df, groups, state)?;
let ac_falsy = self.falsy.evaluate_on_groups(df, groups, state)?;
let mut ac_falsy = self.falsy.evaluate_on_groups(df, groups, state)?;

if !ac_truthy.can_combine(&ac_falsy) {
return Err(PolarsError::InvalidOperation(
Expand All @@ -57,15 +59,114 @@ The predicate produced {} values. Where the original DataFrame has {} values",
));
}

let out = ac_truthy.flat().zip_with(mask, ac_falsy.flat().as_ref())?;
match (ac_truthy.agg_state(), ac_falsy.agg_state()) {
(AggState::AggregatedFlat(_), AggState::NotAggregated(_)) => {
// this is a flat series of len eq to group tuples
let truthy = ac_truthy.aggregated();
let truthy = truthy.as_ref();
let arr_truthy = &truthy.chunks()[0];
assert_eq!(truthy.len(), groups.len());

assert!(!(out.len() != required_height), "The output of the `when -> then -> otherwise-expr` is of a different length than the groups.\
// we create a dummy Series that is not cloned nor moved
// so we can swap the ArrayRef during the hot loop
// this prevents a series Arc alloc and a vec alloc per iteration
let dummy = Series::try_from(("dummy", vec![arr_truthy.clone()])).unwrap();
let chunks = unsafe {
let chunks = dummy.chunks();
let ptr = chunks.as_ptr() as *mut ArrayRef;
let len = chunks.len();
std::slice::from_raw_parts_mut(ptr, len)
};

// this is now a list
let falsy = ac_falsy.aggregated();
let falsy = falsy.list().unwrap();

let mut ca: ListChunked = falsy
.amortized_iter()
.enumerate()
.map(|(idx, opt_s)| {
opt_s
.map(|s| {
let falsy = s.as_ref();

// Safety:
// we are in bounds
let mut arr =
unsafe { Arc::from(arr_truthy.slice_unchecked(idx, 1)) };
std::mem::swap(&mut chunks[0], &mut arr);
let truthy = &dummy;

truthy.zip_with(mask, falsy)
})
.transpose()
})
.collect::<Result<_>>()?;
ca.rename(truthy.name());

ac_truthy.with_series(ca.into_series(), true);
Ok(ac_truthy)
}
(AggState::NotAggregated(_), AggState::AggregatedFlat(_)) => {
// this is now a list
let truthy = ac_truthy.aggregated();
let truthy = truthy.list().unwrap();

// this is a flat series of len eq to group tuples
let falsy = ac_falsy.aggregated();
assert_eq!(falsy.len(), groups.len());
let falsy = falsy.as_ref();
let arr_falsy = &falsy.chunks()[0];

// we create a dummy Series that is not cloned nor moved
// so we can swap the ArrayRef during the hot loop
// this prevents a series Arc alloc and a vec alloc per iteration
let dummy = Series::try_from(("dummy", vec![arr_falsy.clone()])).unwrap();
let chunks = unsafe {
let chunks = dummy.chunks();
let ptr = chunks.as_ptr() as *mut ArrayRef;
let len = chunks.len();
std::slice::from_raw_parts_mut(ptr, len)
};

let mut ca: ListChunked = truthy
.amortized_iter()
.enumerate()
.map(|(idx, opt_s)| {
opt_s
.map(|s| {
let truthy = s.as_ref();
// Safety:
// we are in bounds
let mut arr =
unsafe { Arc::from(arr_falsy.slice_unchecked(idx, 1)) };
std::mem::swap(&mut chunks[0], &mut arr);
let falsy = &dummy;

truthy.zip_with(mask, falsy)
})
.transpose()
})
.collect::<Result<_>>()?;
ca.rename(truthy.name());

ac_truthy.with_series(ca.into_series(), true);
Ok(ac_truthy)
}
// Both are or a flat series or aggreagated into a list
// so we can flatten the Series an apply the operators
_ => {
let out = ac_truthy.flat().zip_with(mask, ac_falsy.flat().as_ref())?;

assert!(!(out.len() != required_height), "The output of the `when -> then -> otherwise-expr` is of a different length than the groups.\
The expr produced {} values. Where the original DataFrame has {} values",
out.len(),
required_height);
out.len(),
required_height);

ac_truthy.with_series(out, false);
ac_truthy.with_series(out, false);

Ok(ac_truthy)
Ok(ac_truthy)
}
}
}
}
57 changes: 46 additions & 11 deletions polars/polars-lazy/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2245,7 +2245,7 @@ fn test_literal_window_fn() -> Result<()> {
}

#[test]
fn test_binary_agg_context_1() -> Result<()> {
fn test_binary_agg_context_0() -> Result<()> {
let df = df![
"groups" => [1, 1, 2, 2, 3, 3],
"vals" => [1, 2, 3, 4, 5, 6]
Expand Down Expand Up @@ -2279,23 +2279,58 @@ fn test_binary_agg_context_1() -> Result<()> {
Ok(())
}

// just like binary expression, this must ben changed. This can work
// just like binary expression, this must be changed. This can work
#[test]
#[should_panic]
fn test_invalid_ternary_in_agg2() {
fn test_binary_agg_context_1() -> Result<()> {
let df = df![
"groups" => [1, 1, 2, 2, 3, 3],
"vals" => [1, 2, 3, 4, 5, 6]
]
.unwrap();
]?;

df.lazy()
.groupby([col("groups")])
let out = df
.clone()
.lazy()
.stable_groupby([col("groups")])
.agg([when(col("vals").neq(lit(1)))
.then(col("vals").first())
.otherwise(lit("b"))])
.collect()
.unwrap();
.otherwise(lit(90))
.alias("vals")])
.collect()?;

// [90, 1]
// [90, 3]
// [90, 5]
let out = out.column("vals")?;
let out = out.explode()?;
let out = out.i32()?;
assert_eq!(
Vec::from(out),
&[Some(90), Some(1), Some(90), Some(3), Some(90), Some(5)]
);

let out = df
.lazy()
.stable_groupby([col("groups")])
.agg([when(col("vals").neq(lit(1)))
.then(lit(90))
.otherwise(col("vals").first())
.alias("vals")])
.collect()?;

dbg!(&out);

// [1, 90]
// [3, 90]
// [5, 90]
let out = out.column("vals")?;
let out = out.explode()?;
let out = out.i32()?;
assert_eq!(
Vec::from(out),
&[Some(1), Some(90), Some(3), Some(90), Some(5), Some(90)]
);

Ok(())
}

#[test]
Expand Down

0 comments on commit 4e04d02

Please sign in to comment.