Skip to content

Commit

Permalink
improve ternary in groupby context (#3248)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 28, 2022
1 parent 1b4a516 commit 4d02a46
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 13 deletions.
18 changes: 16 additions & 2 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,25 @@ impl PhysicalExpr for BinaryExpr {
Ok(ac_l)
}
// One of the two exprs is aggregated with flat aggregation, e.g. `e.min(), e.max(), e.first()`
// the other is a literal value. In that case it is unlikely we want to expand this to the
// group sizes.
//
(AggState::AggregatedFlat(_), AggState::Literal(_), _)
| (AggState::Literal(_), AggState::AggregatedFlat(_), _) => {
let l = ac_l.series();
let r = ac_r.series();
let mut s = apply_operator(l, r, self.op)?;
s.rename(l.name());

ac_l.with_series(s, true);
Ok(ac_l)
}
// One of the two exprs is aggregated with flat aggregation, e.g. `e.min(), e.max(), e.first()`

// if the groups_len == df.len we can just apply all flat.
// within an aggregation a `col().first() - lit(0)` must still produce a boolean array of group length,
// that's why a literal also takes this branch
(AggState::AggregatedFlat(s), AggState::NotAggregated(_) | AggState::Literal(_), _)
(AggState::AggregatedFlat(s), AggState::NotAggregated(_), _)
if s.len() != df.height() =>
{
// this is a flat series of len eq to group tuples
Expand Down Expand Up @@ -179,7 +193,7 @@ impl PhysicalExpr for BinaryExpr {
}
// if the groups_len == df.len we can just apply all flat.
(
AggState::Literal(_) | AggState::AggregatedList(_) | AggState::NotAggregated(_),
AggState::AggregatedList(_) | AggState::NotAggregated(_),
AggState::AggregatedFlat(s),
_,
) if s.len() != df.height() => {
Expand Down
29 changes: 20 additions & 9 deletions polars/polars-lazy/src/physical_plan/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,6 @@ impl PhysicalExpr for TernaryExpr {

let mask_s = ac_mask.flat_naive();

assert!(
(mask_s.len() == required_height),
"The predicate is of a different length than the groups.\
The predicate produced {} values. Where the original DataFrame has {} values",
mask_s.len(),
required_height
);

assert!(
ac_truthy.can_combine(&ac_falsy),
"cannot combine this ternary expression, the groups do not match"
Expand Down Expand Up @@ -161,6 +153,24 @@ The predicate produced {} values. Where the original DataFrame has {} values",
ac_truthy.with_series(ca.into_series(), true);
Ok(ac_truthy)
}
// all aggregated or literal
// simply align lengths and zip
(
AggState::Literal(truthy) | AggState::AggregatedFlat(truthy),
AggState::AggregatedFlat(falsy) | AggState::Literal(falsy),
)
| (AggState::AggregatedList(truthy), AggState::AggregatedList(falsy))
if matches!(ac_mask.agg_state(), AggState::AggregatedFlat(_)) =>
{
let mut truthy = truthy.clone();
let mut falsy = falsy.clone();
let mut mask = ac_mask.series().bool()?.clone();
expand_lengths(&mut truthy, &mut falsy, &mut mask);
let mut out = truthy.zip_with(&mask, &falsy).unwrap();
out.rename(truthy.name());
ac_truthy.with_series(out, true);
Ok(ac_truthy)
}
// if the groups_len == df.len we can just apply all flat.
(AggState::NotAggregated(_) | AggState::Literal(_), AggState::AggregatedFlat(s))
if s.len() != df.height() =>
Expand Down Expand Up @@ -220,7 +230,8 @@ The predicate produced {} values. Where the original DataFrame has {} values",
ac_truthy.with_series(ca.into_series(), true);
Ok(ac_truthy)
}
// Both are or a flat series or aggreagated into a list

// Both are or a flat series or aggregated into a list
// so we can flatten the Series an apply the operators
_ => {
let mask = mask_s.bool()?;
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/tests/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ fn test_binary_agg_context_0() -> Result<()> {
.lazy()
.groupby_stable([col("groups")])
.agg([when(col("vals").first().neq(lit(1)))
.then(lit("a"))
.otherwise(lit("b"))
.then(repeat("a", count()))
.otherwise(repeat("b", count()))
.alias("foo")])
.collect()
.unwrap();
Expand Down
26 changes: 26 additions & 0 deletions polars/tests/it/lazy/expressions/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,29 @@ fn test_when_then_otherwise_cats() -> Result<()> {

Ok(())
}

#[test]
fn test_when_then_otherwise_single_bool() -> Result<()> {
let df = df![
"key" => ["a", "b", "b"],
"val" => [Some(1), Some(2), None]
]?;

let out = df
.lazy()
.groupby_stable([col("key")])
.agg([when(col("val").null_count().gt(lit(0)))
.then(Null {}.lit())
.otherwise(col("val").sum())
.alias("sum_null_prop")])
.collect()?;

let expected = df![
"key" => ["a", "b"],
"sum_null_prop" => [Some(1), None]
]?;

assert!(out.frame_equal_missing(&expected));

Ok(())
}

0 comments on commit 4d02a46

Please sign in to comment.