Skip to content

Commit

Permalink
ternary literal predicates (#3747)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 20, 2022
1 parent 3fc26cb commit 969ff8f
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 23 deletions.
15 changes: 10 additions & 5 deletions polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,11 +540,6 @@ impl LazyFrame {
let opt = StackOptimizer {};
let mut rules: Vec<Box<dyn OptimizationRule>> = Vec::with_capacity(8);

if simplify_expr {
rules.push(Box::new(SimplifyExprRule {}));
rules.push(Box::new(SimplifyBooleanRule {}));
}

// during debug we check if the optimizations have not modified the final schema
#[cfg(debug_assertions)]
let prev_schema = logical_plan.schema().clone();
Expand All @@ -556,6 +551,11 @@ impl LazyFrame {
// this optimization will run twice because optimizer may create dumb expressions
lp_top = opt.optimize_loop(&mut rules, expr_arena, lp_arena, lp_top);

// we do simplification
if simplify_expr {
rules.push(Box::new(SimplifyExprRule {}));
}

if projection_pushdown {
let projection_pushdown_opt = ProjectionPushDown {};
let alp = lp_arena.take(lp_top);
Expand Down Expand Up @@ -593,6 +593,11 @@ impl LazyFrame {
if type_coercion {
rules.push(Box::new(TypeCoercionRule {}))
}
// this optimization removes branches, so we must do it when type coercion
// is completed
if simplify_expr {
rules.push(Box::new(SimplifyBooleanRule {}));
}

if aggregate_pushdown {
rules.push(Box::new(AggregatePushdown::new()))
Expand Down
26 changes: 25 additions & 1 deletion polars/polars-lazy/src/logical_plan/optimizer/simplify_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,31 @@ impl OptimizationRule for SimplifyBooleanRule {
{
Some(AExpr::Literal(LiteralValue::Boolean(false)))
}

AExpr::Ternary {
truthy, predicate, ..
} if matches!(
expr_arena.get(*predicate),
AExpr::Literal(LiteralValue::Boolean(true))
) =>
{
Some(expr_arena.get(*truthy).clone())
}
AExpr::Ternary {
truthy,
falsy,
predicate,
} if matches!(
expr_arena.get(*predicate),
AExpr::Literal(LiteralValue::Boolean(false))
) =>
{
let names = aexpr_to_root_names(*truthy, expr_arena);
if names.is_empty() {
None
} else {
Some(AExpr::Alias(*falsy, names[0].clone()))
}
}
AExpr::Not(x) => {
let y = expr_arena.get(*x);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,20 @@ fn run_partitions(
let agg_expr = expr.as_partitioned_aggregator().unwrap();
let agg = agg_expr.evaluate_partitioned(&df, groups, state)?;
if agg.len() != groups.len() {
Err(PolarsError::ComputeError(
format!("returned aggregation is a different length: {} than the group lengths: {}",
agg.len(),
groups.len()).into()
))

if agg.len() == 1 {
Ok(match groups.len() {
0 => agg.slice(0, 0),
len => agg.expand_at_index(0, len)
})
} else {
Err(PolarsError::ComputeError(
format!("returned aggregation is a different length: {} than the group lengths: {}",
agg.len(),
groups.len()).into()
))
}

} else {
Ok(agg)
}
Expand Down
16 changes: 5 additions & 11 deletions polars/polars-lazy/src/physical_plan/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ impl PhysicalExpr for TernaryExpr {
groups: &'a GroupsProxy,
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let required_height = df.height();

let op_mask = || self.predicate.evaluate_on_groups(df, groups, state);
let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state);
let op_falsy = || self.falsy.evaluate_on_groups(df, groups, state);
Expand Down Expand Up @@ -165,15 +163,11 @@ impl PhysicalExpr for TernaryExpr {
// Both are or a flat series or aggregated into a list
// so we can flatten the Series an apply the operators
_ => {
let mask = mask_s.bool()?;
let out = ac_truthy
.flat_naive()
.zip_with(mask, ac_falsy.flat_naive().as_ref())?;

assert!((out.len() == required_height), "The output of the `when -> then -> otherwise-expr` is of a different length than the groups.\
The expr produced {} values. Where the original DataFrame has {} values",
out.len(),
required_height);
let mut mask = mask_s.bool()?.clone();
let mut truthy = ac_truthy.flat_naive().into_owned();
let mut falsy = ac_falsy.flat_naive().into_owned();
expand_lengths(&mut truthy, &mut falsy, &mut mask);
let out = truthy.zip_with(&mask, &falsy)?;

ac_truthy.with_series(out, false);

Expand Down
47 changes: 47 additions & 0 deletions polars/polars-lazy/src/tests/optimization_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,50 @@ fn test_with_row_count_opts() -> Result<()> {

Ok(())
}

#[test]
fn test_groupby_ternary_literal_predicate() -> Result<()> {
let df = df![
"a" => [1, 2, 3],
"b" => [1, 2, 3]
]?;

for predicate in [true, false] {
let q = df
.clone()
.lazy()
.groupby(["a"])
.agg([when(lit(predicate))
.then(col("b").sum())
.otherwise(NULL.lit())])
.sort("a", Default::default());

let (mut expr_arena, mut lp_arena) = get_arenas();
let lp = q.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap();

(&lp_arena).iter(lp).any(|(_, lp)| {
use ALogicalPlan::*;
match lp {
Aggregate { aggs, .. } => {
for node in aggs {
// we should not have a ternary expression anymore
assert!(!matches!(expr_arena.get(*node), AExpr::Ternary { .. }));
}
false
}
_ => false,
}
});

let out = q.collect()?;
let b = out.column("b")?;
let b = b.i32()?;
if predicate {
assert_eq!(Vec::from(b), &[Some(1), Some(2), Some(3)]);
} else {
assert_eq!(b.null_count(), 3);
};
}

Ok(())
}
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1958,7 +1958,7 @@ fn test_is_in() -> Result<()> {
}

#[test]
fn test_partitioned_gb() -> Result<()> {
fn test_partitioned_gb_1() -> Result<()> {
// don't move these to integration tests
// keep these dtypes
let out = df![
Expand Down

0 comments on commit 969ff8f

Please sign in to comment.