Skip to content

Commit

Permalink
fixed format & added test for CCE case
Browse files Browse the repository at this point in the history
  • Loading branch information
navis committed Jun 29, 2015
1 parent 735972f commit 143e1ef
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ package org.apache.spark.sql.catalyst.expressions
*/
class InterpretedProjection(expressions: Seq[Expression], mutableRow: Boolean = false)
extends Projection {
def this(expressions: Seq[Expression],
inputSchema: Seq[Attribute], mutableRow: Boolean = false) =

def this(
expressions: Seq[Expression],
inputSchema: Seq[Attribute],
mutableRow: Boolean = false) = {
this(expressions.map(BindReferences.bindReference(_, inputSchema)), mutableRow)
}

// null check is required for when Kryo invokes the no-arg constructor.
protected val exprArray = if (expressions != null) expressions.toArray else null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow, _}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
import org.apache.spark.sql.types.StructType
import org.apache.spark.{SparkContext, Logging, TaskContext}
import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.util.SerializableConfiguration

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,31 @@ class AggregateSuite extends SparkPlanTest {

test("SPARK-8357 Memory leakage on unsafe aggregation path with empty input") {

val input = Seq.empty[(String, Int, Double)]
val df = input.toDF("a", "b", "c")

val colB = df.col("b").expr
val colC = df.col("c").expr
val aggrExpr = Alias(Count(Cast(colC, LongType)), "Count")()
val input0 = Seq.empty[(String, Int, Double)]
val input1 = Seq(("Hello", 4, 2.0))

// hack : current default parallelism of test local backend is two
val two = Seq(Tuple1(0L), Tuple1(0L))
val empty = Seq.empty[Tuple1[Long]]
val x0 = Seq(Tuple1(0L), Tuple1(0L))
val y0 = Seq.empty[Tuple1[Long]]

val x1 = Seq(Tuple1(0L), Tuple1(1L))
val y1 = Seq(Tuple1(1L))

val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED)
try {
for ((codegen, unsafe) <- Seq((false, false), (true, false), (true, true));
partial <- Seq(false, true); groupExpr <- Seq(colB :: Nil, Seq.empty)) {
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegen)
checkAnswer(df,
GeneratedAggregate(partial, groupExpr, aggrExpr :: Nil, unsafe, _: SparkPlan),
if (groupExpr.isEmpty && !partial) two else empty)
for ((input, x, y) <- Seq((input0, x0, y0), (input1, x1, y1))) {
val df = input.toDF("a", "b", "c")
val colB = df.col("b").expr
val colC = df.col("c").expr
val aggrExpr = Alias(Count(Cast(colC, LongType)), "Count")()

for ((codegen, unsafe) <- Seq((false, false), (true, false), (true, true));
partial <- Seq(false, true); groupExpr <- Seq(colB :: Nil, Seq.empty)) {
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegen)
checkAnswer(df,
GeneratedAggregate(partial, groupExpr, aggrExpr :: Nil, unsafe, _: SparkPlan),
if (groupExpr.isEmpty && !partial) x else y)
}
}
} finally {
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
Expand Down

0 comments on commit 143e1ef

Please sign in to comment.