Skip to content

Commit

Permalink
improve predicate pushdown (#3313)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 5, 2022
1 parent b97ba6f commit 3bc74b5
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 14 deletions.
24 changes: 23 additions & 1 deletion polars/polars-lazy/src/dsl/function_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
#[derive(Clone, PartialEq, Debug)]
pub enum FunctionExpr {
NullCount,
Pow(f64),
}

impl FunctionExpr {
Expand All @@ -16,9 +17,23 @@ impl FunctionExpr {
_cntxt: Context,
fields: &[Field],
) -> Result<Field> {
let with_dtype = |dtype: DataType| Ok(Field::new(fields[0].name(), dtype));
let map_dtype = |func: &dyn Fn(&DataType) -> DataType| {
let dtype = func(fields[0].data_type());
Ok(Field::new(fields[0].name(), dtype))
};

let float_dtype = || {
map_dtype(&|dtype| match dtype {
DataType::Float32 => DataType::Float32,
_ => DataType::Float64,
})
};

use FunctionExpr::*;
match self {
NullCount => Ok(Field::new(fields[0].name(), IDX_DTYPE)),
NullCount => with_dtype(IDX_DTYPE),
Pow(_) => float_dtype(),
}
}
}
Expand All @@ -40,6 +55,13 @@ impl From<FunctionExpr> for NoEq<Arc<dyn SeriesUdf>> {
};
wrap!(f)
}
Pow(exponent) => {
let f = move |s: &mut [Series]| {
let s = &s[0];
s.pow(exponent)
};
wrap!(f)
}
}
}
}
21 changes: 15 additions & 6 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,19 @@ impl Expr {
}
}

fn map_private(self, function_expr: FunctionExpr, fmt_str: &'static str) -> Self {
Expr::Function {
input: vec![self],
function: function_expr,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: false,
auto_explode: false,
fmt_str,
},
}
}

/// Apply a function/closure once the logical plan get executed with many arguments
///
/// See the [`Expr::map`] function for the differences between [`map`](Expr::map) and [`apply`](Expr::apply).
Expand Down Expand Up @@ -631,7 +644,7 @@ impl Expr {
collect_groups: ApplyOptions::ApplyList,
input_wildcard_expansion: false,
auto_explode: false,
fmt_str: "",
fmt_str: "map_list",
},
}
}
Expand Down Expand Up @@ -1077,11 +1090,7 @@ impl Expr {

/// Raise expression to the power `exponent`
pub fn pow(self, exponent: f64) -> Self {
self.map(
move |s: Series| s.pow(exponent),
GetOutput::from_type(DataType::Float64),
)
.with_fmt("pow")
self.map_private(FunctionExpr::Pow(exponent), "pow")
}

/// Filter a single column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ pub(super) fn other_column_is_pushdown_boundary(node: Node, expr_arena: &Arena<A
AExpr::Shift { .. } | AExpr::Sort { .. } | AExpr::SortBy { .. }
| AExpr::Agg(_) // an aggregation needs all rows
| AExpr::Reverse(_)
// everything that works on groups likely changes to order of elements w/r/t the other columns
// Apply groups can be something like shift, sort, or an aggregation like skew
// both need all values
| AExpr::AnonymousFunction {options: FunctionOptions { collect_groups: ApplyOptions::ApplyGroups, .. }, ..}
| AExpr::AnonymousFunction {options: FunctionOptions { collect_groups: ApplyOptions::ApplyList, .. }, ..}
| AExpr::BinaryExpr {..}
| AExpr::Cast {data_type: DataType::Float32 | DataType::Float64, ..}
// cast may create nulls
Expand Down Expand Up @@ -145,6 +145,7 @@ pub(super) fn predicate_column_is_pushdown_boundary(node: Node, expr_arena: &Are
| AExpr::Reverse(_)
// everything that works on groups likely changes to order of elements w/r/t the other columns
| AExpr::AnonymousFunction {..}
| AExpr::Function {..}
| AExpr::BinaryExpr {..}
// cast may change precision.
| AExpr::Cast {data_type: DataType::Float32 | DataType::Float64 | DataType::Utf8 | DataType::Boolean, ..}
Expand Down Expand Up @@ -175,22 +176,30 @@ pub(super) fn rewrite_projection_node(
where
{
let mut local_predicates = Vec::with_capacity(acc_predicates.len());
let input_schema = lp_arena.get(input).schema(lp_arena);

// maybe update predicate name if a projection is an alias
// aliases change the column names and because we push the predicates downwards
// this may be problematic as the aliased column may not yet exist.
for projection_node in &projections {
let projection_is_boundary =
// only if a predicate refers to this projection's output column.
let projection_maybe_boundary =
predicate_column_is_pushdown_boundary(*projection_node, expr_arena);

let projection_expr = expr_arena.get(*projection_node);
let output_field = projection_expr
.to_field(input_schema, Context::Default, expr_arena)
.unwrap();
let projection_roots = aexpr_to_root_names(*projection_node, expr_arena);

{
let projection_aexpr = expr_arena.get(*projection_node);
if let AExpr::Alias(_, name) = projection_aexpr {
// if this alias refers to one of the predicates in the upper nodes
// we rename the column of the predicate before we push it downwards.

if let Some(predicate) = acc_predicates.remove(&*name) {
if projection_is_boundary {
if projection_maybe_boundary {
local_predicates.push(predicate);
continue;
}
Expand All @@ -214,8 +223,6 @@ where
}
}

let input_schema = lp_arena.get(input).schema(lp_arena);

// we check if predicates can be done on the input above
// this can only be done if the current projection is not a projection boundary
let is_boundary = other_column_is_pushdown_boundary(*projection_node, expr_arena);
Expand All @@ -234,7 +241,7 @@ where
// checks 1.
if check_input_node(*predicate, input_schema, expr_arena)
// checks 2.
&& !(projection_roots.contains(name) && projection_is_boundary)
&& !(output_field.name().as_str() == &**name && projection_maybe_boundary)
// checks 3.
&& !is_boundary
{
Expand Down
19 changes: 19 additions & 0 deletions polars/polars-lazy/src/tests/predicate_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,22 @@ fn test_filter_null_creation_by_cast() -> Result<()> {

Ok(())
}

#[test]
fn test_predicate_pd_apply() -> Result<()> {
let q = df![
"a" => [1, 2, 3],
]?
.lazy()
.select([
// map_list is use in python `col().apply`
col("a"),
col("a")
.map_list(|s| Ok(s), GetOutput::same_type())
.alias("a_applied"),
])
.filter(col("a").lt(lit(3)));

assert!(predicate_at_scan(q.clone()));
Ok(())
}

0 comments on commit 3bc74b5

Please sign in to comment.