Skip to content

Commit

Permalink
ensure stack optimizations run at scan nodes (#4138)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 24, 2022
1 parent 3668df5 commit 3805d8f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 10 deletions.
30 changes: 29 additions & 1 deletion polars/polars-lazy/src/logical_plan/alp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,25 @@ impl Default for ALogicalPlan {
}

impl ALogicalPlan {
/// Get the schema of the logical plan node but don't take projections into account at the scan
/// level. This ensures we can apply the predicate
pub(crate) fn scan_schema(&self) -> &SchemaRef {
use ALogicalPlan::*;
match self {
#[cfg(feature = "python")]
PythonScan { options } => &options.schema,
#[cfg(feature = "csv-file")]
CsvScan { schema, .. } => schema,
#[cfg(feature = "parquet")]
ParquetScan { schema, .. } => schema,
#[cfg(feature = "ipc")]
IpcScan { schema, .. } => schema,
AnonymousScan { schema, .. } => schema,
_ => unreachable!(),
}
}

/// Get the schema of the logical plan node.
pub(crate) fn schema<'a>(&'a self, arena: &'a Arena<ALogicalPlan>) -> &'a SchemaRef {
use ALogicalPlan::*;
match self {
Expand Down Expand Up @@ -493,7 +512,16 @@ impl ALogicalPlan {
}
#[cfg(feature = "python")]
PythonScan { .. } => {}
AnonymousScan { .. } => {}
AnonymousScan {
predicate,
aggregate,
..
} => {
container.extend_from_slice(aggregate);
if let Some(node) = predicate {
container.push(*node)
}
}
}
}

Expand Down
22 changes: 13 additions & 9 deletions polars/polars-lazy/src/logical_plan/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,12 @@ fn get_input(lp_arena: &Arena<ALogicalPlan>, lp_node: Node) -> [Option<Node>; 2]
inputs
}

fn get_schema(lp_arena: &Arena<ALogicalPlan>, lp_node: Node) -> Option<&SchemaRef> {
get_input(lp_arena, lp_node)[0].map(|input| lp_arena.get(input).schema(lp_arena))
fn get_schema(lp_arena: &Arena<ALogicalPlan>, lp_node: Node) -> &SchemaRef {
match get_input(lp_arena, lp_node) {
[Some(input), _] => lp_arena.get(input).schema(lp_arena),
// files don't have an input, so we must take their schema
[None, _] => lp_arena.get(lp_node).scan_schema(),
}
}

fn get_aexpr_and_type<'a>(
Expand Down Expand Up @@ -132,7 +136,7 @@ impl OptimizationRule for TypeCoercionRule {
falsy: falsy_node,
predicate,
} => {
let input_schema = get_schema(lp_arena, lp_node)?;
let input_schema = get_schema(lp_arena, lp_node);
let (truthy, type_true) =
get_aexpr_and_type(expr_arena, truthy_node, input_schema)?;
let (falsy, type_false) = get_aexpr_and_type(expr_arena, falsy_node, input_schema)?;
Expand Down Expand Up @@ -179,7 +183,7 @@ impl OptimizationRule for TypeCoercionRule {
op,
right: node_right,
} => {
let input_schema = get_schema(lp_arena, lp_node)?;
let input_schema = get_schema(lp_arena, lp_node);
let (left, type_left) = get_aexpr_and_type(expr_arena, node_left, input_schema)?;
let (right, type_right) = get_aexpr_and_type(expr_arena, node_right, input_schema)?;

Expand Down Expand Up @@ -348,15 +352,15 @@ impl OptimizationRule for TypeCoercionRule {
ref input,
options,
} => {
let input_schema = get_schema(lp_arena, lp_node)?;
let input_schema = get_schema(lp_arena, lp_node);
let other_node = input[1];
let (_, type_left) = get_aexpr_and_type(expr_arena, input[0], input_schema)?;
let (_, type_other) = get_aexpr_and_type(expr_arena, other_node, input_schema)?;

match (&type_left, type_other) {
(DataType::Categorical(Some(rev_map)), DataType::Utf8)
if rev_map.is_global() =>
{
// cast both local and global string cache
// note that there might not yet be a rev
(DataType::Categorical(_), DataType::Utf8) => {
let mut input = input.clone();

let casted_expr = AExpr::Cast {
Expand Down Expand Up @@ -386,7 +390,7 @@ impl OptimizationRule for TypeCoercionRule {
ref input,
options,
} => {
let input_schema = get_schema(lp_arena, lp_node)?;
let input_schema = get_schema(lp_arena, lp_node);
let other_node = input[1];
let (left, type_left) = get_aexpr_and_type(expr_arena, input[0], input_schema)?;
let (fill_value, type_fill_value) =
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/io/test_lazy_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,19 @@ def test_row_count(foods_ipc: str) -> None:
)

assert df["foo"].to_list() == [10, 16, 21, 23, 24, 30, 35]


def test_is_in_type_coercion(foods_ipc: str) -> None:
out = (
pl.scan_ipc(foods_ipc)
.filter(pl.col("category").is_in(["vegetables"]))
.collect()
)
assert out.shape == (7, 4)
out = (
pl.scan_ipc(foods_ipc)
.select(pl.col("category").alias("cat"))
.filter(pl.col("cat").is_in(["vegetables"]))
.collect()
)
assert out.shape == (7, 1)

0 comments on commit 3805d8f

Please sign in to comment.