From a9f4df3a2b2ec2cda59504147d40ac565ed101cf Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 26 Mar 2024 08:47:37 +0100 Subject: [PATCH] feat: Full plan CSE (#15264) --- crates/polars-core/src/frame/explode.rs | 2 +- crates/polars-core/src/frame/mod.rs | 2 +- crates/polars-core/src/schema.rs | 7 + crates/polars-io/src/cloud/options.rs | 2 +- crates/polars-io/src/csv/read.rs | 6 +- crates/polars-io/src/options.rs | 2 +- crates/polars-io/src/parquet/read.rs | 2 +- .../src/physical_plan/executors/scan/ipc.rs | 2 - .../src/physical_plan/expressions/column.rs | 2 +- .../src/physical_plan/planner/lp.rs | 2 - crates/polars-lazy/src/scan/ipc.rs | 3 - crates/polars-lazy/src/tests/cse.rs | 26 +- crates/polars-lazy/src/tests/io.rs | 1 - crates/polars-lazy/src/tests/tpch.rs | 9 + crates/polars-ops/src/frame/join/args.rs | 6 +- crates/polars-ops/src/frame/join/asof/mod.rs | 4 +- crates/polars-plan/src/dsl/options.rs | 2 +- .../src/logical_plan/aexpr/hash.rs | 13 + .../polars-plan/src/logical_plan/aexpr/mod.rs | 15 +- .../logical_plan/{alp.rs => alp/inputs.rs} | 221 +------- .../polars-plan/src/logical_plan/alp/mod.rs | 136 +++++ .../src/logical_plan/alp/schema.rs | 90 +++ .../polars-plan/src/logical_plan/builder.rs | 3 +- .../polars-plan/src/logical_plan/expr_ir.rs | 13 +- .../polars-plan/src/logical_plan/file_scan.rs | 51 +- .../src/logical_plan/functions/count.rs | 1 - .../src/logical_plan/functions/mod.rs | 49 ++ .../logical_plan/optimizer/cache_states.rs | 261 ++++----- .../logical_plan/optimizer/collect_members.rs | 51 +- .../src/logical_plan/optimizer/cse.rs | 423 -------------- .../src/logical_plan/optimizer/cse/cache.rs | 45 ++ .../optimizer/{ => cse}/cse_expr.rs | 30 +- .../src/logical_plan/optimizer/cse/cse_lp.rs | 337 ++++++++++++ .../src/logical_plan/optimizer/cse/mod.rs | 18 + .../logical_plan/optimizer/file_caching.rs | 101 ++-- .../src/logical_plan/optimizer/mod.rs | 94 ++-- .../optimizer/predicate_pushdown/mod.rs | 387 ++++++++----- .../polars-plan/src/logical_plan/options.rs | 34 +- .../src/logical_plan/visitor/expr.rs | 54 +- .../src/logical_plan/visitor/hash.rs | 518 ++++++++++++++++++ .../src/logical_plan/visitor/lp.rs | 9 + .../src/logical_plan/visitor/mod.rs | 1 + crates/polars-time/src/group_by/dynamic.rs | 2 +- 43 files changed, 1882 insertions(+), 1155 deletions(-) rename crates/polars-plan/src/logical_plan/{alp.rs => alp/inputs.rs} (58%) create mode 100644 crates/polars-plan/src/logical_plan/alp/mod.rs create mode 100644 crates/polars-plan/src/logical_plan/alp/schema.rs delete mode 100644 crates/polars-plan/src/logical_plan/optimizer/cse.rs create mode 100644 crates/polars-plan/src/logical_plan/optimizer/cse/cache.rs rename crates/polars-plan/src/logical_plan/optimizer/{ => cse}/cse_expr.rs (97%) create mode 100644 crates/polars-plan/src/logical_plan/optimizer/cse/cse_lp.rs create mode 100644 crates/polars-plan/src/logical_plan/optimizer/cse/mod.rs create mode 100644 crates/polars-plan/src/logical_plan/visitor/hash.rs diff --git a/crates/polars-core/src/frame/explode.rs b/crates/polars-core/src/frame/explode.rs index 51fe294aa3fa..eb4fe1e7cf6d 100644 --- a/crates/polars-core/src/frame/explode.rs +++ b/crates/polars-core/src/frame/explode.rs @@ -21,7 +21,7 @@ fn get_exploded(series: &Series) -> PolarsResult<(Series, OffsetsBuffer)> { } /// Arguments for `[DataFrame::melt]` function -#[derive(Clone, Default, Debug, PartialEq)] +#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] pub struct MeltArgs { pub id_vars: Vec, diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 19db5298b088..019fb6c342d8 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -41,7 +41,7 @@ pub enum NullStrategy { Propagate, } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum UniqueKeepStrategy { /// Keep the first unique row. diff --git a/crates/polars-core/src/schema.rs b/crates/polars-core/src/schema.rs index 1e61a02cb3b2..569517bd37ad 100644 --- a/crates/polars-core/src/schema.rs +++ b/crates/polars-core/src/schema.rs @@ -1,4 +1,5 @@ use std::fmt::{Debug, Formatter}; +use std::hash::{Hash, Hasher}; use arrow::datatypes::ArrowSchemaRef; use indexmap::map::MutableKeys; @@ -17,6 +18,12 @@ pub struct Schema { inner: PlIndexMap, } +impl Hash for Schema { + fn hash(&self, state: &mut H) { + self.inner.iter().for_each(|v| v.hash(state)) + } +} + // Schemas will only compare equal if they have the same fields in the same order. We can't use `self.inner == // other.inner` because [`IndexMap`] ignores order when checking equality, but we don't want to ignore it. impl PartialEq for Schema { diff --git a/crates/polars-io/src/cloud/options.rs b/crates/polars-io/src/cloud/options.rs index 747ad29f374d..05cb9e65252e 100644 --- a/crates/polars-io/src/cloud/options.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -52,7 +52,7 @@ static BUCKET_REGION: Lazy = Vec<(T, String)>; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Hash, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] /// Options to connect to various cloud providers. pub struct CloudOptions { diff --git a/crates/polars-io/src/csv/read.rs b/crates/polars-io/src/csv/read.rs index 6168fa620bb9..e6eaf9efe734 100644 --- a/crates/polars-io/src/csv/read.rs +++ b/crates/polars-io/src/csv/read.rs @@ -5,7 +5,7 @@ use crate::csv::read_impl::{ }; use crate::csv::utils::infer_file_schema; -#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum CsvEncoding { /// Utf8 encoding @@ -14,7 +14,7 @@ pub enum CsvEncoding { LossyUtf8, } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum NullValues { /// A single value that's used for all columns @@ -25,7 +25,7 @@ pub enum NullValues { Named(Vec<(String, String)>), } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum CommentPrefix { /// A single byte character that indicates the start of a comment line. diff --git a/crates/polars-io/src/options.rs b/crates/polars-io/src/options.rs index a4f23e9cc272..13a6b56bcdef 100644 --- a/crates/polars-io/src/options.rs +++ b/crates/polars-io/src/options.rs @@ -2,7 +2,7 @@ use polars_utils::IdxSize; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RowIndex { pub name: String, diff --git a/crates/polars-io/src/parquet/read.rs b/crates/polars-io/src/parquet/read.rs index 5691857a4ac4..e1d482b51a35 100644 --- a/crates/polars-io/src/parquet/read.rs +++ b/crates/polars-io/src/parquet/read.rs @@ -21,7 +21,7 @@ use crate::predicates::PhysicalIoExpr; use crate::prelude::*; use crate::RowIndex; -#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum ParallelStrategy { /// Don't parallelize diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs index f5305e4a6f2c..6e2e3e012a39 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs @@ -4,7 +4,6 @@ use std::sync::RwLock; use polars_core::config; use polars_core::utils::accumulate_dataframes_vertical; -#[cfg(feature = "cloud")] use polars_io::cloud::CloudOptions; use polars_io::predicates::apply_predicate; use polars_io::{is_cloud_url, RowIndex}; @@ -18,7 +17,6 @@ pub struct IpcExec { pub(crate) predicate: Option>, pub(crate) options: IpcScanOptions, pub(crate) file_options: FileScanOptions, - #[cfg(feature = "cloud")] pub(crate) cloud_options: Option, pub(crate) metadata: Option, } diff --git a/crates/polars-lazy/src/physical_plan/expressions/column.rs b/crates/polars-lazy/src/physical_plan/expressions/column.rs index eda761ab9d56..bf37377a56cd 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/column.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/column.rs @@ -93,7 +93,7 @@ impl ColumnExpr { && _state.ext_contexts.is_empty() && std::env::var("POLARS_NO_SCHEMA_CHECK").is_err() { - panic!("invalid schema: df {:?}; column: {}", df, &self.name) + panic!("invalid schema: df {:?};\ncolumn: {}", df, &self.name) } } // in release we fallback to linear search diff --git a/crates/polars-lazy/src/physical_plan/planner/lp.rs b/crates/polars-lazy/src/physical_plan/planner/lp.rs index 74f26ad2f228..f44ad6805e23 100644 --- a/crates/polars-lazy/src/physical_plan/planner/lp.rs +++ b/crates/polars-lazy/src/physical_plan/planner/lp.rs @@ -229,7 +229,6 @@ pub fn create_physical_plan( #[cfg(feature = "ipc")] FileScan::Ipc { options, - #[cfg(feature = "cloud")] cloud_options, metadata, } => Ok(Box::new(executors::IpcExec { @@ -238,7 +237,6 @@ pub fn create_physical_plan( predicate, options, file_options, - #[cfg(feature = "cloud")] cloud_options, metadata, })), diff --git a/crates/polars-lazy/src/scan/ipc.rs b/crates/polars-lazy/src/scan/ipc.rs index 6d378708c8f8..41797db9a341 100644 --- a/crates/polars-lazy/src/scan/ipc.rs +++ b/crates/polars-lazy/src/scan/ipc.rs @@ -13,7 +13,6 @@ pub struct ScanArgsIpc { pub rechunk: bool, pub row_index: Option, pub memmap: bool, - #[cfg(feature = "cloud")] pub cloud_options: Option, } @@ -25,7 +24,6 @@ impl Default for ScanArgsIpc { rechunk: false, row_index: None, memmap: true, - #[cfg(feature = "cloud")] cloud_options: Default::default(), } } @@ -79,7 +77,6 @@ impl LazyFileListReader for LazyIpcReader { args.cache, args.row_index, args.rechunk, - #[cfg(feature = "cloud")] args.cloud_options, )? .build() diff --git a/crates/polars-lazy/src/tests/cse.rs b/crates/polars-lazy/src/tests/cse.rs index f97dee7023c4..05ac8c4f9cdf 100644 --- a/crates/polars-lazy/src/tests/cse.rs +++ b/crates/polars-lazy/src/tests/cse.rs @@ -45,9 +45,14 @@ fn test_cse_unions() -> PolarsResult<()> { let (mut expr_arena, mut lp_arena) = get_arenas(); let lp = lf.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap(); + let mut cache_count = 0; assert!((&lp_arena).iter(lp).all(|(_, lp)| { use ALogicalPlan::*; match lp { + Cache { .. } => { + cache_count += 1; + true + }, Scan { file_options, .. } => { if let Some(columns) = &file_options.with_columns { columns.len() == 2 @@ -58,6 +63,7 @@ fn test_cse_unions() -> PolarsResult<()> { _ => true, } })); + assert_eq!(cache_count, 2); let out = lf.collect()?; assert_eq!(out.get_column_names(), &["category", "fats_g"]); @@ -82,17 +88,23 @@ fn test_cse_cache_union_projection_pd() -> PolarsResult<()> { // check that the projection of a is not done before the cache let (mut expr_arena, mut lp_arena) = get_arenas(); let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + let mut cache_count = 0; assert!((&lp_arena).iter(lp).all(|(_, lp)| { use ALogicalPlan::*; match lp { + Cache { .. } => { + cache_count += 1; + true + }, DataFrameScan { projection: Some(projection), .. - } => projection.as_ref() == &vec!["a".to_string(), "b".to_string()], + } => projection.as_ref().len() <= 2, DataFrameScan { .. } => false, _ => true, } })); + assert_eq!(cache_count, 2); Ok(()) } @@ -189,7 +201,9 @@ fn test_cse_joins_4954() -> PolarsResult<()> { .flat_map(|(_, lp)| { use ALogicalPlan::*; match lp { - Cache { id, count, input } => { + Cache { + id, count, input, .. + } => { assert_eq!(*count, 1); assert!(matches!( lp_arena.get(*input), @@ -242,11 +256,13 @@ fn test_cache_with_partial_projection() -> PolarsResult<()> { JoinType::Semi.into(), ); - let q = q.with_comm_subplan_elim(true); - let (mut expr_arena, mut lp_arena) = get_arenas(); let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); + // EDIT: #15264 this originally + // tested 2 caches, but we cannot do that after #15264 due to projection pushdown + // running first and the cache semantics changing, so now we test 1. Maybe we can improve later. + // ensure we get two different caches // and ensure that every cache only has 1 hit. let cache_ids = (&lp_arena) @@ -259,7 +275,7 @@ fn test_cache_with_partial_projection() -> PolarsResult<()> { } }) .collect::>(); - assert_eq!(cache_ids.len(), 2); + assert_eq!(cache_ids.len(), 1); Ok(()) } diff --git a/crates/polars-lazy/src/tests/io.rs b/crates/polars-lazy/src/tests/io.rs index 472ee803372c..b527aa18abe3 100644 --- a/crates/polars-lazy/src/tests/io.rs +++ b/crates/polars-lazy/src/tests/io.rs @@ -416,7 +416,6 @@ fn test_ipc_globbing() -> PolarsResult<()> { rechunk: false, row_index: None, memmap: true, - #[cfg(feature = "cloud")] cloud_options: None, }, )? diff --git a/crates/polars-lazy/src/tests/tpch.rs b/crates/polars-lazy/src/tests/tpch.rs index 37006f9f555d..f909a7a18b5c 100644 --- a/crates/polars-lazy/src/tests/tpch.rs +++ b/crates/polars-lazy/src/tests/tpch.rs @@ -85,6 +85,15 @@ fn test_q2() -> PolarsResult<()> { .limit(100) .with_comm_subplan_elim(true); + let (node, lp_arena, _) = q.clone().to_alp_optimized().unwrap(); + assert_eq!( + (&lp_arena) + .iter(node) + .filter(|(_, alp)| matches!(alp, ALogicalPlan::Cache { .. })) + .count(), + 2 + ); + let out = q.collect()?; let schema = Schema::from_iter([ Field::new("s_acctbal", DataType::Float64), diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index b0008b7afb7c..42339e9425d1 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -18,7 +18,7 @@ pub type ChunkJoinIds = Vec; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -#[derive(Clone, PartialEq, Eq, Debug)] +#[derive(Clone, PartialEq, Eq, Debug, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct JoinArgs { pub how: JoinType, @@ -56,7 +56,7 @@ impl JoinArgs { } } -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum JoinType { Left, @@ -116,7 +116,7 @@ impl Debug for JoinType { } } -#[derive(Copy, Clone, PartialEq, Eq, Default)] +#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum JoinValidation { /// No unique checks diff --git a/crates/polars-ops/src/frame/join/asof/mod.rs b/crates/polars-ops/src/frame/join/asof/mod.rs index 0fc95e3bcb5a..7fd6c0a048d1 100644 --- a/crates/polars-ops/src/frame/join/asof/mod.rs +++ b/crates/polars-ops/src/frame/join/asof/mod.rs @@ -142,7 +142,7 @@ impl AsofJoinState for AsofJoinNearestState { } } -#[derive(Clone, Debug, PartialEq, Eq, Default)] +#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct AsOfOptions { pub strategy: AsofStrategy, @@ -191,7 +191,7 @@ fn check_asof_columns( Ok(()) } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum AsofStrategy { /// selects the last row in the right DataFrame whose ‘on’ key is less than or equal to the left’s key diff --git a/crates/polars-plan/src/dsl/options.rs b/crates/polars-plan/src/dsl/options.rs index f7c8d355e5a8..a1c362cbc0dd 100644 --- a/crates/polars-plan/src/dsl/options.rs +++ b/crates/polars-plan/src/dsl/options.rs @@ -38,7 +38,7 @@ impl Default for StrptimeOptions { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct JoinOptions { pub allow_parallel: bool, diff --git a/crates/polars-plan/src/logical_plan/aexpr/hash.rs b/crates/polars-plan/src/logical_plan/aexpr/hash.rs index 8ef7c86c5fca..f9b6bcfcfb49 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/hash.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/hash.rs @@ -1,5 +1,8 @@ use std::hash::{Hash, Hasher}; +use polars_utils::arena::{Arena, Node}; + +use crate::logical_plan::ArenaExprIter; use crate::prelude::AExpr; impl Hash for AExpr { @@ -30,3 +33,13 @@ impl Hash for AExpr { } } } + +pub(crate) fn traverse_and_hash_aexpr( + node: Node, + expr_arena: &Arena, + state: &mut H, +) { + for (_, ae) in expr_arena.iter(node) { + ae.hash(state); + } +} diff --git a/crates/polars-plan/src/logical_plan/aexpr/mod.rs b/crates/polars-plan/src/logical_plan/aexpr/mod.rs index 40b16de73380..2faf576642e7 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/mod.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/mod.rs @@ -3,14 +3,13 @@ mod schema; use std::hash::{Hash, Hasher}; +pub(super) use hash::traverse_and_hash_aexpr; use polars_core::prelude::*; use polars_core::utils::{get_time_units, try_get_supertype}; use polars_utils::arena::{Arena, Node}; use strum_macros::IntoStaticStr; use crate::constants::LEN; -#[cfg(feature = "cse")] -use crate::logical_plan::visitor::AexprNode; use crate::logical_plan::Context; use crate::prelude::*; @@ -189,18 +188,6 @@ pub enum AExpr { } impl AExpr { - #[cfg(feature = "cse")] - pub(crate) fn is_equal(l: Node, r: Node, arena: &Arena) -> bool { - let arena = arena as *const Arena as *mut Arena; - // SAFETY: we can pass a *mut pointer - // the equality operation will not access mutable - unsafe { - let ae_node_l = AexprNode::from_raw(l, arena); - let ae_node_r = AexprNode::from_raw(r, arena); - ae_node_l == ae_node_r - } - } - #[cfg(feature = "cse")] pub(crate) fn col(name: &str) -> Self { AExpr::Column(ColumnName::from(name)) diff --git a/crates/polars-plan/src/logical_plan/alp.rs b/crates/polars-plan/src/logical_plan/alp/inputs.rs similarity index 58% rename from crates/polars-plan/src/logical_plan/alp.rs rename to crates/polars-plan/src/logical_plan/alp/inputs.rs index 5cea2049150e..c51630fc7be0 100644 --- a/crates/polars-plan/src/logical_plan/alp.rs +++ b/crates/polars-plan/src/logical_plan/alp/inputs.rs @@ -1,211 +1,4 @@ -use std::borrow::Cow; -use std::path::PathBuf; - -use polars_core::prelude::*; -use polars_utils::idx_vec::UnitVec; -use polars_utils::unitvec; - -use super::projection_expr::*; -use crate::prelude::*; - -/// [`ALogicalPlan`] is a representation of [`LogicalPlan`] with [`Node`]s which are allocated in an [`Arena`] -#[derive(Clone, Debug, Default)] -pub enum ALogicalPlan { - #[cfg(feature = "python")] - PythonScan { - options: PythonOptions, - predicate: Option, - }, - Slice { - input: Node, - offset: i64, - len: IdxSize, - }, - Selection { - input: Node, - predicate: ExprIR, - }, - Scan { - paths: Arc<[PathBuf]>, - file_info: FileInfo, - predicate: Option, - /// schema of the projected file - output_schema: Option, - scan_type: FileScan, - /// generic options that can be used for all file types. - file_options: FileScanOptions, - }, - DataFrameScan { - df: Arc, - schema: SchemaRef, - // schema of the projected file - output_schema: Option, - projection: Option>>, - selection: Option, - }, - // Only selects columns - SimpleProjection { - input: Node, - columns: SchemaRef, - duplicate_check: bool, - }, - Projection { - input: Node, - expr: ProjectionExprs, - schema: SchemaRef, - options: ProjectionOptions, - }, - Sort { - input: Node, - by_column: Vec, - args: SortArguments, - }, - Cache { - input: Node, - id: usize, - count: usize, - }, - Aggregate { - input: Node, - keys: Vec, - aggs: Vec, - schema: SchemaRef, - apply: Option>, - maintain_order: bool, - options: Arc, - }, - Join { - input_left: Node, - input_right: Node, - schema: SchemaRef, - left_on: Vec, - right_on: Vec, - options: Arc, - }, - HStack { - input: Node, - exprs: ProjectionExprs, - schema: SchemaRef, - options: ProjectionOptions, - }, - Distinct { - input: Node, - options: DistinctOptions, - }, - MapFunction { - input: Node, - function: FunctionNode, - }, - Union { - inputs: Vec, - options: UnionOptions, - }, - HConcat { - inputs: Vec, - schema: SchemaRef, - options: HConcatOptions, - }, - ExtContext { - input: Node, - contexts: Vec, - schema: SchemaRef, - }, - Sink { - input: Node, - payload: SinkType, - }, - #[default] - Invalid, -} - -impl ALogicalPlan { - /// Get the schema of the logical plan node but don't take projections into account at the scan - /// level. This ensures we can apply the predicate - pub(crate) fn scan_schema(&self) -> &SchemaRef { - use ALogicalPlan::*; - match self { - Scan { file_info, .. } => &file_info.schema, - #[cfg(feature = "python")] - PythonScan { options, .. } => &options.schema, - _ => unreachable!(), - } - } - - pub fn name(&self) -> &'static str { - use ALogicalPlan::*; - match self { - Scan { scan_type, .. } => scan_type.into(), - #[cfg(feature = "python")] - PythonScan { .. } => "python_scan", - Slice { .. } => "slice", - Selection { .. } => "selection", - DataFrameScan { .. } => "df", - Projection { .. } => "projection", - Sort { .. } => "sort", - Cache { .. } => "cache", - Aggregate { .. } => "aggregate", - Join { .. } => "join", - HStack { .. } => "hstack", - Distinct { .. } => "distinct", - MapFunction { .. } => "map_function", - Union { .. } => "union", - HConcat { .. } => "hconcat", - ExtContext { .. } => "ext_context", - Sink { payload, .. } => match payload { - SinkType::Memory => "sink (memory)", - SinkType::File { .. } => "sink (file)", - #[cfg(feature = "cloud")] - SinkType::Cloud { .. } => "sink (cloud)", - }, - SimpleProjection { .. } => "simple_projection", - Invalid => "invalid", - } - } - - /// Get the schema of the logical plan node. - pub fn schema<'a>(&'a self, arena: &'a Arena) -> Cow<'a, SchemaRef> { - use ALogicalPlan::*; - let schema = match self { - #[cfg(feature = "python")] - PythonScan { options, .. } => options.output_schema.as_ref().unwrap_or(&options.schema), - Union { inputs, .. } => return arena.get(inputs[0]).schema(arena), - HConcat { schema, .. } => schema, - Cache { input, .. } => return arena.get(*input).schema(arena), - Sort { input, .. } => return arena.get(*input).schema(arena), - Scan { - output_schema, - file_info, - .. - } => output_schema.as_ref().unwrap_or(&file_info.schema), - DataFrameScan { - schema, - output_schema, - .. - } => output_schema.as_ref().unwrap_or(schema), - Selection { input, .. } => return arena.get(*input).schema(arena), - Projection { schema, .. } => schema, - SimpleProjection { columns, .. } => columns, - Aggregate { schema, .. } => schema, - Join { schema, .. } => schema, - HStack { schema, .. } => schema, - Distinct { input, .. } | Sink { input, .. } => return arena.get(*input).schema(arena), - Slice { input, .. } => return arena.get(*input).schema(arena), - MapFunction { input, function } => { - let input_schema = arena.get(*input).schema(arena); - - return match input_schema { - Cow::Owned(schema) => { - Cow::Owned(function.schema(&schema).unwrap().into_owned()) - }, - Cow::Borrowed(schema) => function.schema(schema).unwrap(), - }; - }, - ExtContext { schema, .. } => schema, - Invalid => unreachable!(), - }; - Cow::Borrowed(schema) - } -} +use super::*; impl ALogicalPlan { /// Takes the expressions of an LP node and the inputs of that node and reconstruct @@ -489,15 +282,3 @@ impl ALogicalPlan { inputs.first().copied() } } - -#[cfg(test)] -mod test { - use super::*; - - // skipped for now - #[ignore] - #[test] - fn test_alp_size() { - assert!(std::mem::size_of::() <= 152); - } -} diff --git a/crates/polars-plan/src/logical_plan/alp/mod.rs b/crates/polars-plan/src/logical_plan/alp/mod.rs new file mode 100644 index 000000000000..d1fa0e696646 --- /dev/null +++ b/crates/polars-plan/src/logical_plan/alp/mod.rs @@ -0,0 +1,136 @@ +mod inputs; +mod schema; + +use std::borrow::Cow; +use std::path::PathBuf; + +use polars_core::prelude::*; +use polars_utils::idx_vec::UnitVec; +use polars_utils::unitvec; + +use super::projection_expr::*; +use crate::prelude::*; + +/// [`ALogicalPlan`] is a representation of [`LogicalPlan`] with [`Node`]s which are allocated in an [`Arena`] +#[derive(Clone, Debug, Default)] +pub enum ALogicalPlan { + #[cfg(feature = "python")] + PythonScan { + options: PythonOptions, + predicate: Option, + }, + Slice { + input: Node, + offset: i64, + len: IdxSize, + }, + Selection { + input: Node, + predicate: ExprIR, + }, + Scan { + paths: Arc<[PathBuf]>, + file_info: FileInfo, + predicate: Option, + /// schema of the projected file + output_schema: Option, + scan_type: FileScan, + /// generic options that can be used for all file types. + file_options: FileScanOptions, + }, + DataFrameScan { + df: Arc, + schema: SchemaRef, + // schema of the projected file + output_schema: Option, + projection: Option>>, + selection: Option, + }, + // Only selects columns + SimpleProjection { + input: Node, + columns: SchemaRef, + duplicate_check: bool, + }, + Projection { + input: Node, + expr: ProjectionExprs, + schema: SchemaRef, + options: ProjectionOptions, + }, + Sort { + input: Node, + by_column: Vec, + args: SortArguments, + }, + Cache { + input: Node, + // Unique ID. + id: usize, + /// How many hits the cache must be saved in memory. + count: usize, + }, + Aggregate { + input: Node, + keys: Vec, + aggs: Vec, + schema: SchemaRef, + apply: Option>, + maintain_order: bool, + options: Arc, + }, + Join { + input_left: Node, + input_right: Node, + schema: SchemaRef, + left_on: Vec, + right_on: Vec, + options: Arc, + }, + HStack { + input: Node, + exprs: ProjectionExprs, + schema: SchemaRef, + options: ProjectionOptions, + }, + Distinct { + input: Node, + options: DistinctOptions, + }, + MapFunction { + input: Node, + function: FunctionNode, + }, + Union { + inputs: Vec, + options: UnionOptions, + }, + HConcat { + inputs: Vec, + schema: SchemaRef, + options: HConcatOptions, + }, + ExtContext { + input: Node, + contexts: Vec, + schema: SchemaRef, + }, + Sink { + input: Node, + payload: SinkType, + }, + #[default] + Invalid, +} + +#[cfg(test)] +mod test { + use super::*; + + // skipped for now + #[ignore] + #[test] + fn test_alp_size() { + assert!(std::mem::size_of::() <= 152); + } +} diff --git a/crates/polars-plan/src/logical_plan/alp/schema.rs b/crates/polars-plan/src/logical_plan/alp/schema.rs new file mode 100644 index 000000000000..55b1751e24f4 --- /dev/null +++ b/crates/polars-plan/src/logical_plan/alp/schema.rs @@ -0,0 +1,90 @@ +use super::*; + +impl ALogicalPlan { + /// Get the schema of the logical plan node but don't take projections into account at the scan + /// level. This ensures we can apply the predicate + pub(crate) fn scan_schema(&self) -> &SchemaRef { + use ALogicalPlan::*; + match self { + Scan { file_info, .. } => &file_info.schema, + #[cfg(feature = "python")] + PythonScan { options, .. } => &options.schema, + _ => unreachable!(), + } + } + + pub fn name(&self) -> &'static str { + use ALogicalPlan::*; + match self { + Scan { scan_type, .. } => scan_type.into(), + #[cfg(feature = "python")] + PythonScan { .. } => "python_scan", + Slice { .. } => "slice", + Selection { .. } => "selection", + DataFrameScan { .. } => "df", + Projection { .. } => "projection", + Sort { .. } => "sort", + Cache { .. } => "cache", + Aggregate { .. } => "aggregate", + Join { .. } => "join", + HStack { .. } => "hstack", + Distinct { .. } => "distinct", + MapFunction { .. } => "map_function", + Union { .. } => "union", + HConcat { .. } => "hconcat", + ExtContext { .. } => "ext_context", + Sink { payload, .. } => match payload { + SinkType::Memory => "sink (memory)", + SinkType::File { .. } => "sink (file)", + #[cfg(feature = "cloud")] + SinkType::Cloud { .. } => "sink (cloud)", + }, + SimpleProjection { .. } => "simple_projection", + Invalid => "invalid", + } + } + + /// Get the schema of the logical plan node. + pub fn schema<'a>(&'a self, arena: &'a Arena) -> Cow<'a, SchemaRef> { + use ALogicalPlan::*; + let schema = match self { + #[cfg(feature = "python")] + PythonScan { options, .. } => options.output_schema.as_ref().unwrap_or(&options.schema), + Union { inputs, .. } => return arena.get(inputs[0]).schema(arena), + HConcat { schema, .. } => schema, + Cache { input, .. } => return arena.get(*input).schema(arena), + Sort { input, .. } => return arena.get(*input).schema(arena), + Scan { + output_schema, + file_info, + .. + } => output_schema.as_ref().unwrap_or(&file_info.schema), + DataFrameScan { + schema, + output_schema, + .. + } => output_schema.as_ref().unwrap_or(schema), + Selection { input, .. } => return arena.get(*input).schema(arena), + Projection { schema, .. } => schema, + SimpleProjection { columns, .. } => columns, + Aggregate { schema, .. } => schema, + Join { schema, .. } => schema, + HStack { schema, .. } => schema, + Distinct { input, .. } | Sink { input, .. } => return arena.get(*input).schema(arena), + Slice { input, .. } => return arena.get(*input).schema(arena), + MapFunction { input, function } => { + let input_schema = arena.get(*input).schema(arena); + + return match input_schema { + Cow::Owned(schema) => { + Cow::Owned(function.schema(&schema).unwrap().into_owned()) + }, + Cow::Borrowed(schema) => function.schema(schema).unwrap(), + }; + }, + ExtContext { schema, .. } => schema, + Invalid => unreachable!(), + }; + Cow::Borrowed(schema) + } +} diff --git a/crates/polars-plan/src/logical_plan/builder.rs b/crates/polars-plan/src/logical_plan/builder.rs index f967dd04aa1c..0115b22c66dd 100644 --- a/crates/polars-plan/src/logical_plan/builder.rs +++ b/crates/polars-plan/src/logical_plan/builder.rs @@ -237,7 +237,7 @@ impl LogicalPlanBuilder { cache: bool, row_index: Option, rechunk: bool, - #[cfg(feature = "cloud")] cloud_options: Option, + cloud_options: Option, ) -> PolarsResult { use polars_io::is_cloud_url; @@ -289,7 +289,6 @@ impl LogicalPlanBuilder { predicate: None, scan_type: FileScan::Ipc { options, - #[cfg(feature = "cloud")] cloud_options, metadata: Some(metadata), }, diff --git a/crates/polars-plan/src/logical_plan/expr_ir.rs b/crates/polars-plan/src/logical_plan/expr_ir.rs index 46ca4909b836..df7ec7835f6a 100644 --- a/crates/polars-plan/src/logical_plan/expr_ir.rs +++ b/crates/polars-plan/src/logical_plan/expr_ir.rs @@ -1,7 +1,9 @@ +use std::hash::{Hash, Hasher}; + use super::*; use crate::constants::LITERAL_NAME; -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)] pub enum OutputName { #[default] None, @@ -25,7 +27,7 @@ impl OutputName { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct ExprIR { /// Output name of this expression. output_name: OutputName, @@ -120,6 +122,13 @@ impl ExprIR { pub(crate) fn has_alias(&self) -> bool { matches!(self.output_name, OutputName::Alias(_)) } + + pub(crate) fn traverse_and_hash(&self, expr_arena: &Arena, state: &mut H) { + traverse_and_hash_aexpr(self.node, expr_arena, state); + if let Some(alias) = self.get_alias() { + alias.hash(state) + } + } } impl AsRef for ExprIR { diff --git a/crates/polars-plan/src/logical_plan/file_scan.rs b/crates/polars-plan/src/logical_plan/file_scan.rs index 8f7319574c0c..a1ebf1795896 100644 --- a/crates/polars-plan/src/logical_plan/file_scan.rs +++ b/crates/polars-plan/src/logical_plan/file_scan.rs @@ -1,3 +1,5 @@ +use std::hash::{Hash, Hasher}; + #[cfg(feature = "parquet")] use polars_parquet::write::FileMetaData; @@ -18,7 +20,6 @@ pub enum FileScan { #[cfg(feature = "ipc")] Ipc { options: IpcScanOptions, - #[cfg(feature = "cloud")] cloud_options: Option, #[cfg_attr(feature = "serde", serde(skip))] metadata: Option, @@ -52,31 +53,51 @@ impl PartialEq for FileScan { ( FileScan::Ipc { options: l, - #[cfg(feature = "cloud")] - cloud_options: c_l, + cloud_options: c_l, .. }, FileScan::Ipc { options: r, - #[cfg(feature = "cloud")] - cloud_options: c_r, + cloud_options: c_r, .. }, - ) => { - #[cfg(not(feature = "cloud"))] - { - l == r - } - #[cfg(feature = "cloud")] - { - l == r && c_l == c_r - } - }, + ) => l == r && c_l == c_r, _ => false, } } } +impl Eq for FileScan {} + +impl Hash for FileScan { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + #[cfg(feature = "csv")] + FileScan::Csv { options } => options.hash(state), + #[cfg(feature = "parquet")] + FileScan::Parquet { + options, + cloud_options, + metadata: _, + } => { + options.hash(state); + cloud_options.hash(state) + }, + #[cfg(feature = "ipc")] + FileScan::Ipc { + options, + cloud_options, + metadata: _, + } => { + options.hash(state); + cloud_options.hash(state); + }, + FileScan::Anonymous { options, .. } => options.hash(state), + } + } +} + impl FileScan { pub(crate) fn remove_metadata(&mut self) { match self { diff --git a/crates/polars-plan/src/logical_plan/functions/count.rs b/crates/polars-plan/src/logical_plan/functions/count.rs index 77e3a9600151..a9b928e817e1 100644 --- a/crates/polars-plan/src/logical_plan/functions/count.rs +++ b/crates/polars-plan/src/logical_plan/functions/count.rs @@ -53,7 +53,6 @@ pub fn count_rows(paths: &Arc<[PathBuf]>, scan_type: &FileScan) -> PolarsResult< #[cfg(feature = "ipc")] FileScan::Ipc { options, - #[cfg(feature = "cloud")] cloud_options, metadata, } => { diff --git a/crates/polars-plan/src/logical_plan/functions/mod.rs b/crates/polars-plan/src/logical_plan/functions/mod.rs index 5a38b7a0ee5c..c9995e03d2dc 100644 --- a/crates/polars-plan/src/logical_plan/functions/mod.rs +++ b/crates/polars-plan/src/logical_plan/functions/mod.rs @@ -6,6 +6,7 @@ mod python_udf; mod rename; use std::borrow::Cow; use std::fmt::{Debug, Display, Formatter}; +use std::hash::{Hash, Hasher}; use std::path::PathBuf; use std::sync::Arc; @@ -95,6 +96,8 @@ pub enum FunctionNode { }, } +impl Eq for FunctionNode {} + impl PartialEq for FunctionNode { fn eq(&self, other: &Self) -> bool { use FunctionNode::*; @@ -117,11 +120,57 @@ impl PartialEq for FunctionNode { (Explode { columns: l, .. }, Explode { columns: r, .. }) => l == r, (Melt { args: l, .. }, Melt { args: r, .. }) => l == r, (RowIndex { name: l, .. }, RowIndex { name: r, .. }) => l == r, + #[cfg(feature = "merge_sorted")] + (MergeSorted { column: l }, MergeSorted { column: r }) => l == r, _ => false, } } } +impl Hash for FunctionNode { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + #[cfg(feature = "python")] + FunctionNode::OpaquePython { .. } => {}, + FunctionNode::Opaque { fmt_str, .. } => fmt_str.hash(state), + FunctionNode::Count { + paths, + scan_type, + alias, + } => { + paths.hash(state); + scan_type.hash(state); + alias.hash(state); + }, + FunctionNode::Pipeline { .. } => {}, + FunctionNode::Unnest { columns } => columns.hash(state), + FunctionNode::DropNulls { subset } => subset.hash(state), + FunctionNode::Rechunk => {}, + #[cfg(feature = "merge_sorted")] + FunctionNode::MergeSorted { column } => column.hash(state), + FunctionNode::Rename { + existing, + new, + swapping: _, + } => { + existing.hash(state); + new.hash(state); + }, + FunctionNode::Explode { columns, schema: _ } => columns.hash(state), + FunctionNode::Melt { args, schema: _ } => args.hash(state), + FunctionNode::RowIndex { + name, + schema: _, + offset, + } => { + name.hash(state); + offset.hash(state); + }, + } + } +} + impl FunctionNode { /// Whether this function can run on batches of data at a time. pub fn is_streamable(&self) -> bool { diff --git a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs index 416061261766..e3174cbd105a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cache_states.rs @@ -39,159 +39,132 @@ pub(super) fn set_cache_states( expr_arena: &mut Arena, scratch: &mut Vec, has_caches: bool, -) { - let mut loop_count = 0; +) -> PolarsResult<()> { let mut stack = Vec::with_capacity(4); - // we loop because there can be nested caches and we must run the projection pushdown - // optimization between cache nodes. - loop { - scratch.clear(); - stack.clear(); - - // per cache id holds: - // a Vec: with children of the node - // a Set: with the union of projected column names. - // a Set: with the union of hstack column names. - let mut cache_schema_and_children = BTreeMap::new(); - - stack.push((root, None, None, None, 0)); - - // the depth of the caches in a single tree branch - let mut max_cache_depth = 0; - - // first traversal - // collect the union of columns per cache id. - // and find the cache parents - while let Some(( - current_node, - mut cache_id, - mut parent, - mut previous_cache, - mut caches_seen, - )) = stack.pop() - { - let lp = lp_arena.get(current_node); - lp.copy_inputs(scratch); - - use ALogicalPlan::*; - match lp { - // don't allow parallelism as caches need each others work - // also self-referencing plans can deadlock on the files they lock - Join { options, .. } if has_caches && options.allow_parallel => { - if let Join { options, .. } = lp_arena.get_mut(current_node) { - let options = Arc::make_mut(options); - options.allow_parallel = false; - } - }, - // don't allow parallelism as caches need each others work - // also self-referencing plans can deadlock on the files they lock - Union { options, .. } if has_caches && options.parallel => { - if let Union { options, .. } = lp_arena.get_mut(current_node) { - options.parallel = false; - } - }, - Cache { input, id, .. } => { - caches_seen += 1; - - // no need to run the same cache optimization twice - if loop_count > caches_seen { - continue; - } - - max_cache_depth = std::cmp::max(caches_seen, max_cache_depth); - if let Some(cache_id) = cache_id { - previous_cache = Some(cache_id) + scratch.clear(); + stack.clear(); + + // Per cache id holds: + // - a Vec: with children of the node + // - a Set: with the union of projected column names. + // - a Set: with the union of hstack column names. + let mut cache_schema_and_children = BTreeMap::new(); + + stack.push((root, None, None, None)); + + // # First traversal. + // Collect the union of columns per cache id. + // And find the cache parents. + while let Some((current_node, mut cache_id, mut parent, mut previous_cache)) = stack.pop() { + let lp = lp_arena.get(current_node); + lp.copy_inputs(scratch); + + use ALogicalPlan::*; + match lp { + // don't allow parallelism as caches need each others work + // also self-referencing plans can deadlock on the files they lock + Join { options, .. } if has_caches && options.allow_parallel => { + if let Join { options, .. } = lp_arena.get_mut(current_node) { + let options = Arc::make_mut(options); + options.allow_parallel = false; + } + }, + // don't allow parallelism as caches need each others work + // also self-referencing plans can deadlock on the files they lock + Union { options, .. } if has_caches && options.parallel => { + if let Union { options, .. } = lp_arena.get_mut(current_node) { + options.parallel = false; + } + }, + Cache { input, id, .. } => { + if let Some(cache_id) = cache_id { + previous_cache = Some(cache_id) + } + if let Some(parent_node) = parent { + // projection pushdown has already run and blocked on cache nodes + // the pushed down columns are projected just above this cache + // if there were no pushed down column, we just take the current + // nodes schema + // we never want to naively take parents, as a join or aggregate for instance + // change the schema + + let (children, union_names) = cache_schema_and_children + .entry(*id) + .or_insert_with(|| (Vec::new(), PlHashSet::new())); + children.push(*input); + + if let Some(names) = get_upper_projections(parent_node, lp_arena, expr_arena) { + union_names.extend(names); } - if let Some(parent_node) = parent { - // projection pushdown has already run and blocked on cache nodes - // the pushed down columns are projected just above this cache - // if there were no pushed down column, we just take the current - // nodes schema - // we never want to naively take parents, as a join or aggregate for instance - // change the schema - - let (children, union_names) = cache_schema_and_children - .entry(*id) - .or_insert_with(|| (Vec::new(), PlHashSet::new())); - children.push(*input); - - if let Some(names) = - get_upper_projections(parent_node, lp_arena, expr_arena) - { - union_names.extend(names); - } - // There was no explicit projection and we must take - // all columns - else { - let schema = lp.schema(lp_arena); - union_names.extend( - schema - .iter_names() - .map(|name| ColumnName::from(name.as_str())), - ); - } + // There was no explicit projection and we must take + // all columns + else { + let schema = lp.schema(lp_arena); + union_names.extend( + schema + .iter_names() + .map(|name| ColumnName::from(name.as_str())), + ); } - cache_id = Some(*id); - }, - _ => {}, - } + } + cache_id = Some(*id); + }, + _ => {}, + } - parent = Some(current_node); - for n in scratch.iter() { - stack.push((*n, cache_id, parent, previous_cache, caches_seen)) - } - scratch.clear(); + parent = Some(current_node); + for n in scratch.iter() { + stack.push((*n, cache_id, parent, previous_cache)) } + scratch.clear(); + } - // second pass - // we create a subtree where we project the columns - // just before the cache. Then we do another projection pushdown - // and finally remove that last projection and stitch the subplan - // back to the cache node again - if !cache_schema_and_children.is_empty() { - let mut pd = ProjectionPushDown::new(); - for (_cache_id, (children, columns)) in cache_schema_and_children { - if !columns.is_empty() { - for child in children { - let columns = &columns; - let child_lp = lp_arena.get(child).clone(); - - // make sure we project in the order of the schema - // if we don't a union may fail as we would project by the - // order we discovered all values. - let child_schema = child_lp.schema(lp_arena); - let child_schema = child_schema.as_ref(); - let projection: Vec<_> = child_schema - .iter_names() - .flat_map(|name| columns.get(name.as_str()).map(|name| name.as_ref())) - .collect(); - - let new_child = lp_arena.add(child_lp); - - let lp = ALogicalPlanBuilder::new(new_child, expr_arena, lp_arena) - .project_simple(projection.iter().copied()) - .unwrap() - .build(); - - let lp = pd.optimize(lp, lp_arena, expr_arena).unwrap(); - // Remove the projection added by the optimization. - let lp = if let ALogicalPlan::Projection { input, .. } - | ALogicalPlan::SimpleProjection { input, .. } = lp - { - lp_arena.take(input) - } else { - lp - }; - lp_arena.replace(child, lp); - } + // second pass + // we create a subtree where we project the columns + // just before the cache. Then we do another projection pushdown + // and finally remove that last projection and stitch the subplan + // back to the cache node again + if !cache_schema_and_children.is_empty() { + let mut proj_pd = ProjectionPushDown::new(); + let pred_pd = PredicatePushDown::new(Default::default()); + for (_cache_id, (children, columns)) in cache_schema_and_children { + if !columns.is_empty() { + for child in children { + let columns = &columns; + let child_lp = lp_arena.get(child).clone(); + + // make sure we project in the order of the schema + // if we don't a union may fail as we would project by the + // order we discovered all values. + let child_schema = child_lp.schema(lp_arena); + let child_schema = child_schema.as_ref(); + let projection: Vec<_> = child_schema + .iter_names() + .flat_map(|name| columns.get(name.as_str()).map(|name| name.as_ref())) + .collect(); + + let new_child = lp_arena.add(child_lp); + + let lp = ALogicalPlanBuilder::new(new_child, expr_arena, lp_arena) + .project_simple(projection.iter().copied()) + .unwrap() + .build(); + + let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?; + let lp = pred_pd.optimize(lp, lp_arena, expr_arena)?; + // Remove the projection added by the optimization. + let lp = if let ALogicalPlan::Projection { input, .. } + | ALogicalPlan::SimpleProjection { input, .. } = lp + { + lp_arena.take(input) + } else { + lp + }; + lp_arena.replace(child, lp); } } } - - if loop_count >= max_cache_depth { - break; - } - loop_count += 1; } + Ok(()) } diff --git a/crates/polars-plan/src/logical_plan/optimizer/collect_members.rs b/crates/polars-plan/src/logical_plan/optimizer/collect_members.rs index 17b051411cde..f64bf40b8b75 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/collect_members.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/collect_members.rs @@ -1,9 +1,33 @@ use super::*; +// Utility to cheaply check if we have duplicate sources. +// This may have false positives. +#[cfg(feature = "cse")] +#[derive(Default)] +struct UniqueScans { + ids: PlHashSet, + count: usize, +} + +#[cfg(feature = "cse")] +impl UniqueScans { + fn insert(&mut self, node: Node, lp_arena: &Arena, expr_arena: &Arena) { + let alp_node = unsafe { ALogicalPlanNode::from_raw(node, lp_arena as *const _ as *mut _) }; + self.ids.insert( + self.ids + .hasher() + .hash_one(alp_node.hashable_and_cmp(expr_arena)), + ); + self.count += 1; + } +} + pub(super) struct MemberCollector { pub(crate) has_joins_or_unions: bool, pub(crate) has_cache: bool, pub(crate) has_ext_context: bool, + #[cfg(feature = "cse")] + scans: UniqueScans, } impl MemberCollector { @@ -12,17 +36,40 @@ impl MemberCollector { has_joins_or_unions: false, has_cache: false, has_ext_context: false, + #[cfg(feature = "cse")] + scans: UniqueScans::default(), } } - pub fn collect(&mut self, root: Node, lp_arena: &Arena) { + pub(super) fn collect( + &mut self, + root: Node, + lp_arena: &Arena, + _expr_arena: &Arena, + ) { use ALogicalPlan::*; - for (_, alp) in lp_arena.iter(root) { + for (_node, alp) in lp_arena.iter(root) { match alp { Join { .. } | Union { .. } => self.has_joins_or_unions = true, Cache { .. } => self.has_cache = true, ExtContext { .. } => self.has_ext_context = true, + #[cfg(feature = "cse")] + Scan { .. } => { + self.scans.insert(_node, lp_arena, _expr_arena); + }, + HConcat { .. } => { + self.has_joins_or_unions = true; + }, + #[cfg(feature = "cse")] + DataFrameScan { .. } => { + self.scans.insert(_node, lp_arena, _expr_arena); + }, _ => {}, } } } + + #[cfg(feature = "cse")] + pub(super) fn has_duplicate_scans(&self) -> bool { + self.scans.count != self.scans.ids.len() + } } diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse.rs b/crates/polars-plan/src/logical_plan/optimizer/cse.rs deleted file mode 100644 index 28fe5f2f2225..000000000000 --- a/crates/polars-plan/src/logical_plan/optimizer/cse.rs +++ /dev/null @@ -1,423 +0,0 @@ -//! Common Subplan Elimination - -use std::collections::{BTreeMap, BTreeSet}; - -use polars_core::prelude::*; -use polars_utils::idx_vec::UnitVec; -use polars_utils::unitvec; - -use crate::prelude::*; - -// nodes into an alogicalplan. -type Trail = Vec; - -// we use mutation of `id` to get a unique trail -// we traverse left first, so the `id` remains the same for an all left traversal. -// every right node may increment `id` and because it's shared mutable there will -// be no collisions as the increment is communicated upward with mutation. -pub(super) fn collect_trails( - root: Node, - lp_arena: &Arena, - // every branch gets its own trail - // note to self: - // don't use a vec, as different branches can have collisions - trails: &mut BTreeMap, - id: &mut u32, - // if trails should be collected - collect: bool, -) -> Option<()> { - // TODO! remove recursion and use a stack - if collect { - trails.get_mut(id).unwrap().push(root); - } - - use ALogicalPlan::*; - match lp_arena.get(root) { - // if we find a cache anywhere, that means the users has set caches and we don't interfere - Cache { .. } => return None, - // we start collecting from first encountered join - // later we must unions as well - Join { - input_left, - input_right, - .. - } => { - // make sure that the new branch has the same trail history - let new_trail = trails.get(id).unwrap().clone(); - collect_trails(*input_left, lp_arena, trails, id, true)?; - - *id += 1; - trails.insert(*id, new_trail); - collect_trails(*input_right, lp_arena, trails, id, true)?; - }, - Union { inputs, .. } | HConcat { inputs, .. } => { - if inputs.len() > 200 { - // don't even bother with cse on this many inputs - return None; - } - let new_trail = trails.get(id).unwrap().clone(); - - let last_i = inputs.len() - 1; - - for (i, input) in inputs.iter().enumerate() { - collect_trails(*input, lp_arena, trails, id, true)?; - - // don't add a trail on the last iteration as that would only add a Union - // without any inputs - if i != last_i { - *id += 1; - trails.insert(*id, new_trail.clone()); - } - } - }, - ExtContext { .. } => { - // block for now. - }, - lp => { - // other nodes have only a single input - let mut nodes: UnitVec = unitvec![]; - lp.copy_inputs(&mut nodes); - if let Some(input) = nodes.pop() { - collect_trails(input, lp_arena, trails, id, collect)? - } - }, - } - Some(()) -} - -fn expr_nodes_equal(a: &[ExprIR], b: &[ExprIR], expr_arena: &Arena) -> bool { - a.len() == b.len() - && a.iter() - .zip(b) - .all(|(a, b)| AExpr::is_equal(a.node(), b.node(), expr_arena)) -} - -fn predicate_equal(a: Option, b: Option, expr_arena: &Arena) -> bool { - match (a, b) { - (Some(l), Some(r)) => AExpr::is_equal(l, r, expr_arena), - (None, None) => true, - _ => false, - } -} - -fn lp_node_equal(a: &ALogicalPlan, b: &ALogicalPlan, expr_arena: &Arena) -> bool { - use ALogicalPlan::*; - match (a, b) { - ( - DataFrameScan { - df: left_df, - projection: None, - selection: None, - .. - }, - DataFrameScan { - df: right_df, - projection: None, - selection: None, - .. - }, - ) => Arc::ptr_eq(left_df, right_df), - ( - Scan { - paths: path_left, - predicate: predicate_left, - scan_type: scan_type_left, - .. - }, - Scan { - paths: path_right, - predicate: predicate_right, - scan_type: scan_type_right, - .. - }, - ) => { - path_left == path_right - && scan_type_left == scan_type_right - && predicate_equal( - predicate_left.as_ref().map(|e| e.node()), - predicate_right.as_ref().map(|e| e.node()), - expr_arena, - ) - }, - (Selection { predicate: l, .. }, Selection { predicate: r, .. }) => { - AExpr::is_equal(l.node(), r.node(), expr_arena) - }, - (SimpleProjection { columns: l, .. }, SimpleProjection { columns: r, .. }) => l == r, - (Projection { expr: l, .. }, Projection { expr: r, .. }) - | (HStack { exprs: l, .. }, HStack { exprs: r, .. }) => expr_nodes_equal(l, r, expr_arena), - ( - Slice { - offset: offset_l, - len: len_l, - .. - }, - Slice { - offset: offset_r, - len: len_r, - .. - }, - ) => offset_l == offset_r && len_l == len_r, - ( - Sort { - by_column: by_l, - args: args_l, - .. - }, - Sort { - by_column: by_r, - args: args_r, - .. - }, - ) => expr_nodes_equal(by_l, by_r, expr_arena) && args_l == args_r, - (Distinct { options: l, .. }, Distinct { options: r, .. }) => l == r, - (MapFunction { function: l, .. }, MapFunction { function: r, .. }) => l == r, - ( - Aggregate { - keys: keys_l, - aggs: agg_l, - apply: None, - maintain_order: maintain_order_l, - options: options_l, - .. - }, - Aggregate { - keys: keys_r, - aggs: agg_r, - apply: None, - maintain_order: maintain_order_r, - options: options_r, - .. - }, - ) => { - maintain_order_l == maintain_order_r - && options_l == options_r - && expr_nodes_equal(keys_l, keys_r, expr_arena) - && expr_nodes_equal(agg_l, agg_r, expr_arena) - }, - #[cfg(feature = "python")] - (PythonScan { options: l, .. }, PythonScan { options: r, .. }) => l == r, - _ => { - // joins and unions are also false - // they do not originate from a single trail - // so we would need to follow every leaf that - // is below the joining/union root - // that gets complicated quick - false - }, - } -} - -/// Iterate from two leaf location upwards and find the latest matching node. -/// -/// Returns the matching nodes -fn longest_subgraph( - trail_a: &Trail, - trail_b: &Trail, - lp_arena: &Arena, - expr_arena: &Arena, -) -> Option<(Node, Node, bool)> { - if trail_a.is_empty() || trail_b.is_empty() { - return None; - } - let mut prev_node_a = Node(0); - let mut prev_node_b = Node(0); - let mut is_equal; - let mut i = 0; - let mut entirely_equal = trail_a.len() == trail_b.len(); - - // iterates from the leaves upwards - for (node_a, node_b) in trail_a.iter().rev().zip(trail_b.iter().rev()) { - // we never include the root that splits a trail - // e.g. don't want to cache the join/union, but - // we want to cache the similar inputs - if *node_a == *node_b { - break; - } - let a = lp_arena.get(*node_a); - let b = lp_arena.get(*node_b); - - is_equal = lp_node_equal(a, b, expr_arena); - - if !is_equal { - entirely_equal = false; - break; - } - - prev_node_a = *node_a; - prev_node_b = *node_b; - i += 1; - } - // previous node was equal - if i > 0 { - Some((prev_node_a, prev_node_b, entirely_equal)) - } else { - None - } -} - -pub(crate) fn elim_cmn_subplans( - root: Node, - lp_arena: &mut Arena, - expr_arena: &Arena, -) -> (Node, bool) { - let mut trails = BTreeMap::new(); - let mut id = 0; - trails.insert(id, Vec::new()); - if collect_trails(root, lp_arena, &mut trails, &mut id, false).is_none() { - // early return because we encountered a cache set by the caller - // we will not interfere with those - return (root, false); - } - let trails = trails.into_values().collect::>(); - - // search from the leaf nodes upwards and find the longest shared subplans - let mut trail_ends = vec![]; - // if i matches j - // we don't need to search with j as they are equal - // this is very important as otherwise we get quadratic behavior - let mut to_skip = BTreeSet::new(); - - for i in 0..trails.len() { - if to_skip.contains(&i) { - continue; - } - let trail_i = &trails[i]; - - // we only look forwards, then we traverse all combinations - for (j, trail_j) in trails.iter().enumerate().skip(i + 1) { - if let Some((a, b, all_equal)) = - longest_subgraph(trail_i, trail_j, lp_arena, expr_arena) - { - // then we can skip `j` as we already searched with trail `i` which is equal - if all_equal { - to_skip.insert(j); - } - trail_ends.push((a, b)) - } - } - } - - let lp_cache = lp_arena as *const Arena as usize; - - let hb = ahash::RandomState::new(); - let mut changed = false; - - let mut cache_mapping = BTreeMap::new(); - let mut cache_counts = PlHashMap::with_capacity(trail_ends.len()); - - for combination in trail_ends.iter() { - // both are the same, but only point to a different location - // in our arena so we hash one and store the hash for both locations - // this will ensure all matches have the same hash. - let node1 = combination.0 .0; - let node2 = combination.1 .0; - - let cache_id = match (cache_mapping.get(&node1), cache_mapping.get(&node2)) { - (Some(h), _) => *h, - (_, Some(h)) => *h, - _ => { - let hash = hb.hash_one(node1); - let mut cache_id = lp_cache.wrapping_add(hash as usize); - // this ensures we can still add branch ids without overflowing - // during the dot representation - if (usize::MAX - cache_id) < 2048 { - cache_id -= 2048 - } - - cache_mapping.insert(node1, cache_id); - cache_mapping.insert(node2, cache_id); - cache_id - }, - }; - *cache_counts.entry(cache_id).or_insert(0usize) += 1; - } - - // insert cache nodes - for combination in trail_ends.iter() { - // both are the same, but only point to a different location - // in our arena so we hash one and store the hash for both locations - // this will ensure all matches have the same hash. - let node1 = combination.0 .0; - let node2 = combination.1 .0; - - let cache_id = match (cache_mapping.get(&node1), cache_mapping.get(&node2)) { - // (Some(_), Some(_)) => { - // continue - // } - (Some(h), _) => *h, - (_, Some(h)) => *h, - _ => { - unreachable!() - }, - }; - let cache_count = *cache_counts.get(&cache_id).unwrap(); - - // reassign old nodes to another location as we are going to replace - // them with a cache node - for inp_node in [combination.0, combination.1] { - if let ALogicalPlan::Cache { count, .. } = lp_arena.get_mut(inp_node) { - *count = cache_count; - } else { - let lp = lp_arena.get(inp_node).clone(); - - let node = lp_arena.add(lp); - - let cache_lp = ALogicalPlan::Cache { - input: node, - id: cache_id, - // remove after one cache hit. - count: cache_count, - }; - lp_arena.replace(inp_node, cache_lp.clone()); - }; - } - - changed = true; - } - - (root, changed) -} - -// ensure the file count counters are decremented with the cache counts -pub(crate) fn decrement_file_counters_by_cache_hits( - root: Node, - lp_arena: &mut Arena, - _expr_arena: &Arena, - acc_count: FileCount, - scratch: &mut Vec, -) { - use ALogicalPlan::*; - match lp_arena.get_mut(root) { - Scan { - file_options: options, - .. - } => { - if acc_count >= options.file_counter { - options.file_counter = 1; - } else { - options.file_counter -= acc_count as FileCount - } - }, - Cache { count, input, .. } => { - // we use usize::MAX for an infinite cache. - let new_count = if *count != usize::MAX { - acc_count + *count as FileCount - } else { - acc_count - }; - decrement_file_counters_by_cache_hits(*input, lp_arena, _expr_arena, new_count, scratch) - }, - lp => { - lp.copy_inputs(scratch); - while let Some(input) = scratch.pop() { - decrement_file_counters_by_cache_hits( - input, - lp_arena, - _expr_arena, - acc_count, - scratch, - ) - } - }, - } -} diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse/cache.rs b/crates/polars-plan/src/logical_plan/optimizer/cse/cache.rs new file mode 100644 index 000000000000..c91ddd04a7fd --- /dev/null +++ b/crates/polars-plan/src/logical_plan/optimizer/cse/cache.rs @@ -0,0 +1,45 @@ +use super::*; + +// ensure the file count counters are decremented with the cache counts +pub(crate) fn decrement_file_counters_by_cache_hits( + root: Node, + lp_arena: &mut Arena, + _expr_arena: &Arena, + acc_count: FileCount, + scratch: &mut Vec, +) { + use ALogicalPlan::*; + match lp_arena.get_mut(root) { + Scan { + file_options: options, + .. + } => { + if acc_count >= options.file_counter { + options.file_counter = 1; + } else { + options.file_counter -= acc_count as FileCount + } + }, + Cache { count, input, .. } => { + // we use usize::MAX for an infinite cache. + let new_count = if *count != usize::MAX { + acc_count + *count as FileCount + } else { + acc_count + }; + decrement_file_counters_by_cache_hits(*input, lp_arena, _expr_arena, new_count, scratch) + }, + lp => { + lp.copy_inputs(scratch); + while let Some(input) = scratch.pop() { + decrement_file_counters_by_cache_hits( + input, + lp_arena, + _expr_arena, + acc_count, + scratch, + ) + } + }, + } +} diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/cse/cse_expr.rs similarity index 97% rename from crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs rename to crates/polars-plan/src/logical_plan/optimizer/cse/cse_expr.rs index 4bcbb590aebf..ab90b5946607 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cse/cse_expr.rs @@ -19,7 +19,7 @@ mod identifier_impl { /// We will do a full expression comparison to check if the /// expressions with equal identifiers are truly equal #[derive(Clone, Debug)] - pub struct Identifier { + pub(super) struct Identifier { inner: String, last_node: Option, } @@ -87,7 +87,7 @@ mod identifier_impl { /// We will do a full expression comparison to check if the /// expressions with equal identifiers are truly equal #[derive(Clone, Debug)] - pub struct Identifier { + pub(super) struct Identifier { inner: Option, last_node: Option, hb: RandomState, @@ -239,21 +239,19 @@ fn skip_pre_visit(ae: &AExpr, is_groupby: bool) -> bool { struct ExprIdentifierVisitor<'a> { se_count: &'a mut SubExprCount, identifier_array: &'a mut IdentifierArray, - // index in pre-visit traversal order + // Index in pre-visit traversal order. pre_visit_idx: usize, post_visit_idx: usize, visit_stack: &'a mut Vec, /// Offset in the identifier array /// this allows us to use a single `vec` on multiple expressions id_array_offset: usize, - // whether the expression replaced a subexpression + // Whether the expression replaced a subexpression. has_sub_expr: bool, // During aggregation we only identify element-wise operations is_group_by: bool, } -type Accepted = Option<(VisitRecursion, bool)>; - impl ExprIdentifierVisitor<'_> { fn new<'a>( se_count: &'a mut SubExprCount, @@ -300,14 +298,6 @@ impl ExprIdentifierVisitor<'_> { /// `Some(_, true)` don't accept this node, but can be a member of a cse. /// `Some(_, false)` don't accept this node, and don't allow as a member of a cse. fn accept_node_post_visit(&self, ae: &AExpr) -> Accepted { - // Don't allow this node in a cse. - const REFUSE_NO_MEMBER: Accepted = Some((VisitRecursion::Continue, false)); - // Don't allow this node, but allow as a member of a cse. - const REFUSE_ALLOW_MEMBER: Accepted = Some((VisitRecursion::Continue, true)); - const REFUSE_SKIP: Accepted = Some((VisitRecursion::Skip, false)); - // Accept this node. - const ACCEPT: Accepted = None; - match ae { // window expressions should `evaluate_on_groups`, not `evaluate` // so we shouldn't cache the children as they are evaluated incorrectly @@ -382,7 +372,7 @@ impl Visitor for ExprIdentifierVisitor<'_> { self.post_visit_idx += 1; let (pre_visit_idx, sub_expr_id, is_valid_accumulated) = self.pop_until_entered(); - // create the id of this node + // Create the Id of this node. let id: Identifier = sub_expr_id.add_ae_node(node); if !is_valid_accumulated { @@ -391,8 +381,8 @@ impl Visitor for ExprIdentifierVisitor<'_> { return Ok(VisitRecursion::Continue); } - // if we don't store this node - // we only push the visit_stack, so the parents know the trail + // If we don't store this node + // we only push the visit_stack, so the parents know the trail. if let Some((recurse, local_is_valid)) = self.accept_node_post_visit(ae) { self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx; @@ -401,12 +391,12 @@ impl Visitor for ExprIdentifierVisitor<'_> { return Ok(recurse); } - // store the created id + // Store the created id. self.identifier_array[pre_visit_idx + self.id_array_offset] = (self.post_visit_idx, id.clone()); // We popped until entered, push this Id on the stack so the trail - // is available for the parent expression + // is available for the parent expression. self.visit_stack .push(VisitRecord::SubExprId(id.clone(), true)); @@ -506,7 +496,7 @@ impl RewritingVisitor for CommonSubExprRewriter<'_> { let id = &self.identifier_array[self.visited_idx + self.id_array_offset].1; - // placeholder not overwritten, so we can skip this sub-expression + // Id placeholder not overwritten, so we can skip this sub-expression. if !id.is_valid() { self.visited_idx += 1; let recurse = if ae_node.is_leaf() { diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse/cse_lp.rs b/crates/polars-plan/src/logical_plan/optimizer/cse/cse_lp.rs new file mode 100644 index 000000000000..c7120c04ccae --- /dev/null +++ b/crates/polars-plan/src/logical_plan/optimizer/cse/cse_lp.rs @@ -0,0 +1,337 @@ +use super::*; +use crate::prelude::visitor::ALogicalPlanNode; + +mod identifier_impl { + use std::fmt::{Debug, Formatter}; + use std::hash::{Hash, Hasher}; + + use ahash::RandomState; + use polars_core::hashing::_boost_hash_combine; + + use super::*; + /// Identifier that shows the sub-expression path. + /// Must implement hash and equality and ideally + /// have little collisions + /// We will do a full expression comparison to check if the + /// expressions with equal identifiers are truly equal + #[derive(Clone)] + pub(super) struct Identifier { + inner: Option, + last_node: Option, + hb: RandomState, + expr_arena: *const Arena, + } + + impl Debug for Identifier { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.last_node.as_ref().map(|n| n.to_alp())) + } + } + + impl PartialEq for Identifier { + fn eq(&self, other: &Self) -> bool { + self.inner == other.inner + && match (self.last_node, other.last_node) { + (None, None) => true, + (Some(l), Some(r)) => { + let expr_arena = unsafe { &*self.expr_arena }; + // We ignore caches as they are inserted on the node locations. + // In that case we don't want to cmp the cache (as we just inserted it), + // but the input node of the cache. + l.hashable_and_cmp(expr_arena).ignore_caches() + == r.hashable_and_cmp(expr_arena).ignore_caches() + }, + _ => false, + } + } + } + + impl Eq for Identifier {} + + impl Hash for Identifier { + fn hash(&self, state: &mut H) { + state.write_u64(self.inner.unwrap_or(0)) + } + } + + impl Identifier { + /// # Safety + /// + /// The arena must be a valid pointer and there should be no `&mut` to this arena. + pub unsafe fn new(expr_arena: *const Arena) -> Self { + Self { + inner: None, + last_node: None, + hb: RandomState::with_seed(0), + expr_arena, + } + } + + pub fn is_valid(&self) -> bool { + self.inner.is_some() + } + + pub fn combine(&mut self, other: &Identifier) { + let inner = match (self.inner, other.inner) { + (Some(l), Some(r)) => _boost_hash_combine(l, r), + (None, Some(r)) => r, + (Some(l), None) => l, + _ => return, + }; + self.inner = Some(inner); + } + + pub fn add_alp_node(&self, alp: &ALogicalPlanNode) -> Self { + let expr_arena = unsafe { &*self.expr_arena }; + let hashed = self.hb.hash_one(alp.hashable_and_cmp(expr_arena)); + let inner = Some( + self.inner + .map_or(hashed, |l| _boost_hash_combine(l, hashed)), + ); + Self { + inner, + last_node: Some(*alp), + hb: self.hb.clone(), + expr_arena: self.expr_arena, + } + } + } +} +use identifier_impl::*; + +/// Identifier maps to Expr Node and count. +type SubPlanCount = PlHashMap; +/// (post_visit_idx, identifier); +type IdentifierArray = Vec<(usize, Identifier)>; + +/// See Expr based CSE for explanations. +enum VisitRecord { + /// Entered a new plan node + Entered(usize), + SubPlanId(Identifier), +} + +struct LpIdentifierVisitor<'a> { + sp_count: &'a mut SubPlanCount, + identifier_array: &'a mut IdentifierArray, + // Index in pre-visit traversal order. + pre_visit_idx: usize, + post_visit_idx: usize, + visit_stack: Vec, + has_subplan: bool, + expr_arena: &'a Arena, +} + +impl LpIdentifierVisitor<'_> { + fn new<'a>( + sp_count: &'a mut SubPlanCount, + identifier_array: &'a mut IdentifierArray, + expr_arena: &'a Arena, + ) -> LpIdentifierVisitor<'a> { + LpIdentifierVisitor { + sp_count, + identifier_array, + pre_visit_idx: 0, + post_visit_idx: 0, + visit_stack: vec![], + has_subplan: false, + expr_arena, + } + } + + fn pop_until_entered(&mut self) -> (usize, Identifier) { + // SAFETY: + // we keep pointer valid and will not create mutable refs. + let mut id = unsafe { Identifier::new(self.expr_arena as *const _) }; + + while let Some(item) = self.visit_stack.pop() { + match item { + VisitRecord::Entered(idx) => return (idx, id), + VisitRecord::SubPlanId(s) => { + id.combine(&s); + }, + } + } + unreachable!() + } +} + +impl Visitor for LpIdentifierVisitor<'_> { + type Node = ALogicalPlanNode; + + fn pre_visit(&mut self, _node: &Self::Node) -> PolarsResult { + self.visit_stack + .push(VisitRecord::Entered(self.pre_visit_idx)); + self.pre_visit_idx += 1; + + // SAFETY: + // we keep pointer valid and will not create mutable refs. + self.identifier_array + .push((0, unsafe { Identifier::new(self.expr_arena as *const _) })); + Ok(VisitRecursion::Continue) + } + + fn post_visit(&mut self, node: &Self::Node) -> PolarsResult { + self.post_visit_idx += 1; + + let (pre_visit_idx, sub_plan_id) = self.pop_until_entered(); + + // Create the Id of this node. + let id: Identifier = sub_plan_id.add_alp_node(node); + + // Store the created id. + self.identifier_array[pre_visit_idx] = (self.post_visit_idx, id.clone()); + + // We popped until entered, push this Id on the stack so the trail + // is available for the parent plan. + self.visit_stack.push(VisitRecord::SubPlanId(id.clone())); + + let (_, sp_count) = self.sp_count.entry(id).or_insert_with(|| (node.node(), 0)); + *sp_count += 1; + self.has_subplan |= *sp_count > 1; + Ok(VisitRecursion::Continue) + } +} + +pub(super) type CacheId2Caches = PlHashMap)>; + +struct CommonSubPlanRewriter<'a> { + sp_count: &'a SubPlanCount, + identifier_array: &'a IdentifierArray, + + max_post_visit_idx: usize, + /// index in traversal order in which `identifier_array` + /// was written. This is the index in `identifier_array`. + visited_idx: usize, + /// Indicates if this expression is rewritten. + rewritten: bool, + cache_id: PlHashMap, + // Maps cache_id : (cache_count and cache_nodes) + cache_id_to_caches: CacheId2Caches, +} + +impl<'a> CommonSubPlanRewriter<'a> { + fn new(sp_count: &'a SubPlanCount, identifier_array: &'a IdentifierArray) -> Self { + Self { + sp_count, + identifier_array, + max_post_visit_idx: 0, + visited_idx: 0, + rewritten: false, + cache_id: Default::default(), + cache_id_to_caches: Default::default(), + } + } +} + +impl RewritingVisitor for CommonSubPlanRewriter<'_> { + type Node = ALogicalPlanNode; + + fn pre_visit(&mut self, _lp_node: &Self::Node) -> PolarsResult { + if self.visited_idx >= self.identifier_array.len() + || self.max_post_visit_idx > self.identifier_array[self.visited_idx].0 + { + return Ok(RewriteRecursion::Stop); + } + + let id = &self.identifier_array[self.visited_idx].1; + + // Id placeholder not overwritten, so we can skip this sub-expression. + if !id.is_valid() { + self.visited_idx += 1; + return Ok(RewriteRecursion::MutateAndContinue); + } + + let Some((_, count)) = self.sp_count.get(id) else { + self.visited_idx += 1; + return Ok(RewriteRecursion::NoMutateAndContinue); + }; + + if *count > 1 { + // Rewrite this sub-plan, don't visit its children + Ok(RewriteRecursion::MutateAndStop) + } else { + // This is a unique plan + // visit its children to see if they are cse + self.visited_idx += 1; + Ok(RewriteRecursion::NoMutateAndContinue) + } + } + + fn mutate(&mut self, mut node: Self::Node) -> PolarsResult { + let (post_visit_count, id) = &self.identifier_array[self.visited_idx]; + self.visited_idx += 1; + + if *post_visit_count < self.max_post_visit_idx { + return Ok(node); + } + self.max_post_visit_idx = *post_visit_count; + while self.visited_idx < self.identifier_array.len() + && *post_visit_count > self.identifier_array[self.visited_idx].0 + { + self.visited_idx += 1; + } + + let cache_id = self.cache_id.len(); + let cache_id = *self.cache_id.entry(id.clone()).or_insert(cache_id); + let cache_count = self.sp_count.get(id).unwrap().1; + + let cache_node = ALogicalPlan::Cache { + input: node.node(), + id: cache_id, + count: cache_count - 1, + }; + node.assign(cache_node); + let (_count, nodes) = self + .cache_id_to_caches + .entry(cache_id) + .or_insert_with(|| (cache_count, vec![])); + nodes.push(node.node()); + self.rewritten = true; + Ok(node) + } +} + +pub(crate) fn elim_cmn_subplans( + root: Node, + lp_arena: &mut Arena, + expr_arena: &Arena, +) -> (Node, bool, CacheId2Caches) { + let mut sp_count = Default::default(); + let mut id_array = Default::default(); + + let (changed, cache_id_to_caches) = ALogicalPlanNode::with_context(root, lp_arena, |lp_node| { + let mut visitor = LpIdentifierVisitor::new(&mut sp_count, &mut id_array, expr_arena); + + lp_node.visit(&mut visitor).map(|_| ()).unwrap(); + + let mut rewriter = CommonSubPlanRewriter::new(&sp_count, &id_array); + lp_node.rewrite(&mut rewriter).unwrap(); + + (rewriter.rewritten, rewriter.cache_id_to_caches) + }); + + (root, changed, cache_id_to_caches) +} + +/// Prune unused caches. +/// In the query below the query will be insert cache 0 with a count of 2 on `lf.select` +/// and cache 1 with a count of 3 on `lf`. But because cache 0 is higher in the chain cache 1 +/// will never be used. So we prune caches that don't fit their count. +/// +/// `conctat([lf.select(), lf.select(), lf])` +/// +pub(crate) fn prune_unused_caches(lp_arena: &mut Arena, cid2c: CacheId2Caches) { + for (count, nodes) in cid2c.values() { + if *count == nodes.len() { + continue; + } + + for node in nodes { + let ALogicalPlan::Cache { input, .. } = lp_arena.get(*node) else { + unreachable!() + }; + lp_arena.swap(*input, *node) + } + } +} diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/cse/mod.rs new file mode 100644 index 000000000000..8aafa8803c38 --- /dev/null +++ b/crates/polars-plan/src/logical_plan/optimizer/cse/mod.rs @@ -0,0 +1,18 @@ +mod cache; +mod cse_expr; +mod cse_lp; + +pub(super) use cache::decrement_file_counters_by_cache_hits; +pub(super) use cse_expr::CommonSubExprOptimizer; +pub(super) use cse_lp::{elim_cmn_subplans, prune_unused_caches}; + +use super::*; + +type Accepted = Option<(VisitRecursion, bool)>; +// Don't allow this node in a cse. +const REFUSE_NO_MEMBER: Accepted = Some((VisitRecursion::Continue, false)); +// Don't allow this node, but allow as a member of a cse. +const REFUSE_ALLOW_MEMBER: Accepted = Some((VisitRecursion::Continue, true)); +const REFUSE_SKIP: Accepted = Some((VisitRecursion::Skip, false)); +// Accept this node. +const ACCEPT: Accepted = None; diff --git a/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs b/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs index e1f5cf4e195c..352ce387533b 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/file_caching.rs @@ -54,30 +54,30 @@ pub fn collect_fingerprints( expr_arena: &Arena, ) { use ALogicalPlan::*; - match lp_arena.get(root) { - Scan { - paths, - file_options: options, - predicate, - scan_type, - .. - } => { - let slice = (scan_type.skip_rows(), options.n_rows); - let predicate = predicate - .as_ref() - .map(|e| node_to_expr(e.node(), expr_arena)); - let fp = FileFingerPrint { - paths: paths.clone(), + + for (_node, lp) in lp_arena.iter(root) { + #[allow(clippy::single_match)] + match lp { + Scan { + paths, + file_options: options, predicate, - slice, - }; - fps.push(fp); - }, - lp => { - for input in lp.get_inputs() { - collect_fingerprints(input, fps, lp_arena, expr_arena) - } - }, + scan_type, + .. + } => { + let slice = (scan_type.skip_rows(), options.n_rows); + let predicate = predicate + .as_ref() + .map(|e| node_to_expr(e.node(), expr_arena)); + let fp = FileFingerPrint { + paths: paths.clone(), + predicate, + slice, + }; + fps.push(fp); + }, + _ => {}, + } } } @@ -93,33 +93,33 @@ pub fn find_column_union_and_fingerprints( expr_arena: &Arena, ) { use ALogicalPlan::*; - match lp_arena.get(root) { - Scan { - paths, - file_options: options, - predicate, - file_info, - scan_type, - .. - } => { - let slice = (scan_type.skip_rows(), options.n_rows); - let predicate = predicate - .as_ref() - .map(|e| node_to_expr(e.node(), expr_arena)); - process_with_columns( + + for (_node, lp) in lp_arena.iter(root) { + #[allow(clippy::single_match)] + match lp { + Scan { paths, - options.with_columns.as_deref(), + file_options: options, predicate, - slice, - columns, - &file_info.schema, - ); - }, - lp => { - for input in lp.get_inputs() { - find_column_union_and_fingerprints(input, columns, lp_arena, expr_arena) - } - }, + file_info, + scan_type, + .. + } => { + let slice = (scan_type.skip_rows(), options.n_rows); + let predicate = predicate + .as_ref() + .map(|e| node_to_expr(e.node(), expr_arena)); + process_with_columns( + paths, + options.with_columns.as_deref(), + predicate, + slice, + columns, + &file_info.schema, + ); + }, + _ => {}, + } } } @@ -156,6 +156,9 @@ impl FileCacher { with_columns: Option>>, behind_cache: bool, ) -> ALogicalPlan { + if behind_cache { + return lp; + } // if the original projection is less than the new one. Also project locally if let Some(mut with_columns) = with_columns { // we cannot always find the predicates, because some have `SpecialEq` functions so for those @@ -164,7 +167,7 @@ impl FileCacher { Some((_file_count, agg_columns)) => with_columns.len() < agg_columns.len(), None => true, }; - if !behind_cache && do_projection { + if do_projection { let node = lp_arena.add(lp); let projections = std::mem::take(Arc::make_mut(&mut with_columns)) diff --git a/crates/polars-plan/src/logical_plan/optimizer/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/mod.rs index 896d5bd3458d..3e31b39ebcb3 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/mod.rs @@ -3,15 +3,13 @@ use polars_core::prelude::*; use crate::prelude::*; mod cache_states; -#[cfg(feature = "cse")] -mod cse; mod delay_rechunk; mod drop_nulls; mod collect_members; mod count_star; #[cfg(feature = "cse")] -mod cse_expr; +mod cse; #[cfg(any( feature = "ipc", feature = "parquet", @@ -35,6 +33,7 @@ mod type_coercion; use delay_rechunk::DelayRechunk; use drop_nulls::ReplaceDropNulls; +use polars_core::config::verbose; use polars_io::predicates::PhysicalIoExpr; pub use predicate_pushdown::PredicatePushDown; pub use projection_pushdown::ProjectionPushDown; @@ -48,7 +47,9 @@ use self::flatten_union::FlattenUnionRule; pub use crate::frame::{AllowedOptimizations, OptState}; use crate::logical_plan::optimizer::count_star::CountStar; #[cfg(feature = "cse")] -use crate::logical_plan::optimizer::cse_expr::CommonSubExprOptimizer; +use crate::logical_plan::optimizer::cse::prune_unused_caches; +#[cfg(feature = "cse")] +use crate::logical_plan::optimizer::cse::CommonSubExprOptimizer; use crate::logical_plan::optimizer::predicate_pushdown::HiveEval; #[cfg(feature = "cse")] use crate::logical_plan::visitor::*; @@ -73,6 +74,8 @@ pub fn optimize( scratch: &mut Vec, hive_partition_eval: HiveEval<'_>, ) -> PolarsResult { + #[cfg(feature = "cse")] + let verbose = verbose(); // get toggle values let predicate_pushdown = opt_state.predicate_pushdown; let projection_pushdown = opt_state.projection_pushdown; @@ -108,21 +111,9 @@ pub fn optimize( // Collect members for optimizations that need it. let mut members = MemberCollector::new(); if !eager && (comm_subexpr_elim || projection_pushdown) { - members.collect(lp_top, lp_arena) + members.collect(lp_top, lp_arena, expr_arena) } - #[cfg(feature = "cse")] - let cse_plan_changed = if comm_subplan_elim { - let (lp, changed) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena); - lp_top = lp; - members.has_cache |= changed; - changed - } else { - false - }; - #[cfg(not(feature = "cse"))] - let cse_plan_changed = false; - // we do simplification if simplify_expr { rules.push(Box::new(SimplifyExprRule {})); @@ -130,6 +121,25 @@ pub fn optimize( rules.push(Box::new(fused::FusedArithmetic {})); } + #[cfg(feature = "cse")] + let cse_plan_changed = + if comm_subplan_elim && members.has_joins_or_unions && members.has_duplicate_scans() { + if verbose { + eprintln!("found multiple sources; run comm_subplan_elim") + } + let (lp, changed, cid2c) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena); + + prune_unused_caches(lp_arena, cid2c); + + lp_top = lp; + members.has_cache |= changed; + changed + } else { + false + }; + #[cfg(not(feature = "cse"))] + let cse_plan_changed = false; + // should be run before predicate pushdown if projection_pushdown { let mut projection_pushdown_opt = ProjectionPushDown::new(); @@ -137,10 +147,6 @@ pub fn optimize( let alp = projection_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; lp_arena.replace(lp_top, alp); - if members.has_joins_or_unions && members.has_cache { - cache_states::set_cache_states(lp_top, lp_arena, expr_arena, scratch, cse_plan_changed); - } - if projection_pushdown_opt.is_count_star { let mut count_star_opt = CountStar::new(); count_star_opt.optimize_plan(lp_arena, expr_arena, lp_top); @@ -182,11 +188,32 @@ pub fn optimize( rules.push(Box::new(SimplifyBooleanRule {})); } + rules.push(Box::new(ReplaceDropNulls {})); + if !eager { + rules.push(Box::new(FlattenUnionRule {})); + } + + lp_top = opt.optimize_loop(&mut rules, expr_arena, lp_arena, lp_top)?; + + if members.has_joins_or_unions && members.has_cache { + cache_states::set_cache_states(lp_top, lp_arena, expr_arena, scratch, cse_plan_changed)?; + } + + // This one should run (nearly) last as this modifies the projections + #[cfg(feature = "cse")] + if comm_subexpr_elim && !members.has_ext_context { + let mut optimizer = CommonSubExprOptimizer::new(expr_arena); + lp_top = ALogicalPlanNode::with_context(lp_top, lp_arena, |alp_node| { + alp_node.rewrite(&mut optimizer) + })? + .node() + } + // make sure that we do that once slice pushdown // and predicate pushdown are done. At that moment // the file fingerprints are finished. #[cfg(any(feature = "cse", feature = "parquet", feature = "ipc", feature = "csv"))] - if agg_scan_projection || cse_plan_changed { + if agg_scan_projection && !cse_plan_changed { // we do this so that expressions are simplified created by the pushdown optimizations // we must clean up the predicates, because the agg_scan_projection // uses them in the hashtable to determine duplicates. @@ -205,29 +232,12 @@ pub fn optimize( let mut file_cacher = FileCacher::new(file_predicate_to_columns_and_count); file_cacher.assign_unions(lp_top, lp_arena, expr_arena, scratch); - - #[cfg(feature = "cse")] - if cse_plan_changed { - // this must run after cse - cse::decrement_file_counters_by_cache_hits(lp_top, lp_arena, expr_arena, 0, scratch); - } } - rules.push(Box::new(ReplaceDropNulls {})); - if !eager { - rules.push(Box::new(FlattenUnionRule {})); - } - - lp_top = opt.optimize_loop(&mut rules, expr_arena, lp_arena, lp_top)?; - - // This one should run (nearly) last as this modifies the projections #[cfg(feature = "cse")] - if comm_subexpr_elim && !members.has_ext_context { - let mut optimizer = CommonSubExprOptimizer::new(expr_arena); - lp_top = ALogicalPlanNode::with_context(lp_top, lp_arena, |alp_node| { - alp_node.rewrite(&mut optimizer) - })? - .node() + if cse_plan_changed { + // this must run after cse + cse::decrement_file_counters_by_cache_hits(lp_top, lp_arena, expr_arena, 0, scratch); } // during debug we check if the optimizations have not modified the final schema diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index 4acae0aef112..6d73d18d5de1 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -4,7 +4,6 @@ mod keys; mod rename; mod utils; -use polars_core::config::verbose; use polars_core::datatypes::PlHashMap; use polars_core::prelude::*; use utils::*; @@ -207,6 +206,18 @@ impl<'a> PredicatePushDown<'a> { Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) } + fn no_pushdown( + &self, + lp: ALogicalPlan, + acc_predicates: PlHashMap, ExprIR>, + lp_arena: &mut Arena, + expr_arena: &mut Arena, + ) -> PolarsResult { + // all predicates are done locally + let local_predicates = acc_predicates.into_values().collect::>(); + Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + } + /// Predicate pushdown optimizer /// /// # Arguments @@ -228,7 +239,10 @@ impl<'a> PredicatePushDown<'a> { use ALogicalPlan::*; match lp { - Selection { ref predicate, input } => { + Selection { + ref predicate, + input, + } => { // Use a tmp_key to avoid inadvertently combining predicates that otherwise would have // been partially pushed: // @@ -240,21 +254,22 @@ impl<'a> PredicatePushDown<'a> { let tmp_key = Arc::::from(&*temporary_unique_key(&acc_predicates)); acc_predicates.insert(tmp_key.clone(), predicate.clone()); - let local_predicates = match pushdown_eligibility(&[], &acc_predicates, expr_arena)?.0 { - PushdownEligibility::Full => vec![], - PushdownEligibility::Partial { to_local } => { - let mut out = Vec::with_capacity(to_local.len()); - for key in to_local { - out.push(acc_predicates.remove(&key).unwrap()); - } - out - }, - PushdownEligibility::NoPushdown => { - let out = acc_predicates.drain().map(|t| t.1).collect(); - acc_predicates.clear(); - out - }, - }; + let local_predicates = + match pushdown_eligibility(&[], &acc_predicates, expr_arena)?.0 { + PushdownEligibility::Full => vec![], + PushdownEligibility::Partial { to_local } => { + let mut out = Vec::with_capacity(to_local.len()); + for key in to_local { + out.push(acc_predicates.remove(&key).unwrap()); + } + out + }, + PushdownEligibility::NoPushdown => { + let out = acc_predicates.drain().map(|t| t.1).collect(); + acc_predicates.clear(); + out + }, + }; if let Some(predicate) = acc_predicates.remove(&tmp_key) { insert_and_combine_predicate(&mut acc_predicates, &predicate, expr_arena); @@ -268,8 +283,15 @@ impl<'a> PredicatePushDown<'a> { // predicates, we simply don't pushdown this one passed this node // However, we can do better and let it pass but store the order of the predicates // so that we can apply them in correct order at the deepest level - Ok(self.optional_apply_predicate(new_input, local_predicates, lp_arena, expr_arena)) - } + Ok( + self.optional_apply_predicate( + new_input, + local_predicates, + lp_arena, + expr_arena, + ), + ) + }, DataFrameScan { df, schema, @@ -286,14 +308,14 @@ impl<'a> PredicatePushDown<'a> { selection, }; Ok(lp) - } + }, Scan { mut paths, mut file_info, ref predicate, mut scan_type, file_options: options, - output_schema + output_schema, } => { for e in acc_predicates.values() { debug_assert_aexpr_allows_predicate_pushdown(e.node(), expr_arena); @@ -309,23 +331,25 @@ impl<'a> PredicatePushDown<'a> { // not update the row index properly before applying the // predicate (e.g. FileScan::Csv doesn't). if let Some(ref row_index) = options.row_index { - let row_index_predicates = transfer_to_local_by_name(expr_arena, &mut acc_predicates, |name| { - name.as_ref() == row_index.name - }); + let row_index_predicates = transfer_to_local_by_name( + expr_arena, + &mut acc_predicates, + |name| name.as_ref() == row_index.name, + ); row_index_predicates } else { vec![] } - } + }, }; let predicate = predicate_at_scan(acc_predicates, predicate.clone(), expr_arena); if let (true, Some(predicate)) = (file_info.hive_parts.is_some(), &predicate) { - if let Some(io_expr) = self.hive_partition_eval.unwrap()(predicate, expr_arena) { + if let Some(io_expr) = self.hive_partition_eval.unwrap()(predicate, expr_arena) + { if let Some(stats_evaluator) = io_expr.as_stats_evaluator() { let mut new_paths = Vec::with_capacity(paths.len()); - for path in paths.as_ref().iter() { file_info.update_hive_partitions(path)?; let hive_part_stats = file_info.hive_parts.as_deref().ok_or_else(|| polars_err!(ComputeError: "cannot combine hive partitioned directories with non-hive partitioned ones"))?; @@ -337,7 +361,11 @@ impl<'a> PredicatePushDown<'a> { if paths.len() != new_paths.len() { if self.verbose { - eprintln!("hive partitioning: skipped {} files, first file : {}", paths.len() - new_paths.len(), paths[0].display()) + eprintln!( + "hive partitioning: skipped {} files, first file : {}", + paths.len() - new_paths.len(), + paths[0].display() + ) } scan_type.remove_metadata(); } @@ -350,8 +378,8 @@ impl<'a> PredicatePushDown<'a> { schema: schema.clone(), output_schema: None, projection: None, - selection: None - }) + selection: None, + }); } else { paths = Arc::from(new_paths) } @@ -361,10 +389,10 @@ impl<'a> PredicatePushDown<'a> { let mut do_optimization = match &scan_type { #[cfg(feature = "csv")] - FileScan::Csv {..} => options.n_rows.is_none(), - FileScan::Anonymous {function, ..} => function.allows_predicate_pushdown(), + FileScan::Csv { .. } => options.n_rows.is_none(), + FileScan::Anonymous { function, .. } => function.allows_predicate_pushdown(), #[allow(unreachable_patterns)] - _ => true + _ => true, }; do_optimization &= predicate.is_some(); @@ -375,7 +403,7 @@ impl<'a> PredicatePushDown<'a> { predicate, file_options: options, output_schema, - scan_type + scan_type, } } else { let lp = Scan { @@ -384,52 +412,39 @@ impl<'a> PredicatePushDown<'a> { predicate: None, file_options: options, output_schema, - scan_type + scan_type, }; if let Some(predicate) = predicate { let input = lp_arena.add(lp); - Selection { - input, - predicate - } + Selection { input, predicate } } else { lp } }; Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) - - } - Distinct { - input, - options - } => { + }, + Distinct { input, options } => { if let Some(ref subset) = options.subset { // Predicates on the subset can pass. let subset = subset.clone(); let mut names_set = PlHashSet::<&str>::with_capacity(subset.len()); for name in subset.iter() { names_set.insert(name.as_str()); - }; + } let condition = |name: Arc| !names_set.contains(name.as_ref()); let local_predicates = transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition); self.pushdown_and_assign(input, acc_predicates, lp_arena, expr_arena)?; - let lp = Distinct { - input, - options - }; + let lp = Distinct { input, options }; Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) } else { - let lp = Distinct { - input, - options - }; + let lp = Distinct { input, options }; self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) } - } + }, Join { input_left, input_right, @@ -437,54 +452,64 @@ impl<'a> PredicatePushDown<'a> { right_on, schema, options, - } => { - process_join( - self, - lp_arena, - expr_arena, - input_left, - input_right, - left_on, - right_on, - schema, - options, - acc_predicates - ) - } + } => process_join( + self, + lp_arena, + expr_arena, + input_left, + input_right, + left_on, + right_on, + schema, + options, + acc_predicates, + ), MapFunction { ref function, .. } => { - if function.allow_predicate_pd() - { + if function.allow_predicate_pd() { match function { - FunctionNode::Rename { - existing, - new, - .. - } => { - let local_predicates = process_rename(&mut acc_predicates, - expr_arena, - existing, - new, + FunctionNode::Rename { existing, new, .. } => { + let local_predicates = + process_rename(&mut acc_predicates, expr_arena, existing, new)?; + let lp = self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, )?; - let lp = self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?; - Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) + Ok(self.optional_apply_predicate( + lp, + local_predicates, + lp_arena, + expr_arena, + )) }, - FunctionNode::Explode {columns, ..} => { - - let condition = |name: Arc| columns.iter().any(|s| s.as_ref() == &*name); + FunctionNode::Explode { columns, .. } => { + let condition = + |name: Arc| columns.iter().any(|s| s.as_ref() == &*name); // first columns that refer to the exploded columns should be done here - let local_predicates = - transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition); - - let lp = self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?; - Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) - - } - FunctionNode::Melt { - args, - .. - } => { - + let local_predicates = transfer_to_local_by_name( + expr_arena, + &mut acc_predicates, + condition, + ); + + let lp = self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + )?; + Ok(self.optional_apply_predicate( + lp, + local_predicates, + lp_arena, + expr_arena, + )) + }, + FunctionNode::Melt { args, .. } => { let variable_name = args.variable_name.as_deref().unwrap_or("variable"); let value_name = args.value_name.as_deref().unwrap_or("value"); @@ -495,28 +520,60 @@ impl<'a> PredicatePushDown<'a> { || name == value_name || args.value_vars.iter().any(|s| s.as_str() == name) }; - let local_predicates = - transfer_to_local_by_name(expr_arena, &mut acc_predicates, condition); - - let lp = self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?; - Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) - - } - _ => { - self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) - } + let local_predicates = transfer_to_local_by_name( + expr_arena, + &mut acc_predicates, + condition, + ); + + let lp = self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + )?; + Ok(self.optional_apply_predicate( + lp, + local_predicates, + lp_arena, + expr_arena, + )) + }, + _ => self.pushdown_and_continue( + lp, + acc_predicates, + lp_arena, + expr_arena, + false, + ), } - - } else { self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) } - } - Aggregate {input, keys, aggs, schema, apply, maintain_order, options, } => { - process_group_by(self, lp_arena, expr_arena, input, keys, aggs, schema, maintain_order, apply, options, acc_predicates) - }, - lp @ Union {..} => { + Aggregate { + input, + keys, + aggs, + schema, + apply, + maintain_order, + options, + } => process_group_by( + self, + lp_arena, + expr_arena, + input, + keys, + aggs, + schema, + maintain_order, + apply, + options, + acc_predicates, + ), + lp @ Union { .. } => { let mut local_predicates = vec![]; // a count is influenced by a Union/Vstack @@ -528,10 +585,11 @@ impl<'a> PredicatePushDown<'a> { true } }); - let lp = self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?; + let lp = + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?; Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) - } - lp @ Sort{..} => { + }, + lp @ Sort { .. } => { let mut local_predicates = vec![]; acc_predicates.retain(|_, predicate| { if predicate_is_sort_boundary(predicate.node(), expr_arena) { @@ -541,58 +599,93 @@ impl<'a> PredicatePushDown<'a> { true } }); - let lp = self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?; + let lp = + self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false)?; Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena)) - - } + }, // Pushed down passed these nodes - lp@ Sink {..} => { + lp @ Sink { .. } => { self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, false) - } - lp @ HStack {..} | lp @ Projection {..} | lp @ SimpleProjection {..} | lp @ ExtContext {..} => { + }, + lp @ HStack { .. } + | lp @ Projection { .. } + | lp @ SimpleProjection { .. } + | lp @ ExtContext { .. } => { self.pushdown_and_continue(lp, acc_predicates, lp_arena, expr_arena, true) - } + }, // NOT Pushed down passed these nodes // predicates influence slice sizes - lp @ Slice { .. } - // caches will be different - | lp @ Cache { .. } - => { + lp @ Slice { .. } => { self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) - } - lp @ HConcat { .. } - => { + }, + lp @ HConcat { .. } => { self.no_pushdown_restart_opt(lp, acc_predicates, lp_arena, expr_arena) - } + }, + // Caches will run predicate push-down in the `cache_states` run. + Cache { .. } => self.no_pushdown(lp, acc_predicates, lp_arena, expr_arena), #[cfg(feature = "python")] - PythonScan {mut options, predicate} => { + PythonScan { + mut options, + predicate, + } => { if options.pyarrow { let predicate = predicate_at_scan(acc_predicates, predicate, expr_arena); if let Some(predicate) = predicate.clone() { // simplify expressions before we translate them to pyarrow - let lp = PythonScan {options: options.clone(), predicate: Some(predicate)}; + let lp = PythonScan { + options: options.clone(), + predicate: Some(predicate), + }; let lp_top = lp_arena.add(lp); - let stack_opt = StackOptimizer{}; - let lp_top = stack_opt.optimize_loop(&mut [Box::new(SimplifyExprRule{})], expr_arena, lp_arena, lp_top).unwrap(); - let PythonScan {options: _, predicate: Some(predicate)} = lp_arena.take(lp_top) else {unreachable!()}; - - match super::super::pyarrow::predicate_to_pa(predicate.node(), expr_arena, Default::default()) { + let stack_opt = StackOptimizer {}; + let lp_top = stack_opt + .optimize_loop( + &mut [Box::new(SimplifyExprRule {})], + expr_arena, + lp_arena, + lp_top, + ) + .unwrap(); + let PythonScan { + options: _, + predicate: Some(predicate), + } = lp_arena.take(lp_top) + else { + unreachable!() + }; + + match super::super::pyarrow::predicate_to_pa( + predicate.node(), + expr_arena, + Default::default(), + ) { // we we able to create a pyarrow string, mutate the options - Some(eval_str) => { - options.predicate = Some(eval_str) - }, + Some(eval_str) => options.predicate = Some(eval_str), // we were not able to translate the predicate // apply here None => { - let lp = PythonScan { options, predicate: None }; - return Ok(self.optional_apply_predicate(lp, vec![predicate], lp_arena, expr_arena)) - } + let lp = PythonScan { + options, + predicate: None, + }; + return Ok(self.optional_apply_predicate( + lp, + vec![predicate], + lp_arena, + expr_arena, + )); + }, } } - Ok(PythonScan {options, predicate}) + Ok(PythonScan { options, predicate }) } else { - self.no_pushdown_restart_opt(PythonScan {options, predicate}, acc_predicates, lp_arena, expr_arena) + self.no_pushdown_restart_opt( + PythonScan { options, predicate }, + acc_predicates, + lp_arena, + expr_arena, + ) } }, Invalid => unreachable!(), diff --git a/crates/polars-plan/src/logical_plan/options.rs b/crates/polars-plan/src/logical_plan/options.rs index cb1325104fd3..628aacc4e7b0 100644 --- a/crates/polars-plan/src/logical_plan/options.rs +++ b/crates/polars-plan/src/logical_plan/options.rs @@ -23,7 +23,7 @@ use crate::prelude::python_udf::PythonFunction; pub type FileCount = u32; #[cfg(feature = "csv")] -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct CsvParserOptions { pub separator: u8, @@ -43,7 +43,7 @@ pub struct CsvParserOptions { } #[cfg(feature = "parquet")] -#[derive(Clone, Debug, PartialEq, Eq, Copy)] +#[derive(Clone, Debug, PartialEq, Eq, Copy, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct ParquetOptions { pub parallel: polars_io::parquet::ParallelStrategy, @@ -52,7 +52,7 @@ pub struct ParquetOptions { } #[cfg(feature = "parquet")] -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct ParquetWriteOptions { /// Data page compression @@ -68,7 +68,7 @@ pub struct ParquetWriteOptions { } #[cfg(feature = "ipc")] -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct IpcWriterOptions { /// Data page compression @@ -78,7 +78,7 @@ pub struct IpcWriterOptions { } #[cfg(feature = "csv")] -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct CsvWriterOptions { pub include_bom: bool, @@ -102,20 +102,20 @@ impl Default for CsvWriterOptions { } #[cfg(feature = "json")] -#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct JsonWriterOptions { /// maintain the order the data was processed pub maintain_order: bool, } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct IpcScanOptions { pub memmap: bool, } -#[derive(Clone, Debug, PartialEq, Eq, Default)] +#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] /// Generic options for all file types pub struct FileScanOptions { @@ -128,7 +128,7 @@ pub struct FileScanOptions { pub hive_partitioning: bool, } -#[derive(Clone, Debug, Copy, Default, Eq, PartialEq)] +#[derive(Clone, Debug, Copy, Default, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct UnionOptions { pub slice: Option<(i64, usize)>, @@ -140,13 +140,13 @@ pub struct UnionOptions { pub rechunk: bool, } -#[derive(Clone, Debug, Copy, Default, Eq, PartialEq)] +#[derive(Clone, Debug, Copy, Default, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct HConcatOptions { pub parallel: bool, } -#[derive(Clone, Debug, PartialEq, Eq, Default)] +#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct GroupbyOptions { #[cfg(feature = "dynamic_group_by")] @@ -157,7 +157,7 @@ pub struct GroupbyOptions { pub slice: Option<(i64, usize)>, } -#[derive(Clone, Debug, Eq, PartialEq, Default)] +#[derive(Clone, Debug, Eq, PartialEq, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct DistinctOptions { /// Subset of columns that will be taken into account. @@ -295,7 +295,7 @@ pub struct LogicalPlanUdfOptions { pub fmt_str: &'static str, } -#[derive(Clone, PartialEq, Eq, Debug, Default)] +#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct SortArguments { pub descending: Vec, @@ -320,7 +320,7 @@ pub struct PythonOptions { pub n_rows: Option, } -#[derive(Clone, PartialEq, Eq, Debug, Default)] +#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct AnonymousScanOptions { pub skip_rows: Option, @@ -328,7 +328,7 @@ pub struct AnonymousScanOptions { } #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum SinkType { Memory, File { @@ -351,7 +351,7 @@ pub struct FileSinkOptions { } #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum FileType { #[cfg(feature = "parquet")] Parquet(ParquetWriteOptions), @@ -364,7 +364,7 @@ pub enum FileType { } #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ProjectionOptions { pub run_parallel: bool, pub duplicate_check: bool, diff --git a/crates/polars-plan/src/logical_plan/visitor/expr.rs b/crates/polars-plan/src/logical_plan/visitor/expr.rs index 74f48f60592e..8faf4e67fcce 100644 --- a/crates/polars-plan/src/logical_plan/visitor/expr.rs +++ b/crates/polars-plan/src/logical_plan/visitor/expr.rs @@ -125,14 +125,14 @@ impl AexprNode { }) } - // traverses all nodes and does a full equality check - fn is_equal(&self, other: &Self, scratch1: &mut Vec, scratch2: &mut Vec) -> bool { + // Check single node on equality + fn is_equal(&self, other: &Self) -> bool { self.with_arena(|arena| { let self_ae = self.to_aexpr(); let other_ae = arena.get(other.node()); use AExpr::*; - let this_node_equal = match (self_ae, other_ae) { + match (self_ae, other_ae) { (Alias(_, l), Alias(_, r)) => l == r, (Column(l), Column(r)) => l == r, (Literal(l), Literal(r)) => l == r, @@ -174,30 +174,6 @@ impl AexprNode { (AnonymousFunction { .. }, AnonymousFunction { .. }) => false, (BinaryExpr { op: l, .. }, BinaryExpr { op: r, .. }) => l == r, _ => false, - }; - - if !this_node_equal { - return false; - } - - self_ae.nodes(scratch1); - other_ae.nodes(scratch2); - - loop { - match (scratch1.pop(), scratch2.pop()) { - (Some(l), Some(r)) => { - // SAFETY: we can pass a *mut pointer - // the equality operation will not access mutable - let l = unsafe { AexprNode::from_raw(l, self.arena) }; - let r = unsafe { AexprNode::from_raw(r, self.arena) }; - - if !l.is_equal(&r, scratch1, scratch2) { - return false; - } - }, - (None, None) => return true, - _ => return false, - } } }) } @@ -212,7 +188,29 @@ impl PartialEq for AexprNode { fn eq(&self, other: &Self) -> bool { let mut scratch1 = vec![]; let mut scratch2 = vec![]; - self.is_equal(other, &mut scratch1, &mut scratch2) + + scratch1.push(self.node); + scratch2.push(other.node); + + loop { + match (scratch1.pop(), scratch2.pop()) { + (Some(l), Some(r)) => { + // SAFETY: we can pass a *mut pointer + // the equality operation will not access mutable + let l = unsafe { AexprNode::from_raw(l, self.arena) }; + let r = unsafe { AexprNode::from_raw(r, self.arena) }; + + if !l.is_equal(&r) { + return false; + } + + l.to_aexpr().nodes(&mut scratch1); + r.to_aexpr().nodes(&mut scratch2); + }, + (None, None) => return true, + _ => return false, + } + } } } diff --git a/crates/polars-plan/src/logical_plan/visitor/hash.rs b/crates/polars-plan/src/logical_plan/visitor/hash.rs new file mode 100644 index 000000000000..41ecb5a624f7 --- /dev/null +++ b/crates/polars-plan/src/logical_plan/visitor/hash.rs @@ -0,0 +1,518 @@ +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use polars_utils::arena::Arena; + +use super::*; +use crate::logical_plan::{AExpr, ALogicalPlan}; +use crate::prelude::aexpr::traverse_and_hash_aexpr; +use crate::prelude::ExprIR; + +impl ALogicalPlanNode { + pub(crate) fn hashable_and_cmp<'a>(&'a self, expr_arena: &'a Arena) -> HashableEqLP<'a> { + HashableEqLP { + node: self, + expr_arena, + ignore_cache: false, + } + } +} + +pub(crate) struct HashableEqLP<'a> { + node: &'a ALogicalPlanNode, + expr_arena: &'a Arena, + ignore_cache: bool, +} + +impl HashableEqLP<'_> { + /// When encountering a Cache node, ignore it and take the input. + #[cfg(feature = "cse")] + pub(crate) fn ignore_caches(mut self) -> Self { + self.ignore_cache = true; + self + } +} + +fn hash_option_expr(expr: &Option, expr_arena: &Arena, state: &mut H) { + if let Some(e) = expr { + e.traverse_and_hash(expr_arena, state) + } +} + +fn hash_exprs(exprs: &[ExprIR], expr_arena: &Arena, state: &mut H) { + for e in exprs { + e.traverse_and_hash(expr_arena, state); + } +} + +impl Hash for HashableEqLP<'_> { + // This hashes the variant, not the whole plan + fn hash(&self, state: &mut H) { + let alp = self.node.to_alp(); + std::mem::discriminant(alp).hash(state); + match alp { + #[cfg(feature = "python")] + ALogicalPlan::PythonScan { .. } => {}, + ALogicalPlan::Slice { + offset, + len, + input: _, + } => { + len.hash(state); + offset.hash(state); + }, + ALogicalPlan::Selection { + input: _, + predicate, + } => { + predicate.traverse_and_hash(self.expr_arena, state); + }, + ALogicalPlan::Scan { + paths, + file_info: _, + predicate, + output_schema: _, + scan_type, + file_options, + } => { + // We don't have to traverse the schema, hive partitions etc. as they are derivative from the paths. + scan_type.hash(state); + paths.hash(state); + hash_option_expr(predicate, self.expr_arena, state); + file_options.hash(state); + }, + ALogicalPlan::DataFrameScan { + df, + schema: _, + output_schema: _, + projection, + selection, + } => { + (Arc::as_ptr(df) as usize).hash(state); + projection.hash(state); + hash_option_expr(selection, self.expr_arena, state); + }, + ALogicalPlan::SimpleProjection { + columns, + duplicate_check, + input: _, + } => { + columns.hash(state); + duplicate_check.hash(state); + }, + ALogicalPlan::Projection { + input: _, + expr, + schema: _, + options, + } => { + hash_exprs(expr.default_exprs(), self.expr_arena, state); + options.hash(state); + }, + ALogicalPlan::Sort { + input: _, + by_column, + args, + } => { + hash_exprs(by_column, self.expr_arena, state); + args.hash(state); + }, + ALogicalPlan::Aggregate { + input: _, + keys, + aggs, + schema: _, + apply, + maintain_order, + options, + } => { + hash_exprs(keys, self.expr_arena, state); + hash_exprs(aggs, self.expr_arena, state); + apply.is_none().hash(state); + maintain_order.hash(state); + options.hash(state); + }, + ALogicalPlan::Join { + input_left: _, + input_right: _, + schema: _, + left_on, + right_on, + options, + } => { + hash_exprs(left_on, self.expr_arena, state); + hash_exprs(right_on, self.expr_arena, state); + options.hash(state); + }, + ALogicalPlan::HStack { + input: _, + exprs, + schema: _, + options, + } => { + hash_exprs(exprs.default_exprs(), self.expr_arena, state); + options.hash(state); + }, + ALogicalPlan::Distinct { input: _, options } => { + options.hash(state); + }, + ALogicalPlan::MapFunction { input: _, function } => { + function.hash(state); + }, + ALogicalPlan::Union { inputs: _, options } => options.hash(state), + ALogicalPlan::HConcat { + inputs: _, + schema: _, + options, + } => { + options.hash(state); + }, + ALogicalPlan::ExtContext { + input: _, + contexts, + schema: _, + } => { + for node in contexts { + traverse_and_hash_aexpr(*node, self.expr_arena, state); + } + }, + ALogicalPlan::Sink { input: _, payload } => { + payload.hash(state); + }, + ALogicalPlan::Cache { + input: _, + id, + count, + } => { + id.hash(state); + count.hash(state); + }, + ALogicalPlan::Invalid => unreachable!(), + } + } +} + +fn expr_irs_eq(l: &[ExprIR], r: &[ExprIR], expr_arena: &Arena) -> bool { + l.len() == r.len() && l.iter().zip(r).all(|(l, r)| expr_ir_eq(l, r, expr_arena)) +} + +fn expr_ir_eq(l: &ExprIR, r: &ExprIR, expr_arena: &Arena) -> bool { + l.get_alias() == r.get_alias() && { + let expr_arena = expr_arena as *const _ as *mut _; + unsafe { + let l = AexprNode::from_raw(l.node(), expr_arena); + let r = AexprNode::from_raw(r.node(), expr_arena); + l == r + } + } +} + +fn opt_expr_ir_eq(l: &Option, r: &Option, expr_arena: &Arena) -> bool { + match (l, r) { + (None, None) => true, + (Some(l), Some(r)) => expr_ir_eq(l, r, expr_arena), + _ => false, + } +} + +impl HashableEqLP<'_> { + fn is_equal(&self, other: &Self) -> bool { + let alp_l = self.node.to_alp(); + let alp_r = other.node.to_alp(); + if std::mem::discriminant(alp_l) != std::mem::discriminant(alp_r) { + return false; + } + match (alp_l, alp_r) { + ( + ALogicalPlan::Slice { + input: _, + offset: ol, + len: ll, + }, + ALogicalPlan::Slice { + input: _, + offset: or, + len: lr, + }, + ) => ol == or && ll == lr, + ( + ALogicalPlan::Selection { + input: _, + predicate: l, + }, + ALogicalPlan::Selection { + input: _, + predicate: r, + }, + ) => expr_ir_eq(l, r, self.expr_arena), + ( + ALogicalPlan::Scan { + paths: pl, + file_info: _, + predicate: pred_l, + output_schema: _, + scan_type: stl, + file_options: ol, + }, + ALogicalPlan::Scan { + paths: pr, + file_info: _, + predicate: pred_r, + output_schema: _, + scan_type: str, + file_options: or, + }, + ) => { + pl == pr + && stl == str + && ol == or + && opt_expr_ir_eq(pred_l, pred_r, self.expr_arena) + }, + ( + ALogicalPlan::DataFrameScan { + df: dfl, + schema: _, + output_schema: _, + projection: pl, + selection: sl, + }, + ALogicalPlan::DataFrameScan { + df: dfr, + schema: _, + output_schema: _, + projection: pr, + selection: sr, + }, + ) => { + Arc::as_ptr(dfl) == Arc::as_ptr(dfr) + && pl == pr + && opt_expr_ir_eq(sl, sr, self.expr_arena) + }, + ( + ALogicalPlan::SimpleProjection { + input: _, + columns: cl, + duplicate_check: dl, + }, + ALogicalPlan::SimpleProjection { + input: _, + columns: cr, + duplicate_check: dr, + }, + ) => dl == dr && cl == cr, + ( + ALogicalPlan::Projection { + input: _, + expr: el, + options: ol, + schema: _, + }, + ALogicalPlan::Projection { + input: _, + expr: er, + options: or, + schema: _, + }, + ) => ol == or && expr_irs_eq(el.default_exprs(), er.default_exprs(), self.expr_arena), + ( + ALogicalPlan::Sort { + input: _, + by_column: cl, + args: al, + }, + ALogicalPlan::Sort { + input: _, + by_column: cr, + args: ar, + }, + ) => al == ar && expr_irs_eq(cl, cr, self.expr_arena), + ( + ALogicalPlan::Aggregate { + input: _, + keys: keys_l, + aggs: aggs_l, + schema: _, + apply: apply_l, + maintain_order: maintain_l, + options: ol, + }, + ALogicalPlan::Aggregate { + input: _, + keys: keys_r, + aggs: aggs_r, + schema: _, + apply: apply_r, + maintain_order: maintain_r, + options: or, + }, + ) => { + apply_l.is_none() + && apply_r.is_none() + && ol == or + && maintain_l == maintain_r + && expr_irs_eq(keys_l, keys_r, self.expr_arena) + && expr_irs_eq(aggs_l, aggs_r, self.expr_arena) + }, + ( + ALogicalPlan::Join { + input_left: _, + input_right: _, + schema: _, + left_on: ll, + right_on: rl, + options: ol, + }, + ALogicalPlan::Join { + input_left: _, + input_right: _, + schema: _, + left_on: lr, + right_on: rr, + options: or, + }, + ) => { + ol == or + && expr_irs_eq(ll, lr, self.expr_arena) + && expr_irs_eq(rl, rr, self.expr_arena) + }, + ( + ALogicalPlan::HStack { + input: _, + exprs: el, + schema: _, + options: ol, + }, + ALogicalPlan::HStack { + input: _, + exprs: er, + schema: _, + options: or, + }, + ) => ol == or && expr_irs_eq(el.default_exprs(), er.default_exprs(), self.expr_arena), + ( + ALogicalPlan::Distinct { + input: _, + options: ol, + }, + ALogicalPlan::Distinct { + input: _, + options: or, + }, + ) => ol == or, + ( + ALogicalPlan::MapFunction { + input: _, + function: l, + }, + ALogicalPlan::MapFunction { + input: _, + function: r, + }, + ) => l == r, + ( + ALogicalPlan::Union { + inputs: _, + options: l, + }, + ALogicalPlan::Union { + inputs: _, + options: r, + }, + ) => l == r, + ( + ALogicalPlan::HConcat { + inputs: _, + schema: _, + options: l, + }, + ALogicalPlan::HConcat { + inputs: _, + schema: _, + options: r, + }, + ) => l == r, + ( + ALogicalPlan::ExtContext { + input: _, + contexts: l, + schema: _, + }, + ALogicalPlan::ExtContext { + input: _, + contexts: r, + schema: _, + }, + ) => { + l.len() == r.len() + && l.iter().zip(r.iter()).all(|(l, r)| { + let expr_arena = self.expr_arena as *const _ as *mut _; + unsafe { + let l = AexprNode::from_raw(*l, expr_arena); + let r = AexprNode::from_raw(*r, expr_arena); + l == r + } + }) + }, + _ => false, + } + } +} + +impl PartialEq for HashableEqLP<'_> { + fn eq(&self, other: &Self) -> bool { + let mut scratch_1 = vec![]; + let mut scratch_2 = vec![]; + + scratch_1.push(self.node.node()); + scratch_2.push(other.node.node()); + + loop { + match (scratch_1.pop(), scratch_2.pop()) { + (Some(l), Some(r)) => { + // SAFETY: we can pass a *mut pointer + // the equality operation will not access mutable + let l = unsafe { ALogicalPlanNode::from_raw(l, self.node.get_arena_raw()) }; + let r = unsafe { ALogicalPlanNode::from_raw(r, self.node.get_arena_raw()) }; + let l_alp = l.to_alp(); + let r_alp = r.to_alp(); + + if self.ignore_cache { + match (l_alp, r_alp) { + ( + ALogicalPlan::Cache { input: l, .. }, + ALogicalPlan::Cache { input: r, .. }, + ) => { + scratch_1.push(*l); + scratch_2.push(*r); + continue; + }, + (ALogicalPlan::Cache { input: l, .. }, _) => { + scratch_1.push(*l); + scratch_2.push(r.node()); + continue; + }, + (_, ALogicalPlan::Cache { input: r, .. }) => { + scratch_1.push(l.node()); + scratch_2.push(*r); + continue; + }, + _ => {}, + } + } + + if !l + .hashable_and_cmp(self.expr_arena) + .is_equal(&r.hashable_and_cmp(self.expr_arena)) + { + return false; + } + + l.to_alp().copy_inputs(&mut scratch_1); + r.to_alp().copy_inputs(&mut scratch_2); + }, + (None, None) => return true, + _ => return false, + } + } + } +} diff --git a/crates/polars-plan/src/logical_plan/visitor/lp.rs b/crates/polars-plan/src/logical_plan/visitor/lp.rs index b8f2197d169d..2a959e489027 100644 --- a/crates/polars-plan/src/logical_plan/visitor/lp.rs +++ b/crates/polars-plan/src/logical_plan/visitor/lp.rs @@ -6,6 +6,7 @@ use polars_utils::unitvec; use super::*; use crate::prelude::*; +#[derive(Copy, Clone, Debug)] pub struct ALogicalPlanNode { node: Node, arena: *mut Arena, @@ -26,6 +27,10 @@ impl ALogicalPlanNode { Self { node, arena } } + pub(crate) fn get_arena_raw(&self) -> *mut Arena { + self.arena + } + /// Safe interface. Take the `&mut Arena` only for the duration of `op`. pub fn with_context(node: Node, arena: &mut Arena, mut op: F) -> T where @@ -63,6 +68,10 @@ impl ALogicalPlanNode { self.node = node } + pub fn replace_node(&mut self, node: Node) { + self.node = node; + } + /// Replace the current `Node` with a new `ALogicalPlan`. pub fn replace(&mut self, ae: ALogicalPlan) { let node = self.node; diff --git a/crates/polars-plan/src/logical_plan/visitor/mod.rs b/crates/polars-plan/src/logical_plan/visitor/mod.rs index acce08ad034e..149855970981 100644 --- a/crates/polars-plan/src/logical_plan/visitor/mod.rs +++ b/crates/polars-plan/src/logical_plan/visitor/mod.rs @@ -2,6 +2,7 @@ use arrow::legacy::error::PolarsResult; mod expr; +mod hash; mod lp; mod visitors; diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index ef449b62e2da..e67fe38c12a2 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -17,7 +17,7 @@ use crate::prelude::*; #[repr(transparent)] struct Wrap(pub T); -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct DynamicGroupOptions { /// Time or index column.