Skip to content

Commit

Permalink
addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
navis committed Jun 29, 2015
1 parent 4d326b9 commit 51178e8
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 7 deletions.
4 changes: 4 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
props.foreach { case (k, v) => setConfString(k, v) }
}

def setConf(props: Map[String, String]): Unit = settings.synchronized {
props.foreach { case (k, v) => setConfString(k, v) }
}

/** Set the given Spark SQL configuration property using a `string` value. */
def setConfString(key: String, value: String): Unit = {
require(key != null, "key cannot be null")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,10 @@ case class GeneratedAggregate(

val joinedRow = new JoinedRow3

if (groupingExpressions.isEmpty) {
// even with the empty input, value of empty buffer should be forwarded
if (!iter.hasNext && (partial || groupingExpressions.nonEmpty)) {
// even with empty input, final-global groupby should forward value of empty buffer
Iterator[InternalRow]()
} else if (groupingExpressions.isEmpty) {
// TODO: Codegening anything other than the updateProjection is probably over kill.
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
var currentRow: InternalRow = null
Expand All @@ -284,8 +286,6 @@ case class GeneratedAggregate(

val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(buffer))
} else if (!iter.hasNext) {
Iterator[InternalRow]()
} else if (unsafeEnabled && schemaSupportsUnsafe) {
// unsafe aggregation buffer is not released if input is empty (see SPARK-8357)
assert(iter.hasNext, "There should be at least one row for this path")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.test

import scala.collection.immutable
import scala.language.implicitConversions

import org.apache.spark.{SparkConf, SparkContext}
Expand All @@ -36,9 +37,18 @@ class LocalSQLContext
}

protected[sql] class SQLSession extends super.SQLSession {
var backup: immutable.Map[String, String] = null
protected[sql] override lazy val conf: SQLConf = new SQLConf {
/** Fewer partitions to speed up testing. */
override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5)
backup = getAllConfs
}

protected[sql] def reset() = {
if (backup != null) {
conf.clear()
conf.setConf(backup)
}
}
}

Expand All @@ -50,6 +60,11 @@ class LocalSQLContext
DataFrame(this, plan)
}

/**
* Reset session conf to initial state
*/
protected[sql] def resetConf(): Unit = currentSession().asInstanceOf[SQLSession].reset

}

object TestSQLContext extends LocalSQLContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.execution

import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types.DataTypes._
Expand All @@ -31,8 +30,8 @@ class AggregateSuite extends SparkPlanTest {
val groupExpr = df.col("b").expr
val aggrExpr = Alias(Count(Cast(groupExpr, LongType)), "Count")()

SparkEnv.get.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
for (codegen <- Seq(false, true); partial <- Seq(false, true); unsafe <- Seq(false, true)) {
for ((codegen, unsafe) <- Seq((false, false), (true, false), (true, true));
partial <- Seq(false, true)) {
TestSQLContext.conf.setConfString("spark.sql.codegen", String.valueOf(codegen))
checkAnswer(
df,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution

import org.scalatest.Tag

import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
Expand All @@ -43,6 +45,14 @@ class SparkPlanTest extends SparkFunSuite {
TestSQLContext.implicits.localSeqToDataFrameHolder(data)
}

protected override def test(testName: String, testTags: Tag*)(testFun: => Unit): Unit = {
try {
super.test(testName, testTags: _*)(testFun)
} finally {
TestSQLContext.resetConf()
}
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
Expand Down

0 comments on commit 51178e8

Please sign in to comment.