Skip to content

Commit

Permalink
fix bug in function input wildcard
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 5, 2021
1 parent 7ec2cd3 commit 7938b71
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 305 deletions.
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1112,4 +1112,4 @@ impl JoinBuilder {
.build();
LazyFrame::from_logical_plan(lp, opt_state)
}
}
}
228 changes: 147 additions & 81 deletions polars/polars-lazy/src/logical_plan/iterator.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,151 @@
use crate::prelude::*;

macro_rules! push_expr {
($current_expr:expr, $push:ident) => {{
match $current_expr {
Column(_) | Literal(_) | Wildcard => {}
Alias(e, _) => $push(e),
Not(e) => $push(e),
BinaryExpr { left, op: _, right } => {
$push(left);
$push(right);
}
IsNull(e) => $push(e),
IsNotNull(e) => $push(e),
Cast { expr, .. } => $push(expr),
Sort { expr, .. } => $push(expr),
Take { expr, idx } => {
$push(expr);
$push(idx);
}
Filter { input, by } => {
$push(input);
$push(by)
}
SortBy { expr, by, .. } => {
$push(expr);
$push(by)
}
Agg(agg_e) => {
use AggExpr::*;
match agg_e {
Max(e) => $push(e),
Min(e) => $push(e),
Mean(e) => $push(e),
Median(e) => $push(e),
NUnique(e) => $push(e),
First(e) => $push(e),
Last(e) => $push(e),
List(e) => $push(e),
Count(e) => $push(e),
Quantile { expr, .. } => $push(expr),
Sum(e) => $push(e),
AggGroups(e) => $push(e),
Std(e) => $push(e),
Var(e) => $push(e),
}
}
Ternary {
truthy,
falsy,
predicate,
} => {
$push(truthy);
$push(falsy);
$push(predicate)
}
Function { input, .. } => input.iter().for_each(|e| $push(e)),
Shift { input, .. } => $push(input),
Reverse(e) => $push(e),
Duplicated(e) => $push(e),
IsUnique(e) => $push(e),
Explode(e) => $push(e),
Window {
function,
partition_by,
order_by,
} => {
$push(function);
for e in partition_by {
$push(e)
}
if let Some(e) = order_by {
$push(e);
}
}
Slice { input, .. } => $push(input),
BinaryFunction {
input_a, input_b, ..
} => {
$push(input_a);
$push(input_b)
}
Exclude(e, _) => $push(e),
KeepName(e) => $push(e),
}
}};
}

impl Expr {
pub(crate) fn iter_mut(&mut self) -> ExprIterMut<'_> {
let mut stack = Vec::with_capacity(8);
stack.push(self as *mut Expr);
ExprIterMut { root: self, stack }
}
}

pub(crate) struct ExprIterMut<'a> {
#[allow(dead_code)]
// SAFETY: Don't ever access this!
root: &'a mut Expr,
stack: Vec<*mut Expr>,
}

impl<'a> ExprIterMut<'a> {
/// # Safety
///
/// This is a mutable iterator over the Expr. However iterating this is unsafe. Once you
/// update a node in the Expr tree, (For instance its children). The old children are on this
/// stack and are still dereferenced leading to UB.
///
/// Its the callers responsibility to stop iteration once an element is mutably accessed.
pub unsafe fn next_unsafe(&mut self) -> Option<&'a mut Expr> {
self.stack.pop().map(|current_expr| {
use Expr::*;
// we don't use a &mut Expr, but a &Expr, so that we can reuse the macro. This should
// not matter for safety.
let mut push = |e: &'a Expr| self.stack.push(e as *const Expr as *mut Expr);

let current_expr = &mut *current_expr;

push_expr!(current_expr, push);

current_expr
})
}
}

impl<'a> Iterator for ExprIterMut<'a> {
type Item = &'a mut Expr;

fn next(&mut self) -> Option<Self::Item> {
self.stack.pop().map(|current_expr| {
use Expr::*;
// we don't use a &mut Expr, but a &Expr, so that we can reuse the macro. This should
// not matter for safety.
let mut push = |e: &'a Expr| self.stack.push(e as *const Expr as *mut Expr);

unsafe {
let current_expr = &mut *current_expr;

push_expr!(current_expr, push);

current_expr
}
})
}
}

pub struct ExprIter<'a> {
stack: Vec<&'a Expr>,
}
Expand All @@ -12,87 +158,7 @@ impl<'a> Iterator for ExprIter<'a> {
use Expr::*;
let mut push = |e: &'a Expr| self.stack.push(e);

match current_expr {
Column(_) | Literal(_) | Wildcard => {}
Alias(e, _) => push(e),
Not(e) => push(e),
BinaryExpr { left, op: _, right } => {
push(left);
push(right);
}
IsNull(e) => push(e),
IsNotNull(e) => push(e),
Cast { expr, .. } => push(expr),
Sort { expr, .. } => push(expr),
Take { expr, idx } => {
push(expr);
push(idx);
}
Filter { input, by } => {
push(input);
push(by)
}
SortBy { expr, by, .. } => {
push(expr);
push(by)
}
Agg(agg_e) => {
use AggExpr::*;
match agg_e {
Max(e) => push(e),
Min(e) => push(e),
Mean(e) => push(e),
Median(e) => push(e),
NUnique(e) => push(e),
First(e) => push(e),
Last(e) => push(e),
List(e) => push(e),
Count(e) => push(e),
Quantile { expr, .. } => push(expr),
Sum(e) => push(e),
AggGroups(e) => push(e),
Std(e) => push(e),
Var(e) => push(e),
}
}
Ternary {
truthy,
falsy,
predicate,
} => {
push(truthy);
push(falsy);
push(predicate)
}
Function { input, .. } => input.iter().for_each(|e| push(e)),
Shift { input, .. } => push(input),
Reverse(e) => push(e),
Duplicated(e) => push(e),
IsUnique(e) => push(e),
Explode(e) => push(e),
Window {
function,
partition_by,
order_by,
} => {
push(function);
for e in partition_by {
push(e)
}
if let Some(e) = order_by {
push(e);
}
}
Slice { input, .. } => push(input),
BinaryFunction {
input_a, input_b, ..
} => {
push(input_a);
push(input_b)
}
Exclude(e, _) => push(e),
KeepName(e) => push(e),
}
push_expr!(current_expr, push);
current_expr
})
}
Expand Down
86 changes: 36 additions & 50 deletions polars/polars-lazy/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ use polars_io::csv_core::utils::infer_file_schema;
#[cfg(feature = "parquet")]
use polars_io::{parquet::ParquetReader, SerReader};

use crate::logical_plan::iterator::ArenaExprIter;
use crate::logical_plan::LogicalPlan::DataFrameScan;
use crate::utils::{
combine_predicates_expr, expr_to_root_column_names, has_expr, has_wildcard,
Expand Down Expand Up @@ -851,11 +850,25 @@ fn rewrite_keep_name(expr: Expr) -> Expr {
}
}

/// Take an expression with a root: col("*") and copies that expression for all columns in the schema,
/// with the exclusion of the `names` in the exclude expression.
/// The resulting expressions are written to result.
fn replace_wilcard(expr: &Expr, result: &mut Vec<Expr>, exclude: &[Arc<String>], schema: &Schema) {
for field in schema.fields() {
let name = field.name();
if !exclude.iter().any(|exluded| &**exluded == name) {
let new_expr = replace_wildcard_with_column(expr.clone(), Arc::new(name.clone()));
let new_expr = rewrite_keep_name(new_expr);
result.push(new_expr)
}
}
}

/// In case of single col(*) -> do nothing, no selection is the same as select all
/// In other cases replace the wildcard with an expression with all columns
fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema) -> Vec<Expr> {
let mut result = Vec::with_capacity(exprs.len() + schema.fields().len());
for expr in exprs {
for mut expr in exprs {
if has_wildcard(&expr) {
// keep track of column excluded from the wildcard
let mut exclude = vec![];
Expand All @@ -880,62 +893,35 @@ fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema) -> Vec<Expr> {
continue;
}
// this path prepares the wildcard as input for the Function Expr
// To deal with the borrow checker we first create an arena from this expression.
// Then we clone that arena to `new_arena` because we cannot borrow and have mutable access,
// We iterate the old_arena mutable.
// * Replace the wildcard column with new column names (assign nodes into new arena)
// * Swap the inputs vec, from Expr::Function in the old arena to Expr::Function in the new arena
// * convert from arena expr to boxed expr
if has_expr(
&expr,
|e| matches!(e, Expr::Function { input, options, .. } if options.input_wildcard_expansion && input.iter().any(has_wildcard)),
) {
let mut arena = Arena::with_capacity(16);
let root = to_aexpr(expr, &mut arena);
let mut new_arena = arena.clone();
let iter = (&arena).iter(root);
let mut function_node = Default::default();
let mut function_inputs = Default::default();

for (node, ae) in iter {
if let AExpr::Function { .. } = ae {
function_node = node;
break;
}
}

if let AExpr::Function { input, .. } = arena.get_mut(function_node) {
let (idx, _) = input
.iter()
.find_position(|&node| matches!(new_arena.get(*node), AExpr::Wildcard))
.expect("should have wildcard");
input.remove(idx);

for field in schema.fields() {
let node = new_arena.add(AExpr::Column(Arc::new(field.name().clone())));
input.push(node);
let mut unsafe_iter = expr.iter_mut();

// Safety: this is safe because we directly stop iteration once the children are mutated.
unsafe {
while let Some(e) = unsafe_iter.next_unsafe() {
if let Expr::Function { input, .. } = e {
let mut new_inputs = Vec::with_capacity(input.len());

input.iter_mut().for_each(|e| {
if has_wildcard(e) {
replace_wilcard(e, &mut new_inputs, &[], schema)
} else {
new_inputs.push(e.clone())
}
});

*input = new_inputs;
break;
}
}
function_inputs = std::mem::take(input);
}

if let AExpr::Function { input, .. } = new_arena.get_mut(function_node) {
*input = function_inputs;
}

let new_expr = node_to_exp(root, &new_arena);
result.push(new_expr);
result.push(expr);
continue;
}

for field in schema.fields() {
let name = field.name();
if !exclude.iter().any(|exluded| &**exluded == name) {
let new_expr =
replace_wildcard_with_column(expr.clone(), Arc::new(name.clone()));
let new_expr = rewrite_keep_name(new_expr);
result.push(new_expr)
}
}
replace_wilcard(&expr, &mut result, &exclude, schema);
} else {
let expr = rewrite_keep_name(expr);
result.push(expr)
Expand Down

0 comments on commit 7938b71

Please sign in to comment.