Skip to content

Commit

Permalink
feat: Full plan CSE (#15264)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 26, 2024
1 parent 705b148 commit a9f4df3
Show file tree
Hide file tree
Showing 43 changed files with 1,882 additions and 1,155 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/frame/explode.rs
Expand Up @@ -21,7 +21,7 @@ fn get_exploded(series: &Series) -> PolarsResult<(Series, OffsetsBuffer<i64>)> {
}

/// Arguments for `[DataFrame::melt]` function
#[derive(Clone, Default, Debug, PartialEq)]
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))]
pub struct MeltArgs {
pub id_vars: Vec<SmartString>,
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/frame/mod.rs
Expand Up @@ -41,7 +41,7 @@ pub enum NullStrategy {
Propagate,
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum UniqueKeepStrategy {
/// Keep the first unique row.
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-core/src/schema.rs
@@ -1,4 +1,5 @@
use std::fmt::{Debug, Formatter};
use std::hash::{Hash, Hasher};

use arrow::datatypes::ArrowSchemaRef;
use indexmap::map::MutableKeys;
Expand All @@ -17,6 +18,12 @@ pub struct Schema {
inner: PlIndexMap<SmartString, DataType>,
}

impl Hash for Schema {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.iter().for_each(|v| v.hash(state))
}
}

// Schemas will only compare equal if they have the same fields in the same order. We can't use `self.inner ==
// other.inner` because [`IndexMap`] ignores order when checking equality, but we don't want to ignore it.
impl PartialEq for Schema {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/cloud/options.rs
Expand Up @@ -52,7 +52,7 @@ static BUCKET_REGION: Lazy<std::sync::Mutex<FastFixedCache<SmartString, SmartStr
#[allow(dead_code)]
type Configs<T> = Vec<(T, String)>;

#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Hash, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
/// Options to connect to various cloud providers.
pub struct CloudOptions {
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-io/src/csv/read.rs
Expand Up @@ -5,7 +5,7 @@ use crate::csv::read_impl::{
};
use crate::csv::utils::infer_file_schema;

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum CsvEncoding {
/// Utf8 encoding
Expand All @@ -14,7 +14,7 @@ pub enum CsvEncoding {
LossyUtf8,
}

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum NullValues {
/// A single value that's used for all columns
Expand All @@ -25,7 +25,7 @@ pub enum NullValues {
Named(Vec<(String, String)>),
}

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum CommentPrefix {
/// A single byte character that indicates the start of a comment line.
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/options.rs
Expand Up @@ -2,7 +2,7 @@ use polars_utils::IdxSize;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RowIndex {
pub name: String,
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/parquet/read.rs
Expand Up @@ -21,7 +21,7 @@ use crate::predicates::PhysicalIoExpr;
use crate::prelude::*;
use crate::RowIndex;

#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ParallelStrategy {
/// Don't parallelize
Expand Down
2 changes: 0 additions & 2 deletions crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs
Expand Up @@ -4,7 +4,6 @@ use std::sync::RwLock;

use polars_core::config;
use polars_core::utils::accumulate_dataframes_vertical;
#[cfg(feature = "cloud")]
use polars_io::cloud::CloudOptions;
use polars_io::predicates::apply_predicate;
use polars_io::{is_cloud_url, RowIndex};
Expand All @@ -18,7 +17,6 @@ pub struct IpcExec {
pub(crate) predicate: Option<Arc<dyn PhysicalExpr>>,
pub(crate) options: IpcScanOptions,
pub(crate) file_options: FileScanOptions,
#[cfg(feature = "cloud")]
pub(crate) cloud_options: Option<CloudOptions>,
pub(crate) metadata: Option<arrow::io::ipc::read::FileMetadata>,
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/physical_plan/expressions/column.rs
Expand Up @@ -93,7 +93,7 @@ impl ColumnExpr {
&& _state.ext_contexts.is_empty()
&& std::env::var("POLARS_NO_SCHEMA_CHECK").is_err()
{
panic!("invalid schema: df {:?}; column: {}", df, &self.name)
panic!("invalid schema: df {:?};\ncolumn: {}", df, &self.name)
}
}
// in release we fallback to linear search
Expand Down
2 changes: 0 additions & 2 deletions crates/polars-lazy/src/physical_plan/planner/lp.rs
Expand Up @@ -229,7 +229,6 @@ pub fn create_physical_plan(
#[cfg(feature = "ipc")]
FileScan::Ipc {
options,
#[cfg(feature = "cloud")]
cloud_options,
metadata,
} => Ok(Box::new(executors::IpcExec {
Expand All @@ -238,7 +237,6 @@ pub fn create_physical_plan(
predicate,
options,
file_options,
#[cfg(feature = "cloud")]
cloud_options,
metadata,
})),
Expand Down
3 changes: 0 additions & 3 deletions crates/polars-lazy/src/scan/ipc.rs
Expand Up @@ -13,7 +13,6 @@ pub struct ScanArgsIpc {
pub rechunk: bool,
pub row_index: Option<RowIndex>,
pub memmap: bool,
#[cfg(feature = "cloud")]
pub cloud_options: Option<CloudOptions>,
}

Expand All @@ -25,7 +24,6 @@ impl Default for ScanArgsIpc {
rechunk: false,
row_index: None,
memmap: true,
#[cfg(feature = "cloud")]
cloud_options: Default::default(),
}
}
Expand Down Expand Up @@ -79,7 +77,6 @@ impl LazyFileListReader for LazyIpcReader {
args.cache,
args.row_index,
args.rechunk,
#[cfg(feature = "cloud")]
args.cloud_options,
)?
.build()
Expand Down
26 changes: 21 additions & 5 deletions crates/polars-lazy/src/tests/cse.rs
Expand Up @@ -45,9 +45,14 @@ fn test_cse_unions() -> PolarsResult<()> {

let (mut expr_arena, mut lp_arena) = get_arenas();
let lp = lf.clone().optimize(&mut lp_arena, &mut expr_arena).unwrap();
let mut cache_count = 0;
assert!((&lp_arena).iter(lp).all(|(_, lp)| {
use ALogicalPlan::*;
match lp {
Cache { .. } => {
cache_count += 1;
true
},
Scan { file_options, .. } => {
if let Some(columns) = &file_options.with_columns {
columns.len() == 2
Expand All @@ -58,6 +63,7 @@ fn test_cse_unions() -> PolarsResult<()> {
_ => true,
}
}));
assert_eq!(cache_count, 2);
let out = lf.collect()?;
assert_eq!(out.get_column_names(), &["category", "fats_g"]);

Expand All @@ -82,17 +88,23 @@ fn test_cse_cache_union_projection_pd() -> PolarsResult<()> {
// check that the projection of a is not done before the cache
let (mut expr_arena, mut lp_arena) = get_arenas();
let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap();
let mut cache_count = 0;
assert!((&lp_arena).iter(lp).all(|(_, lp)| {
use ALogicalPlan::*;
match lp {
Cache { .. } => {
cache_count += 1;
true
},
DataFrameScan {
projection: Some(projection),
..
} => projection.as_ref() == &vec!["a".to_string(), "b".to_string()],
} => projection.as_ref().len() <= 2,
DataFrameScan { .. } => false,
_ => true,
}
}));
assert_eq!(cache_count, 2);

Ok(())
}
Expand Down Expand Up @@ -189,7 +201,9 @@ fn test_cse_joins_4954() -> PolarsResult<()> {
.flat_map(|(_, lp)| {
use ALogicalPlan::*;
match lp {
Cache { id, count, input } => {
Cache {
id, count, input, ..
} => {
assert_eq!(*count, 1);
assert!(matches!(
lp_arena.get(*input),
Expand Down Expand Up @@ -242,11 +256,13 @@ fn test_cache_with_partial_projection() -> PolarsResult<()> {
JoinType::Semi.into(),
);

let q = q.with_comm_subplan_elim(true);

let (mut expr_arena, mut lp_arena) = get_arenas();
let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap();

// EDIT: #15264 this originally
// tested 2 caches, but we cannot do that after #15264 due to projection pushdown
// running first and the cache semantics changing, so now we test 1. Maybe we can improve later.

// ensure we get two different caches
// and ensure that every cache only has 1 hit.
let cache_ids = (&lp_arena)
Expand All @@ -259,7 +275,7 @@ fn test_cache_with_partial_projection() -> PolarsResult<()> {
}
})
.collect::<BTreeSet<_>>();
assert_eq!(cache_ids.len(), 2);
assert_eq!(cache_ids.len(), 1);

Ok(())
}
Expand Down
1 change: 0 additions & 1 deletion crates/polars-lazy/src/tests/io.rs
Expand Up @@ -416,7 +416,6 @@ fn test_ipc_globbing() -> PolarsResult<()> {
rechunk: false,
row_index: None,
memmap: true,
#[cfg(feature = "cloud")]
cloud_options: None,
},
)?
Expand Down
9 changes: 9 additions & 0 deletions crates/polars-lazy/src/tests/tpch.rs
Expand Up @@ -85,6 +85,15 @@ fn test_q2() -> PolarsResult<()> {
.limit(100)
.with_comm_subplan_elim(true);

let (node, lp_arena, _) = q.clone().to_alp_optimized().unwrap();
assert_eq!(
(&lp_arena)
.iter(node)
.filter(|(_, alp)| matches!(alp, ALogicalPlan::Cache { .. }))
.count(),
2
);

let out = q.collect()?;
let schema = Schema::from_iter([
Field::new("s_acctbal", DataType::Float64),
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-ops/src/frame/join/args.rs
Expand Up @@ -18,7 +18,7 @@ pub type ChunkJoinIds = Vec<IdxSize>;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

#[derive(Clone, PartialEq, Eq, Debug)]
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct JoinArgs {
pub how: JoinType,
Expand Down Expand Up @@ -56,7 +56,7 @@ impl JoinArgs {
}
}

#[derive(Clone, PartialEq, Eq)]
#[derive(Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum JoinType {
Left,
Expand Down Expand Up @@ -116,7 +116,7 @@ impl Debug for JoinType {
}
}

#[derive(Copy, Clone, PartialEq, Eq, Default)]
#[derive(Copy, Clone, PartialEq, Eq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum JoinValidation {
/// No unique checks
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-ops/src/frame/join/asof/mod.rs
Expand Up @@ -142,7 +142,7 @@ impl<T: NumericNative> AsofJoinState<T> for AsofJoinNearestState {
}
}

#[derive(Clone, Debug, PartialEq, Eq, Default)]
#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct AsOfOptions {
pub strategy: AsofStrategy,
Expand Down Expand Up @@ -191,7 +191,7 @@ fn check_asof_columns(
Ok(())
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum AsofStrategy {
/// selects the last row in the right DataFrame whose ‘on’ key is less than or equal to the left’s key
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/options.rs
Expand Up @@ -38,7 +38,7 @@ impl Default for StrptimeOptions {
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct JoinOptions {
pub allow_parallel: bool,
Expand Down
13 changes: 13 additions & 0 deletions crates/polars-plan/src/logical_plan/aexpr/hash.rs
@@ -1,5 +1,8 @@
use std::hash::{Hash, Hasher};

use polars_utils::arena::{Arena, Node};

use crate::logical_plan::ArenaExprIter;
use crate::prelude::AExpr;

impl Hash for AExpr {
Expand Down Expand Up @@ -30,3 +33,13 @@ impl Hash for AExpr {
}
}
}

pub(crate) fn traverse_and_hash_aexpr<H: Hasher>(
node: Node,
expr_arena: &Arena<AExpr>,
state: &mut H,
) {
for (_, ae) in expr_arena.iter(node) {
ae.hash(state);
}
}
15 changes: 1 addition & 14 deletions crates/polars-plan/src/logical_plan/aexpr/mod.rs
Expand Up @@ -3,14 +3,13 @@ mod schema;

use std::hash::{Hash, Hasher};

pub(super) use hash::traverse_and_hash_aexpr;
use polars_core::prelude::*;
use polars_core::utils::{get_time_units, try_get_supertype};
use polars_utils::arena::{Arena, Node};
use strum_macros::IntoStaticStr;

use crate::constants::LEN;
#[cfg(feature = "cse")]
use crate::logical_plan::visitor::AexprNode;
use crate::logical_plan::Context;
use crate::prelude::*;

Expand Down Expand Up @@ -189,18 +188,6 @@ pub enum AExpr {
}

impl AExpr {
#[cfg(feature = "cse")]
pub(crate) fn is_equal(l: Node, r: Node, arena: &Arena<AExpr>) -> bool {
let arena = arena as *const Arena<AExpr> as *mut Arena<AExpr>;
// SAFETY: we can pass a *mut pointer
// the equality operation will not access mutable
unsafe {
let ae_node_l = AexprNode::from_raw(l, arena);
let ae_node_r = AexprNode::from_raw(r, arena);
ae_node_l == ae_node_r
}
}

#[cfg(feature = "cse")]
pub(crate) fn col(name: &str) -> Self {
AExpr::Column(ColumnName::from(name))
Expand Down

0 comments on commit a9f4df3

Please sign in to comment.