Skip to content

Commit

Permalink
fix when then with literal (#3009)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 30, 2022
1 parent 5387319 commit 7adf768
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 11 deletions.
10 changes: 10 additions & 0 deletions polars/polars-core/src/chunked_array/bitwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,16 @@ impl BitAnd for &BooleanChunked {
.downcast_iter()
.zip(rhs.downcast_iter())
.map(|(lhs, rhs)| {
// early return from all `false` paths
if lhs.null_count() == 0 && rhs.null_count() == 0 {
if lhs.values().null_count() == lhs.len() {
return Arc::new(lhs.clone()) as ArrayRef;
}
if rhs.values().null_count() == rhs.len() {
return Arc::new(rhs.clone()) as ArrayRef;
}
}

Arc::new(compute::boolean_kleene::and(lhs, rhs).expect("should be same size"))
as ArrayRef
})
Expand Down
54 changes: 46 additions & 8 deletions polars/polars-lazy/src/physical_plan/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,41 @@ use std::convert::TryFrom;
use std::sync::Arc;

pub struct TernaryExpr {
pub predicate: Arc<dyn PhysicalExpr>,
pub truthy: Arc<dyn PhysicalExpr>,
pub falsy: Arc<dyn PhysicalExpr>,
pub expr: Expr,
predicate: Arc<dyn PhysicalExpr>,
truthy: Arc<dyn PhysicalExpr>,
falsy: Arc<dyn PhysicalExpr>,
expr: Expr,
}

impl TernaryExpr {
pub fn new(
predicate: Arc<dyn PhysicalExpr>,
truthy: Arc<dyn PhysicalExpr>,
falsy: Arc<dyn PhysicalExpr>,
expr: Expr,
) -> Self {
Self {
predicate,
truthy,
falsy,
expr,
}
}
}

fn expand_lengths(truthy: &mut Series, falsy: &mut Series, mask: &mut BooleanChunked) {
let len = std::cmp::max(std::cmp::max(truthy.len(), falsy.len()), mask.len());
if len > 1 {
if falsy.len() == 1 {
*falsy = falsy.expand_at_index(0, len);
}
if truthy.len() == 1 {
*truthy = truthy.expand_at_index(0, len);
}
if mask.len() == 1 {
*mask = mask.expand_at_index(0, len);
}
}
}

impl PhysicalExpr for TernaryExpr {
Expand All @@ -20,10 +51,17 @@ impl PhysicalExpr for TernaryExpr {
}
fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> Result<Series> {
let mask_series = self.predicate.evaluate(df, state)?;
let mask = mask_series.bool()?;
let truthy = self.truthy.evaluate(df, state)?;
let falsy = self.falsy.evaluate(df, state)?;
truthy.zip_with(mask, &falsy)
let mut mask = mask_series.bool()?.clone();

let op_truthy = || self.truthy.evaluate(df, state);
let op_falsy = || self.falsy.evaluate(df, state);

let (truthy, falsy) = POOL.install(|| rayon::join(op_truthy, op_falsy));
let mut truthy = truthy?;
let mut falsy = falsy?;
expand_lengths(&mut truthy, &mut falsy, &mut mask);

truthy.zip_with(&mask, &falsy)
}
fn to_field(&self, input_schema: &Schema) -> Result<Field> {
self.truthy.to_field(input_schema)
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-lazy/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -960,12 +960,12 @@ impl DefaultPlanner {
let predicate = self.create_physical_expr(predicate, ctxt, expr_arena)?;
let truthy = self.create_physical_expr(truthy, ctxt, expr_arena)?;
let falsy = self.create_physical_expr(falsy, ctxt, expr_arena)?;
Ok(Arc::new(TernaryExpr {
Ok(Arc::new(TernaryExpr::new(
predicate,
truthy,
falsy,
expr: node_to_expr(expression, expr_arena),
}))
node_to_expr(expression, expr_arena),
)))
}
Function {
input,
Expand Down
28 changes: 28 additions & 0 deletions polars/tests/it/lazy/expressions/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,31 @@ fn test_list_broadcast() {
.collect()
.unwrap();
}

#[test]
fn ternary_expand_sizes() -> Result<()> {
let df = df! {
"a" => [Some("a1"), None, None],
"b" => [Some("b1"), Some("b2"), None]
}?;
let out = df
.lazy()
.with_column(
when(not(lit(true)))
.then(lit("unexpected"))
.when(not(col("a").is_null()))
.then(col("a"))
.when(not(col("b").is_null()))
.then(col("b"))
.otherwise(lit("otherwise"))
.alias("c"),
)
.collect()?;
let vals = out
.column("c")?
.utf8()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(vals, &["a1", "b2", "otherwise"]);
Ok(())
}

0 comments on commit 7adf768

Please sign in to comment.