Skip to content

Commit

Permalink
Rolled-back test-conf cleanup & fixed possible CCE & added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
navis committed Jun 29, 2015
1 parent 51178e8 commit 1a02a55
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ package org.apache.spark.sql.catalyst.expressions
* @param expressions a sequence of expressions that determine the value of each column of the
* output row.
*/
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
class InterpretedProjection(expressions: Seq[Expression], mutableRow: Boolean = false)
extends Projection {
def this(expressions: Seq[Expression],
inputSchema: Seq[Attribute], mutableRow: Boolean = false) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)), mutableRow)

// null check is required for when Kryo invokes the no-arg constructor.
protected val exprArray = if (expressions != null) expressions.toArray else null
Expand All @@ -36,7 +38,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
outputArray(i) = exprArray(i).eval(input)
i += 1
}
new GenericInternalRow(outputArray)
if (mutableRow) new GenericMutableRow(outputArray) else new GenericInternalRow(outputArray)
}

override def toString: String = s"Row => [${exprArray.mkString(",")}]"
Expand Down
4 changes: 0 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -550,10 +550,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
props.foreach { case (k, v) => setConfString(k, v) }
}

def setConf(props: Map[String, String]): Unit = settings.synchronized {
props.foreach { case (k, v) => setConfString(k, v) }
}

/** Set the given Spark SQL configuration property using a `string` value. */
def setConfString(key: String, value: String): Unit = {
require(key != null, "key cannot be null")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ case class GeneratedAggregate(
child.execute().mapPartitions { iter =>
// Builds a new custom class for holding the results of aggregation for a group.
val initialValues = computeFunctions.flatMap(_.initialValues)
val newAggregationBuffer = newProjection(initialValues, child.output)
val newAggregationBuffer = newProjection(initialValues, child.output, mutableRow = true)
log.info(s"Initial values: ${initialValues.mkString(",")}")

// A projection that computes the group given an input tuple.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,14 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}

protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
expressions: Seq[Expression],
inputSchema: Seq[Attribute], mutableRow: Boolean = false): Projection = {
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if (codegenEnabled && expressions.forall(_.isThreadSafe)) {
GenerateProjection.generate(expressions, inputSchema)
} else {
new InterpretedProjection(expressions, inputSchema)
new InterpretedProjection(expressions, inputSchema, mutableRow)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow, _}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
import org.apache.spark.sql.types.StructType
import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.{SparkContext, Logging, TaskContext}
import org.apache.spark.util.SerializableConfiguration

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.test

import scala.collection.immutable
import scala.language.implicitConversions

import org.apache.spark.{SparkConf, SparkContext}
Expand All @@ -37,18 +36,9 @@ class LocalSQLContext
}

protected[sql] class SQLSession extends super.SQLSession {
var backup: immutable.Map[String, String] = null
protected[sql] override lazy val conf: SQLConf = new SQLConf {
/** Fewer partitions to speed up testing. */
override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5)
backup = getAllConfs
}

protected[sql] def reset() = {
if (backup != null) {
conf.clear()
conf.setConf(backup)
}
}
}

Expand All @@ -60,11 +50,6 @@ class LocalSQLContext
DataFrame(this, plan)
}

/**
* Reset session conf to initial state
*/
protected[sql] def resetConf(): Unit = currentSession().asInstanceOf[SQLSession].reset

}

object TestSQLContext extends LocalSQLContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,28 @@ class AggregateSuite extends SparkPlanTest {

test("SPARK-8357 Memory leakage on unsafe aggregation path with empty input") {

val df = Seq.empty[(String, Int, Double)].toDF("a", "b", "c")

val groupExpr = df.col("b").expr
val aggrExpr = Alias(Count(Cast(groupExpr, LongType)), "Count")()

for ((codegen, unsafe) <- Seq((false, false), (true, false), (true, true));
partial <- Seq(false, true)) {
TestSQLContext.conf.setConfString("spark.sql.codegen", String.valueOf(codegen))
checkAnswer(
df,
GeneratedAggregate(partial, groupExpr :: Nil, aggrExpr :: Nil, unsafe, _: SparkPlan),
Seq.empty[(String, Int, Double)])
val input = Seq.empty[(String, Int, Double)]
val df = input.toDF("a", "b", "c")

val colB = df.col("b").expr
val colC = df.col("c").expr
val aggrExpr = Alias(Count(Cast(colC, LongType)), "Count")()

// hack : current default parallelism of test local backend is two
val two = Seq(Tuple1(0L), Tuple1(0L))
val empty = Seq.empty[Tuple1[Long]]

val codegenDefault = TestSQLContext.conf.getConfString("spark.sql.codegen")
try {
for ((codegen, unsafe) <- Seq((false, false), (true, false), (true, true));
partial <- Seq(false, true); groupExpr <- Seq(colB :: Nil, Seq.empty)) {
TestSQLContext.conf.setConfString("spark.sql.codegen", String.valueOf(codegen))
checkAnswer(df,
GeneratedAggregate(partial, groupExpr, aggrExpr :: Nil, unsafe, _: SparkPlan),
if (groupExpr.isEmpty && !partial) two else empty)
}
} finally {
TestSQLContext.conf.setConfString("spark.sql.codegen", codegenDefault)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.execution

import org.scalatest.Tag

import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
Expand All @@ -45,14 +43,6 @@ class SparkPlanTest extends SparkFunSuite {
TestSQLContext.implicits.localSeqToDataFrameHolder(data)
}

protected override def test(testName: String, testTags: Tag*)(testFun: => Unit): Unit = {
try {
super.test(testName, testTags: _*)(testFun)
} finally {
TestSQLContext.resetConf()
}
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
Expand Down

0 comments on commit 1a02a55

Please sign in to comment.