Skip to content

Commit

Permalink
addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
navis committed Jun 30, 2015
1 parent 143e1ef commit c5419b3
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ case class GeneratedAggregate(
}
}

// even with empty input iterator, if this group-by operator is for
// global(groupingExpression.isEmpty) and final(partial=false),
// we still need to make a row from empty buffer.
def needEmptyBufferForwarded: Boolean = groupingExpressions.isEmpty && !partial

override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)

protected override def doExecute(): RDD[InternalRow] = {
Expand Down Expand Up @@ -270,8 +275,7 @@ case class GeneratedAggregate(

val joinedRow = new JoinedRow3

if (!iter.hasNext && (partial || groupingExpressions.nonEmpty)) {
// even with empty input, final-global groupby should forward value of empty buffer
if (!iter.hasNext && !needEmptyBufferForwarded) {
Iterator[InternalRow]()
} else if (groupingExpressions.isEmpty) {
// TODO: Codegening anything other than the updateProjection is probably over kill.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ

protected def newProjection(
expressions: Seq[Expression],
inputSchema: Seq[Attribute], mutableRow: Boolean = false): Projection = {
inputSchema: Seq[Attribute],
mutableRow: Boolean = false): Projection = {
log.debug(
s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if (codegenEnabled && expressions.forall(_.isThreadSafe)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,29 @@ class AggregateSuite extends SparkPlanTest {
test("SPARK-8357 Memory leakage on unsafe aggregation path with empty input") {

val input0 = Seq.empty[(String, Int, Double)]
val input1 = Seq(("Hello", 4, 2.0))

// hack : current default parallelism of test local backend is two
// in the case of needEmptyBufferForwarded=true, task makes a row from empty buffer
// even with empty input. And current default parallelism of SparkPlanTest is two (local[2])
val x0 = Seq(Tuple1(0L), Tuple1(0L))
val y0 = Seq.empty[Tuple1[Long]]

val input1 = Seq(("Hello", 4, 2.0))
val x1 = Seq(Tuple1(0L), Tuple1(1L))
val y1 = Seq(Tuple1(1L))

val codegenDefault = TestSQLContext.getConf(SQLConf.CODEGEN_ENABLED)
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, true)
try {
for ((input, x, y) <- Seq((input0, x0, y0), (input1, x1, y1))) {
val df = input.toDF("a", "b", "c")
val colB = df.col("b").expr
val colC = df.col("c").expr
val aggrExpr = Alias(Count(Cast(colC, LongType)), "Count")()

for ((codegen, unsafe) <- Seq((false, false), (true, false), (true, true));
partial <- Seq(false, true); groupExpr <- Seq(colB :: Nil, Seq.empty)) {
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, codegen)
for (partial <- Seq(false, true); groupExpr <- Seq(Seq(colB), Seq.empty)) {
val aggregate = GeneratedAggregate(partial, groupExpr, Seq(aggrExpr), true, _: SparkPlan)
checkAnswer(df,
GeneratedAggregate(partial, groupExpr, aggrExpr :: Nil, unsafe, _: SparkPlan),
if (groupExpr.isEmpty && !partial) x else y)
aggregate,
if (aggregate(null).needEmptyBufferForwarded) x else y)
}
}
} finally {
Expand Down

0 comments on commit c5419b3

Please sign in to comment.