From 3486ce4b04a99df8a81846dccbad1c2b97c41d31 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 20 Jul 2015 23:33:52 -0700 Subject: [PATCH] Some minor cleanup --- .../sql/execution/GeneratedAggregate.scala | 21 +++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 9 ++++ .../spark/sql/execution/AggregateSuite.scala | 43 +++++++------------ 3 files changed, 37 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index c35062cee3513..ecde9c57139a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -65,11 +65,6 @@ case class GeneratedAggregate( } } - // even with empty input iterator, if this group-by operator is for - // global(groupingExpression.isEmpty) and final(partial=false), - // we still need to make a row from empty buffer. - def needEmptyBufferForwarded: Boolean = groupingExpressions.isEmpty && !partial - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) protected override def doExecute(): RDD[InternalRow] = { @@ -247,7 +242,7 @@ case class GeneratedAggregate( child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) - val newAggregationBuffer = newProjection(initialValues, child.output, mutableRow = true) + val newAggregationBuffer = newProjection(initialValues, child.output) log.info(s"Initial values: ${initialValues.mkString(",")}") // A projection that computes the group given an input tuple. @@ -271,8 +266,17 @@ case class GeneratedAggregate( val joinedRow = new JoinedRow3 - if (!iter.hasNext && !needEmptyBufferForwarded) { - Iterator[InternalRow]() + if (!iter.hasNext) { + // This is an empty input, so return early so that we do not allocate data structures + // that won't be cleaned up (see SPARK-8357). + if (groupingExpressions.isEmpty) { + // This is a global aggregate, so return an empty aggregation buffer. + val resultProjection = resultProjectionBuilder() + Iterator(resultProjection(newAggregationBuffer(EmptyRow))) + } else { + // This is a grouped aggregate, so return an empty iterator. + Iterator[InternalRow]() + } } else if (groupingExpressions.isEmpty) { // TODO: Codegening anything other than the updateProjection is probably over kill. val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] @@ -287,7 +291,6 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) } else if (unsafeEnabled) { - // unsafe aggregation buffer is not released if input is empty (see SPARK-8357) assert(iter.hasNext, "There should be at least one row for this path") log.info("Using Unsafe-based aggregator") val aggregationMap = new UnsafeFixedWidthAggregationMap( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 61d5f2061ae18..beee10173fbc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -648,6 +648,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(2, 1, 2, 2, 1)) } + test("count of empty table") { + withTempTable("t") { + Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t") + checkAnswer( + sql("select count(a) from t"), + Row(0)) + } + } + test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala index 7c87024ce85bb..20def6bef0c17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala @@ -20,40 +20,29 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.types.DataTypes._ class AggregateSuite extends SparkPlanTest { - test("SPARK-8357 Memory leakage on unsafe aggregation path with empty input") { - - val input0 = Seq.empty[(String, Int, Double)] - // in the case of needEmptyBufferForwarded=true, task makes a row from empty buffer - // even with empty input. And current default parallelism of SparkPlanTest is two (local[2]) - val x0 = Seq(Tuple1(0L), Tuple1(0L)) - val y0 = Seq.empty[Tuple1[Long]] - - val input1 = Seq(("Hello", 4, 2.0)) - val x1 = Seq(Tuple1(0L), Tuple1(1L)) - val y1 = Seq(Tuple1(1L)) - + test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") { val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED) - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true) + val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED) try { - for ((input, x, y) <- Seq((input0, x0, y0), (input1, x1, y1))) { - val df = input.toDF("a", "b", "c") - val colB = df.col("b").expr - val colC = df.col("c").expr - val aggrExpr = Alias(Count(Cast(colC, LongType)), "Count")() - - for (partial <- Seq(false, true); groupExpr <- Seq(Seq(colB), Seq.empty)) { - val aggregate = GeneratedAggregate(partial, groupExpr, Seq(aggrExpr), true, _: SparkPlan) - checkAnswer(df, - aggregate, - if (aggregate(null).needEmptyBufferForwarded) x else y) - } - } + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true) + TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true) + val df = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + df, + GeneratedAggregate( + partial = true, + Seq(df.col("b").expr), + Seq(Alias(Count(df.col("a").expr), "cnt")()), + unsafeEnabled = true, + _: SparkPlan), + Seq.empty + ) } finally { TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) + TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault) } } }