diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9a10a23937fbb..b3fb097fb02cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -550,6 +550,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf { props.foreach { case (k, v) => setConfString(k, v) } } + def setConf(props: Map[String, String]): Unit = settings.synchronized { + props.foreach { case (k, v) => setConfString(k, v) } + } + /** Set the given Spark SQL configuration property using a `string` value. */ def setConfString(key: String, value: String): Unit = { require(key != null, "key cannot be null") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 8be63c677845a..9b4ddc3ac2d94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -270,8 +270,10 @@ case class GeneratedAggregate( val joinedRow = new JoinedRow3 - if (groupingExpressions.isEmpty) { - // even with the empty input, value of empty buffer should be forwarded + if (!iter.hasNext && (partial || groupingExpressions.nonEmpty)) { + // even with empty input, final-global groupby should forward value of empty buffer + Iterator[InternalRow]() + } else if (groupingExpressions.isEmpty) { // TODO: Codegening anything other than the updateProjection is probably over kill. val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] var currentRow: InternalRow = null @@ -284,8 +286,6 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) - } else if (!iter.hasNext) { - Iterator[InternalRow]() } else if (unsafeEnabled && schemaSupportsUnsafe) { // unsafe aggregation buffer is not released if input is empty (see SPARK-8357) assert(iter.hasNext, "There should be at least one row for this path") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index 9fa394525d65c..1e8ec9ab81b60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.test +import scala.collection.immutable import scala.language.implicitConversions import org.apache.spark.{SparkConf, SparkContext} @@ -36,9 +37,18 @@ class LocalSQLContext } protected[sql] class SQLSession extends super.SQLSession { + var backup: immutable.Map[String, String] = null protected[sql] override lazy val conf: SQLConf = new SQLConf { /** Fewer partitions to speed up testing. */ override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) + backup = getAllConfs + } + + protected[sql] def reset() = { + if (backup != null) { + conf.clear() + conf.setConf(backup) + } } } @@ -50,6 +60,11 @@ class LocalSQLContext DataFrame(this, plan) } + /** + * Reset session conf to initial state + */ + protected[sql] def resetConf(): Unit = currentSession().asInstanceOf[SQLSession].reset + } object TestSQLContext extends LocalSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala index 88620baddb520..fb088028f9a3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregateSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution -import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types.DataTypes._ @@ -31,8 +30,8 @@ class AggregateSuite extends SparkPlanTest { val groupExpr = df.col("b").expr val aggrExpr = Alias(Count(Cast(groupExpr, LongType)), "Count")() - SparkEnv.get.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") - for (codegen <- Seq(false, true); partial <- Seq(false, true); unsafe <- Seq(false, true)) { + for ((codegen, unsafe) <- Seq((false, false), (true, false), (true, true)); + partial <- Seq(false, true)) { TestSQLContext.conf.setConfString("spark.sql.codegen", String.valueOf(codegen)) checkAnswer( df, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 13f3be8ca28d6..d454ca24dcd0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import org.scalatest.Tag + import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -43,6 +45,14 @@ class SparkPlanTest extends SparkFunSuite { TestSQLContext.implicits.localSeqToDataFrameHolder(data) } + protected override def test(testName: String, testTags: Tag*)(testFun: => Unit): Unit = { + try { + super.test(testName, testTags: _*)(testFun) + } finally { + TestSQLContext.resetConf() + } + } + /** * Runs the plan and makes sure the answer matches the expected result. * @param input the input data to be used.