Skip to content

Commit

Permalink
[SPARK-26626][SQL] Maximum size for repeatedly substituted aliases in…
Browse files Browse the repository at this point in the history
… SQL expressions

We have internal applications (BS and C) prone to OOMs with repeated use of
aliases. See ticket [1] and upstream PR [2].

[1] https://issues.apache.org/jira/browse/SPARK-26626
[2] apache#23556

Co-authored-by: j-esse <j-esse@users.noreply.github.com>
Co-authored-by: Josh Casale <jcasale@palantir.com>
Co-authored-by: Will Raschkowski <wraschkowski@palantir.com>
  • Loading branch information
3 people committed Feb 23, 2021
1 parent 4d2d123 commit 0f54fb0
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ object CollapseProject extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p1 @ Project(_, p2: Project) =>
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) ||
hasOversizedRepeatedAliases(p1.projectList, p2.projectList)) {
p1
} else {
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
Expand Down Expand Up @@ -753,6 +754,28 @@ object CollapseProject extends Rule[LogicalPlan] {
}.exists(!_.deterministic))
}

private def hasOversizedRepeatedAliases(
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
val aliases = collectAliases(lower)

// Count how many times each alias is used in the upper Project.
// If an alias is only used once, we can safely substitute it without increasing the overall
// tree size
val referenceCounts = AttributeMap(
upper
.flatMap(_.collect { case a: Attribute => a })
.groupBy(identity)
.mapValues(_.size).toSeq
)

// Check for any aliases that are used more than once, and are larger than the configured
// maximum size
aliases.exists({ case (attribute, expression) =>
referenceCounts.getOrElse(attribute, 0) > 1 &&
expression.treeSize > SQLConf.get.maxRepeatedAliasSize
})
}

private def buildCleanedProjectList(
upper: Seq[NamedExpression],
lower: Seq[NamedExpression]): Seq[NamedExpression] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

trait OperationHelper {
type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan)
Expand Down Expand Up @@ -62,6 +63,26 @@ object PhysicalOperation extends OperationHelper with PredicateHelper {
Some((fields.getOrElse(child.output), filters, child))
}

private def hasOversizedRepeatedAliases(fields: Seq[Expression],
aliases: Map[Attribute, Expression]): Boolean = {
// Count how many times each alias is used in the fields.
// If an alias is only used once, we can safely substitute it without increasing the overall
// tree size
val referenceCounts = AttributeMap(
fields
.flatMap(_.collect { case a: Attribute => a })
.groupBy(identity)
.mapValues(_.size).toSeq
)

// Check for any aliases that are used more than once, and are larger than the configured
// maximum size
aliases.exists({ case (attribute, expression) =>
referenceCounts.getOrElse(attribute, 0) > 1 &&
expression.treeSize > SQLConf.get.maxRepeatedAliasSize
})
}

/**
* Collects all deterministic projects and filters, in-lining/substituting aliases if necessary.
* Here are two examples for alias in-lining/substitution.
Expand All @@ -81,8 +102,13 @@ object PhysicalOperation extends OperationHelper with PredicateHelper {
plan match {
case Project(fields, child) if fields.forall(_.deterministic) =>
val (_, filters, other, aliases) = collectProjectsAndFilters(child)
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))
if (hasOversizedRepeatedAliases(fields, aliases)) {
// Skip substitution if it could overly increase the overall tree size and risk OOMs
(None, Nil, plan, AttributeMap(Nil))
} else {
val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
(Some(substitutedFields), filters, other, collectAliases(substitutedFields))
}

case Filter(condition, child) if condition.deterministic =>
val (fields, filters, other, aliases) = collectProjectsAndFilters(child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {

lazy val containsChild: Set[TreeNode[_]] = children.toSet

lazy val treeSize: Long = children.map(_.treeSize).sum + 1

// Copied from Scala 2.13.1
// github.com/scala/scala/blob/v2.13.1/src/library/scala/util/hashing/MurmurHash3.scala#L56-L73
// to prevent the issue https://github.com/scala/bug/issues/10495
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2340,6 +2340,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val MAX_REPEATED_ALIAS_SIZE =
buildConf("spark.sql.maxRepeatedAliasSize")
.internal()
.doc("The maximum size of alias expression that will be substituted multiple times " +
"(size defined by the number of nodes in the expression tree). " +
"Used by the CollapseProject optimizer, and PhysicalOperation.")
.intConf
.createWithDefault(100)

val SOURCES_BINARY_FILE_MAX_LENGTH = buildConf("spark.sql.sources.binaryFile.maxLength")
.doc("The max length of a file that can be read by the binary file data source. " +
"Spark will fail fast and not attempt to read the file if its length exceeds this value. " +
Expand Down Expand Up @@ -3154,6 +3163,8 @@ class SQLConf extends Serializable with Logging {
def setCommandRejectsSparkCoreConfs: Boolean =
getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS)

def maxRepeatedAliasSize: Int = getConf(SQLConf.MAX_REPEATED_ALIAS_SIZE)

def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING)

def ignoreDataLocality: Boolean = getConf(SQLConf.IGNORE_DATA_LOCALITY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,23 @@ class CollapseProjectSuite extends PlanTest {
val expected = Sample(0.0, 0.6, false, 11L, relation.select('a as 'c)).analyze
comparePlans(optimized, expected)
}


test("ensure oversize aliases are not repeatedly substituted") {
var query: LogicalPlan = testRelation
for( a <- 1 to 100) {
query = query.select(('a + 'b).as('a), ('a - 'b).as('b))
}
val projects = Optimize.execute(query.analyze).collect { case p: Project => p }
assert(projects.size >= 12)
}

test("ensure oversize aliases are still substituted once") {
var query: LogicalPlan = testRelation
for( a <- 1 to 20) {
query = query.select(('a + 'b).as('a), 'b)
}
val projects = Optimize.execute(query.analyze).collect { case p: Project => p }
assert(projects.size === 1)
}
}

0 comments on commit 0f54fb0

Please sign in to comment.