Skip to content

Commit

Permalink
fix: Resolve function names and prune all aliases. (#15522)
Browse files Browse the repository at this point in the history
Co-authored-by: ritchie <ritchie46@gmail.com>
  • Loading branch information
reswqa and ritchie46 committed Apr 8, 2024
1 parent c211fad commit dcee934
Show file tree
Hide file tree
Showing 18 changed files with 349 additions and 171 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/bitwise.rs
Expand Up @@ -161,7 +161,7 @@ impl BitAnd for &BooleanChunked {
(1, 1) => {},
(1, _) => {
return match self.get(0) {
Some(true) => rhs.clone(),
Some(true) => rhs.clone().with_name(self.name()),
Some(false) => BooleanChunked::full(self.name(), false, rhs.len()),
None => &self.new_from_index(0, rhs.len()) & rhs,
};
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/physical_plan/planner/expr.rs
Expand Up @@ -533,7 +533,7 @@ fn create_physical_expr_inner(
options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise);
// will be reset in the function so get that here
let has_window = state.local.has_window;
let input = create_physical_expressions_from_nodes_check_state(
let input = create_physical_expressions_check_state(
&input,
ctxt,
expr_arena,
Expand Down Expand Up @@ -564,7 +564,7 @@ fn create_physical_expr_inner(
options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise);
// will be reset in the function so get that here
let has_window = state.local.has_window;
let input = create_physical_expressions_from_nodes_check_state(
let input = create_physical_expressions_check_state(
&input,
ctxt,
expr_arena,
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/physical_plan/planner/lp.rs
Expand Up @@ -81,7 +81,7 @@ fn partitionable_gb(
},
Function {input, options, ..} => {
matches!(options.collect_groups, ApplyOptions::ElementWise) && input.len() == 1 &&
!has_aggregation(input[0])
!has_aggregation(input[0].node())
}
BinaryExpr {left, right, ..} => {
!has_aggregation(*left) && !has_aggregation(*right)
Expand Down
20 changes: 13 additions & 7 deletions crates/polars-plan/src/logical_plan/aexpr/mod.rs
Expand Up @@ -161,14 +161,17 @@ pub enum AExpr {
falsy: Node,
},
AnonymousFunction {
input: Vec<Node>,
input: Vec<ExprIR>,
function: SpecialEq<Arc<dyn SeriesUdf>>,
output_type: GetOutput,
options: FunctionOptions,
},
Function {
/// function arguments
input: Vec<Node>,
/// Function arguments
/// Some functions rely on aliases,
/// for instance assignment of struct fields.
/// Therefore we need `[ExprIr]`.
input: Vec<ExprIR>,
/// function to apply
function: FunctionExpr,
options: FunctionOptions,
Expand Down Expand Up @@ -292,8 +295,7 @@ impl AExpr {
input
.iter()
.rev()
.copied()
.for_each(|node| container.push_node(node))
.for_each(|e| container.push_node(e.node()))
},
Explode(e) => container.push_node(*e),
Window {
Expand Down Expand Up @@ -372,8 +374,12 @@ impl AExpr {
return self;
},
AnonymousFunction { input, .. } | Function { input, .. } => {
input.clear();
input.extend(inputs.iter().rev().copied());
debug_assert_eq!(input.len(), inputs.len());

// Assign in reverse order as that was the order in which nodes were extracted.
for (e, node) in input.iter_mut().zip(inputs.iter().rev()) {
e.set_node(*node);
}
return self;
},
Slice {
Expand Down
32 changes: 22 additions & 10 deletions crates/polars-plan/src/logical_plan/aexpr/schema.rs
Expand Up @@ -206,22 +206,14 @@ impl AExpr {
} => {
let tmp = function.get_output();
let output_type = tmp.as_ref().unwrap_or(output_type);
let fields = input
.iter()
// default context because `col()` would return a list in aggregation context
.map(|node| arena.get(*node).to_field(schema, Context::Default, arena))
.collect::<PolarsResult<Vec<_>>>()?;
let fields = func_args_to_fields(input, schema, arena)?;
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str);
Ok(output_type.get_field(schema, ctxt, &fields))
},
Function {
function, input, ..
} => {
let fields = input
.iter()
// default context because `col()` would return a list in aggregation context
.map(|node| arena.get(*node).to_field(schema, Context::Default, arena))
.collect::<PolarsResult<Vec<_>>>()?;
let fields = func_args_to_fields(input, schema, arena)?;
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function);
function.get_field(schema, ctxt, &fields)
},
Expand All @@ -236,6 +228,26 @@ impl AExpr {
}
}

fn func_args_to_fields(
input: &[ExprIR],
schema: &Schema,
arena: &Arena<AExpr>,
) -> PolarsResult<Vec<Field>> {
input
.iter()
// Default context because `col()` would return a list in aggregation context
.map(|e| {
arena
.get(e.node())
.to_field(schema, Context::Default, arena)
.map(|mut field| {
field.name = e.output_name().into();
field
})
})
.collect()
}

fn get_arithmetic_field(
left: Node,
right: Node,
Expand Down
32 changes: 26 additions & 6 deletions crates/polars-plan/src/logical_plan/conversion.rs
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use polars_core::prelude::*;
use polars_utils::vec::ConvertVec;
use recursive::recursive;
Expand Down Expand Up @@ -63,6 +65,19 @@ fn to_aexprs(input: Vec<Expr>, arena: &mut Arena<AExpr>, state: &mut ConversionS
.collect()
}

fn set_function_output_name<F>(e: &[ExprIR], state: &mut ConversionState, function_fmt: F)
where
F: FnOnce() -> Cow<'static, str>,
{
if state.output_name.is_none() {
if e.is_empty() {
state.output_name = OutputName::LiteralLhs(ColumnName::from(function_fmt().as_ref()));
} else {
state.output_name = e[0].output_name_inner().clone();
}
}
}

/// Converts expression to AExpr and adds it to the arena, which uses an arena (Vec) for allocation.
#[recursive]
fn to_aexpr_impl(expr: Expr, arena: &mut Arena<AExpr>, state: &mut ConversionState) -> Node {
Expand Down Expand Up @@ -211,9 +226,10 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena<AExpr>, state: &mut ConversionSta
output_type,
options,
} => {
state.prune_alias = false;
let e = to_expr_irs(input, arena);
set_function_output_name(&e, state, || Cow::Borrowed(options.fmt_str));
AExpr::AnonymousFunction {
input: to_aexprs(input, arena, state),
input: e,
function,
output_type,
options,
Expand All @@ -224,14 +240,18 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena<AExpr>, state: &mut ConversionSta
function,
options,
} => {
let e = to_expr_irs(input, arena);

if state.output_name.is_none() {
// Handles special case functions like `struct.field`.
if let Some(name) = function.output_name() {
state.output_name = OutputName::ColumnLhs(name.clone())
} else {
set_function_output_name(&e, state, || Cow::Owned(format!("{}", &function)));
}
}
state.prune_alias = false;
AExpr::Function {
input: to_aexprs(input, arena, state),
input: e,
function,
options,
}
Expand Down Expand Up @@ -672,7 +692,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena<AExpr>) -> Expr {
output_type,
options,
} => Expr::AnonymousFunction {
input: nodes_to_exprs(&input, expr_arena),
input: expr_irs_to_exprs(input, expr_arena),
function,
output_type,
options,
Expand All @@ -682,7 +702,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena<AExpr>) -> Expr {
function,
options,
} => Expr::Function {
input: nodes_to_exprs(&input, expr_arena),
input: expr_irs_to_exprs(input, expr_arena),
function,
options,
},
Expand Down
34 changes: 30 additions & 4 deletions crates/polars-plan/src/logical_plan/expr_ir.rs
Expand Up @@ -3,7 +3,7 @@ use std::hash::Hash;
use std::hash::Hasher;

use super::*;
use crate::constants::LITERAL_NAME;
use crate::constants::{get_len_name, LITERAL_NAME};

#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)]
pub enum OutputName {
Expand Down Expand Up @@ -64,11 +64,33 @@ impl ExprIR {
}
break;
},
AExpr::Alias(node, name) => {
out.output_name = OutputName::ColumnLhs(name.clone());
out.node = *node;
AExpr::Function {
input, function, ..
} => {
if input.is_empty() {
out.output_name =
OutputName::LiteralLhs(ColumnName::from(format!("{}", function)));
} else {
out.output_name = input[0].output_name.clone();
}
break;
},
AExpr::AnonymousFunction { input, options, .. } => {
if input.is_empty() {
out.output_name = OutputName::LiteralLhs(ColumnName::from(options.fmt_str));
} else {
out.output_name = input[0].output_name.clone();
}
break;
},
AExpr::Len => out.output_name = OutputName::LiteralLhs(get_len_name()),
AExpr::Alias(_, _) => {
// Should be removed during conversion.
#[cfg(debug_assertions)]
{
unreachable!()
}
},
_ => {},
}
}
Expand All @@ -90,6 +112,10 @@ impl ExprIR {
self.output_name = OutputName::Alias(name)
}

pub(crate) fn output_name_inner(&self) -> &OutputName {
&self.output_name
}

pub(crate) fn output_name_arc(&self) -> &Arc<str> {
self.output_name.unwrap()
}
Expand Down
35 changes: 26 additions & 9 deletions crates/polars-plan/src/logical_plan/optimizer/fused.rs
Expand Up @@ -2,7 +2,12 @@ use super::*;

pub struct FusedArithmetic {}

fn get_expr(input: Vec<Node>, op: FusedOperator) -> AExpr {
fn get_expr(input: &[Node], op: FusedOperator, expr_arena: &Arena<AExpr>) -> AExpr {
let input = input
.iter()
.copied()
.map(|n| ExprIR::from_node(n, expr_arena))
.collect();
let mut options = FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
cast_to_supertypes: true,
Expand Down Expand Up @@ -86,8 +91,8 @@ impl OptimizationRule for FusedArithmetic {
} => match check_eligible(left, right, lp_node, expr_arena, lp_arena)? {
(None, _) | (Some(false), _) => Ok(None),
(Some(true), Some(output_field)) => {
let input = vec![*right, *a, *b];
let fma = get_expr(input, FusedOperator::MultiplyAdd);
let input = &[*right, *a, *b];
let fma = get_expr(input, FusedOperator::MultiplyAdd, expr_arena);
let node = expr_arena.add(fma);
// we reordered the arguments, so we don't obey the left expression output name
// rule anymore, that's why we alias
Expand All @@ -109,8 +114,12 @@ impl OptimizationRule for FusedArithmetic {
} => match check_eligible(left, right, lp_node, expr_arena, lp_arena)? {
(None, _) | (Some(false), _) => Ok(None),
(Some(true), _) => {
let input = vec![*left, *a, *b];
Ok(Some(get_expr(input, FusedOperator::MultiplyAdd)))
let input = &[*left, *a, *b];
Ok(Some(get_expr(
input,
FusedOperator::MultiplyAdd,
expr_arena,
)))
},
},
_ => Ok(None),
Expand All @@ -135,8 +144,12 @@ impl OptimizationRule for FusedArithmetic {
} => match check_eligible(left, right, lp_node, expr_arena, lp_arena)? {
(None, _) | (Some(false), _) => Ok(None),
(Some(true), _) => {
let input = vec![*left, *a, *b];
Ok(Some(get_expr(input, FusedOperator::SubMultiply)))
let input = &[*left, *a, *b];
Ok(Some(get_expr(
input,
FusedOperator::SubMultiply,
expr_arena,
)))
},
},
_ => {
Expand All @@ -153,8 +166,12 @@ impl OptimizationRule for FusedArithmetic {
match check_eligible(left, right, lp_node, expr_arena, lp_arena)? {
(None, _) | (Some(false), _) => Ok(None),
(Some(true), _) => {
let input = vec![*a, *b, *right];
Ok(Some(get_expr(input, FusedOperator::MultiplySub)))
let input = &[*a, *b, *right];
Ok(Some(get_expr(
input,
FusedOperator::MultiplySub,
expr_arena,
)))
},
}
},
Expand Down
Expand Up @@ -164,11 +164,12 @@ fn check_and_extend_predicate_pd_nodes(
// being projected.
let mut transferred_local_nodes = false;
if let Some(rhs) = input.get(1) {
if matches!(expr_arena.get(*rhs), AExpr::Literal { .. }) {
let rhs = rhs.node();
if matches!(expr_arena.get(rhs), AExpr::Literal { .. }) {
let mut local_nodes = Vec::<Node>::with_capacity(4);
ae.nodes(&mut local_nodes);

stack.extend(local_nodes.into_iter().filter(|node| node != rhs));
stack.extend(local_nodes.into_iter().filter(|node| *node != rhs));
transferred_local_nodes = true;
}
};
Expand Down

0 comments on commit dcee934

Please sign in to comment.