Skip to content

Commit

Permalink
Some minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 21, 2015
1 parent c649310 commit 3486ce4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ case class GeneratedAggregate(
}
}

// even with empty input iterator, if this group-by operator is for
// global(groupingExpression.isEmpty) and final(partial=false),
// we still need to make a row from empty buffer.
def needEmptyBufferForwarded: Boolean = groupingExpressions.isEmpty && !partial

override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)

protected override def doExecute(): RDD[InternalRow] = {
Expand Down Expand Up @@ -247,7 +242,7 @@ case class GeneratedAggregate(
child.execute().mapPartitions { iter =>
// Builds a new custom class for holding the results of aggregation for a group.
val initialValues = computeFunctions.flatMap(_.initialValues)
val newAggregationBuffer = newProjection(initialValues, child.output, mutableRow = true)
val newAggregationBuffer = newProjection(initialValues, child.output)
log.info(s"Initial values: ${initialValues.mkString(",")}")

// A projection that computes the group given an input tuple.
Expand All @@ -271,8 +266,17 @@ case class GeneratedAggregate(

val joinedRow = new JoinedRow3

if (!iter.hasNext && !needEmptyBufferForwarded) {
Iterator[InternalRow]()
if (!iter.hasNext) {
// This is an empty input, so return early so that we do not allocate data structures
// that won't be cleaned up (see SPARK-8357).
if (groupingExpressions.isEmpty) {
// This is a global aggregate, so return an empty aggregation buffer.
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(newAggregationBuffer(EmptyRow)))
} else {
// This is a grouped aggregate, so return an empty iterator.
Iterator[InternalRow]()
}
} else if (groupingExpressions.isEmpty) {
// TODO: Codegening anything other than the updateProjection is probably over kill.
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
Expand All @@ -287,7 +291,6 @@ case class GeneratedAggregate(
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(buffer))
} else if (unsafeEnabled) {
// unsafe aggregation buffer is not released if input is empty (see SPARK-8357)
assert(iter.hasNext, "There should be at least one row for this path")
log.info("Using Unsafe-based aggregator")
val aggregationMap = new UnsafeFixedWidthAggregationMap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
Row(2, 1, 2, 2, 1))
}

test("count of empty table") {
withTempTable("t") {
Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t")
checkAnswer(
sql("select count(a) from t"),
Row(0))
}
}

test("inner join where, one match per row") {
checkAnswer(
sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,40 +20,29 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types.DataTypes._

class AggregateSuite extends SparkPlanTest {

test("SPARK-8357 Memory leakage on unsafe aggregation path with empty input") {

val input0 = Seq.empty[(String, Int, Double)]
// in the case of needEmptyBufferForwarded=true, task makes a row from empty buffer
// even with empty input. And current default parallelism of SparkPlanTest is two (local[2])
val x0 = Seq(Tuple1(0L), Tuple1(0L))
val y0 = Seq.empty[Tuple1[Long]]

val input1 = Seq(("Hello", 4, 2.0))
val x1 = Seq(Tuple1(0L), Tuple1(1L))
val y1 = Seq(Tuple1(1L))

test("SPARK-8357 unsafe aggregation path should not leak memory with empty input") {
val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED)
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true)
val unsafeDefault = TestSQLContext.getConf(SQLConf.UNSAFE_ENABLED)
try {
for ((input, x, y) <- Seq((input0, x0, y0), (input1, x1, y1))) {
val df = input.toDF("a", "b", "c")
val colB = df.col("b").expr
val colC = df.col("c").expr
val aggrExpr = Alias(Count(Cast(colC, LongType)), "Count")()

for (partial <- Seq(false, true); groupExpr <- Seq(Seq(colB), Seq.empty)) {
val aggregate = GeneratedAggregate(partial, groupExpr, Seq(aggrExpr), true, _: SparkPlan)
checkAnswer(df,
aggregate,
if (aggregate(null).needEmptyBufferForwarded) x else y)
}
}
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true)
TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, true)
val df = Seq.empty[(Int, Int)].toDF("a", "b")
checkAnswer(
df,
GeneratedAggregate(
partial = true,
Seq(df.col("b").expr),
Seq(Alias(Count(df.col("a").expr), "cnt")()),
unsafeEnabled = true,
_: SparkPlan),
Seq.empty
)
} finally {
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault)
TestSQLContext.setConf(SQLConf.UNSAFE_ENABLED, unsafeDefault)
}
}
}

0 comments on commit 3486ce4

Please sign in to comment.