diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d93c4a5bc459a..8c1ee17824f91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -709,7 +709,8 @@ object CollapseProject extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, p2: Project) => - if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { + if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) || + hasOversizedRepeatedAliases(p1.projectList, p2.projectList)) { p1 } else { p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) @@ -753,6 +754,28 @@ object CollapseProject extends Rule[LogicalPlan] { }.exists(!_.deterministic)) } + private def hasOversizedRepeatedAliases( + upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { + val aliases = collectAliases(lower) + + // Count how many times each alias is used in the upper Project. + // If an alias is only used once, we can safely substitute it without increasing the overall + // tree size + val referenceCounts = AttributeMap( + upper + .flatMap(_.collect { case a: Attribute => a }) + .groupBy(identity) + .mapValues(_.size).toSeq + ) + + // Check for any aliases that are used more than once, and are larger than the configured + // maximum size + aliases.exists({ case (attribute, expression) => + referenceCounts.getOrElse(attribute, 0) > 1 && + expression.treeSize > SQLConf.get.maxRepeatedAliasSize + }) + } + private def buildCleanedProjectList( upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Seq[NamedExpression] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 415ce46788119..c558414d9998e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf trait OperationHelper { type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan) @@ -62,6 +63,26 @@ object PhysicalOperation extends OperationHelper with PredicateHelper { Some((fields.getOrElse(child.output), filters, child)) } + private def hasOversizedRepeatedAliases(fields: Seq[Expression], + aliases: Map[Attribute, Expression]): Boolean = { + // Count how many times each alias is used in the fields. + // If an alias is only used once, we can safely substitute it without increasing the overall + // tree size + val referenceCounts = AttributeMap( + fields + .flatMap(_.collect { case a: Attribute => a }) + .groupBy(identity) + .mapValues(_.size).toSeq + ) + + // Check for any aliases that are used more than once, and are larger than the configured + // maximum size + aliases.exists({ case (attribute, expression) => + referenceCounts.getOrElse(attribute, 0) > 1 && + expression.treeSize > SQLConf.get.maxRepeatedAliasSize + }) + } + /** * Collects all deterministic projects and filters, in-lining/substituting aliases if necessary. * Here are two examples for alias in-lining/substitution. @@ -81,8 +102,13 @@ object PhysicalOperation extends OperationHelper with PredicateHelper { plan match { case Project(fields, child) if fields.forall(_.deterministic) => val (_, filters, other, aliases) = collectProjectsAndFilters(child) - val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] - (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) + if (hasOversizedRepeatedAliases(fields, aliases)) { + // Skip substitution if it could overly increase the overall tree size and risk OOMs + (None, Nil, plan, AttributeMap(Nil)) + } else { + val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] + (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) + } case Filter(condition, child) if condition.deterministic => val (fields, filters, other, aliases) = collectProjectsAndFilters(child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index c4a106702a515..9ba385659c877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -114,6 +114,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { lazy val containsChild: Set[TreeNode[_]] = children.toSet + lazy val treeSize: Long = children.map(_.treeSize).sum + 1 + // Copied from Scala 2.13.1 // github.com/scala/scala/blob/v2.13.1/src/library/scala/util/hashing/MurmurHash3.scala#L56-L73 // to prevent the issue https://github.com/scala/bug/issues/10495 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0dbc6d4fdcad3..a26fc4f766159 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2340,6 +2340,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val MAX_REPEATED_ALIAS_SIZE = + buildConf("spark.sql.maxRepeatedAliasSize") + .internal() + .doc("The maximum size of alias expression that will be substituted multiple times " + + "(size defined by the number of nodes in the expression tree). " + + "Used by the CollapseProject optimizer, and PhysicalOperation.") + .intConf + .createWithDefault(100) + val SOURCES_BINARY_FILE_MAX_LENGTH = buildConf("spark.sql.sources.binaryFile.maxLength") .doc("The max length of a file that can be read by the binary file data source. " + "Spark will fail fast and not attempt to read the file if its length exceeds this value. " + @@ -3154,6 +3163,8 @@ class SQLConf extends Serializable with Logging { def setCommandRejectsSparkCoreConfs: Boolean = getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS) + def maxRepeatedAliasSize: Int = getConf(SQLConf.MAX_REPEATED_ALIAS_SIZE) + def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING) def ignoreDataLocality: Boolean = getConf(SQLConf.IGNORE_DATA_LOCALITY) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 42bcd13ee378d..5803f35d27fc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -170,4 +170,23 @@ class CollapseProjectSuite extends PlanTest { val expected = Sample(0.0, 0.6, false, 11L, relation.select('a as 'c)).analyze comparePlans(optimized, expected) } + + + test("ensure oversize aliases are not repeatedly substituted") { + var query: LogicalPlan = testRelation + for( a <- 1 to 100) { + query = query.select(('a + 'b).as('a), ('a - 'b).as('b)) + } + val projects = Optimize.execute(query.analyze).collect { case p: Project => p } + assert(projects.size >= 12) + } + + test("ensure oversize aliases are still substituted once") { + var query: LogicalPlan = testRelation + for( a <- 1 to 20) { + query = query.select(('a + 'b).as('a), 'b) + } + val projects = Optimize.execute(query.analyze).collect { case p: Project => p } + assert(projects.size === 1) + } }