Skip to content

Commit

Permalink
refactor[rust]: add generic type coercion node for Function expressio…
Browse files Browse the repository at this point in the history
…ns (#4483)
  • Loading branch information
ritchie46 committed Aug 18, 2022
1 parent a0afdb3 commit e3da667
Showing 1 changed file with 43 additions and 24 deletions.
67 changes: 43 additions & 24 deletions polars/polars-lazy/src/logical_plan/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,25 +427,35 @@ impl OptimizationRule for TypeCoercionRule {
options,
})
}
// generic type coercion of any function.
AExpr::Function {
// only for `DataType::Unknown` as it still has to be set.
function: FunctionExpr::ShiftAndFill { periods },
ref function,
ref input,
options,
} => {
mut options,
} if options.cast_to_supertypes => {
// satisfy bchk
let function = function.clone();
let input = input.clone();

let input_schema = get_schema(lp_arena, lp_node);
let self_node = input[0];
let other_node = input[1];
let (left, type_self) = get_aexpr_and_type(expr_arena, self_node, &input_schema)?;
let (fill_value, type_other) =
get_aexpr_and_type(expr_arena, other_node, &input_schema)?;
let (self_ae, type_self) =
get_aexpr_and_type(expr_arena, self_node, &input_schema)?;

early_escape(&type_self, &type_other)?;

let super_type = get_supertype(&type_self, &type_other).ok()?;
let super_type =
modify_supertype(super_type, left, fill_value, &type_self, &type_other);
let mut super_type = type_self.clone();
for other in &input[1..] {
let (other, type_other) =
get_aexpr_and_type(expr_arena, *other, &input_schema)?;

// early return until Unknown is set
if let DataType::Unknown = &type_other {
return None;
}
early_escape(&super_type, &type_other)?;
let new_st = get_supertype(&super_type, &type_other).ok()?;
super_type = modify_supertype(new_st, self_ae, other, &type_self, &type_other)
}
// only cast if the type is not already the super type.
// this can prevent an expensive flattening and subsequent aggregation
// in a groupby context. To be able to cast the groups need to be
Expand All @@ -459,23 +469,32 @@ impl OptimizationRule for TypeCoercionRule {
} else {
self_node
};
let new_node_other = if type_other != super_type {
expr_arena.add(AExpr::Cast {
expr: other_node,
data_type: super_type,
strict: false,
})
} else {
other_node
};
let mut new_nodes = Vec::with_capacity(input.len());
new_nodes.push(new_node_self);

for other_node in &input[1..] {
let (_, type_other) =
get_aexpr_and_type(expr_arena, *other_node, &input_schema)?;
let new_node_other = if type_other != super_type {
expr_arena.add(AExpr::Cast {
expr: *other_node,
data_type: super_type.clone(),
strict: false,
})
} else {
*other_node
};

new_nodes.push(new_node_other)
}
// ensure we don't go through this on next iteration
options.cast_to_supertypes = false;
Some(AExpr::Function {
function: FunctionExpr::ShiftAndFill { periods },
input: vec![new_node_self, new_node_other],
function,
input: new_nodes,
options,
})
}

_ => None,
}
}
Expand Down

0 comments on commit e3da667

Please sign in to comment.