diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index 12ff13f8f6ae..a136285806ca 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -24,11 +24,14 @@ use datafusion_common::error::Result; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::ExecutionPlan; use crate::PhysicalOptimizerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::{physical_exprs_equal, PhysicalExpr}; @@ -85,7 +88,10 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { input_agg_exec.aggr_expr(), input_agg_exec.filter_expr(), ), - ) { + ) + // Don't combine if input has multiple partitions - preserve distributed aggregation + && !has_multi_partition_coalesce(input_agg_exec.input()) + { let mode = if agg_exec.mode() == &AggregateMode::Final { AggregateMode::Single } else { @@ -161,4 +167,26 @@ fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool { ) } +/// Check if the plan subtree contains a CoalescePartitionsExec with multiple input partitions. +fn has_multi_partition_coalesce(plan: &Arc) -> bool { + plan.apply(|node| { + // Check if this node is CoalescePartitionsExec with multiple inputs + if node.as_any().is::() { + let partition_count = node + .children() + .first() + .map(|child| child.properties().partitioning.partition_count()) + .unwrap_or(0); + if partition_count > 1 { + // Found a multi-partition coalesce, stop traversal + return Ok(TreeNodeRecursion::Stop); + } + } + // Continue traversing children + Ok(TreeNodeRecursion::Continue) + }) + .map(|recursion| matches!(recursion, TreeNodeRecursion::Stop)) + .unwrap_or(false) +} + // See tests in datafusion/core/tests/physical_optimizer diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 97e719111a5b..ec45b2d9486f 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -279,15 +279,14 @@ impl<'a> DFParser<'a> { sql: &str, dialect: &'a dyn Dialect, ) -> Result { - let tokens = Tokenizer::new(dialect, sql).into_tokens().collect::>()?; + let tokens = Tokenizer::new(dialect, sql) + .into_tokens() + .collect::>()?; Ok(Self::from_dialect_and_tokens(dialect, tokens)) } /// Create a new parser from specified dialect and tokens. - pub fn from_dialect_and_tokens( - dialect: &'a dyn Dialect, - tokens: Vec, - ) -> Self { + pub fn from_dialect_and_tokens(dialect: &'a dyn Dialect, tokens: Vec) -> Self { let parser = Parser::new(dialect).with_tokens(tokens); DFParser { parser } }