From c5419b37f2ece4842d8c2bd7463a50b61245fbcf Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Tue, 30 Jun 2015 11:33:15 +0900 Subject: [PATCH] addressed comments --- .../spark/sql/execution/GeneratedAggregate.scala | 8 ++++++-- .../apache/spark/sql/execution/SparkPlan.scala | 3 ++- .../spark/sql/execution/AggregateSuite.scala | 16 ++++++++-------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 10cd29f6f7bc9..5f69a5e19b2be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -64,6 +64,11 @@ case class GeneratedAggregate( } } + // even with empty input iterator, if this group-by operator is for + // global(groupingExpression.isEmpty) and final(partial=false), + // we still need to make a row from empty buffer. + def needEmptyBufferForwarded: Boolean = groupingExpressions.isEmpty && !partial + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) protected override def doExecute(): RDD[InternalRow] = { @@ -270,8 +275,7 @@ case class GeneratedAggregate( val joinedRow = new JoinedRow3 - if (!iter.hasNext && (partial || groupingExpressions.nonEmpty)) { - // even with empty input, final-global groupby should forward value of empty buffer + if (!iter.hasNext && !needEmptyBufferForwarded) { Iterator[InternalRow]() } else if (groupingExpressions.isEmpty) { // TODO: Codegening anything other than the updateProjection is probably over kill. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 99f8e9433c919..d15ee93bd7aea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -154,7 +154,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def newProjection( expressions: Seq[Expression], - inputSchema: Seq[Attribute], mutableRow: Boolean = false): Projection = { + inputSchema: Seq[Attribute], + mutableRow: Boolean = false): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled && expressions.forall(_.isThreadSafe)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala index b8ee523eb9c6d..7c87024ce85bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala @@ -27,16 +27,17 @@ class AggregateSuite extends SparkPlanTest { test("SPARK-8357 Memory leakage on unsafe aggregation path with empty input") { val input0 = Seq.empty[(String, Int, Double)] - val input1 = Seq(("Hello", 4, 2.0)) - - // hack : current default parallelism of test local backend is two + // in the case of needEmptyBufferForwarded=true, task makes a row from empty buffer + // even with empty input. And current default parallelism of SparkPlanTest is two (local[2]) val x0 = Seq(Tuple1(0L), Tuple1(0L)) val y0 = Seq.empty[Tuple1[Long]] + val input1 = Seq(("Hello", 4, 2.0)) val x1 = Seq(Tuple1(0L), Tuple1(1L)) val y1 = Seq(Tuple1(1L)) val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED) + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true) try { for ((input, x, y) <- Seq((input0, x0, y0), (input1, x1, y1))) { val df = input.toDF("a", "b", "c") @@ -44,12 +45,11 @@ class AggregateSuite extends SparkPlanTest { val colC = df.col("c").expr val aggrExpr = Alias(Count(Cast(colC, LongType)), "Count")() - for ((codegen, unsafe) <- Seq((false, false), (true, false), (true, true)); - partial <- Seq(false, true); groupExpr <- Seq(colB :: Nil, Seq.empty)) { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegen) + for (partial <- Seq(false, true); groupExpr <- Seq(Seq(colB), Seq.empty)) { + val aggregate = GeneratedAggregate(partial, groupExpr, Seq(aggrExpr), true, _: SparkPlan) checkAnswer(df, - GeneratedAggregate(partial, groupExpr, aggrExpr :: Nil, unsafe, _: SparkPlan), - if (groupExpr.isEmpty && !partial) x else y) + aggregate, + if (aggregate(null).needEmptyBufferForwarded) x else y) } } } finally {