Skip to content

Commit

Permalink
Bug fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 16, 2015
1 parent aff9534 commit 5b46d41
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ private[sql] case class AggregateExpression2(
override def children: Seq[Expression] = aggregateFunction :: Nil

override def dataType: DataType = aggregateFunction.dataType
override def foldable: Boolean = aggregateFunction.foldable
override def foldable: Boolean = false
override def nullable: Boolean = aggregateFunction.nullable

override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
Expand All @@ -75,17 +75,16 @@ abstract class AggregateFunction2

var bufferOffset: Int = 0

def withBufferOffset(newBufferOffset: Int): AggregateFunction2 = {
bufferOffset = newBufferOffset
this
}
override def foldable: Boolean = false

/** The schema of the aggregation buffer. */
def bufferSchema: StructType

/** Attributes of fields in bufferSchema. */
def bufferAttributes: Seq[Attribute]

def rightBufferSchema: Seq[Attribute]

def initialize(buffer: MutableRow): Unit

def update(buffer: MutableRow, input: InternalRow): Unit
Expand All @@ -100,7 +99,7 @@ case class MyDoubleSum(child: Expression) extends AggregateFunction2 {
StructType(StructField("currentSum", DoubleType, true) :: Nil)

override val bufferAttributes: Seq[Attribute] = bufferSchema.toAttributes

override lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
override def initialize(buffer: MutableRow): Unit = {
buffer.update(bufferOffset, null)
}
Expand Down Expand Up @@ -152,17 +151,7 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable {
val mergeExpressions: Seq[Expression]
val evaluateExpression: Expression

/** Must be filled in by the executors */
var inputSchema: Seq[Attribute] = _

override def withBufferOffset(newBufferOffset: Int): AlgebraicAggregate = {
bufferOffset = newBufferOffset
this
}

def offsetExpressions: Seq[Attribute] = Seq.fill(bufferOffset)(AttributeReference("offset", NullType)())

lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
override lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
implicit class RichAttribute(a: AttributeReference) {
def left = a
def right = rightBufferSchema(bufferAttributes.indexOf(a))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ class AggregateExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

test("Average") {
val inputValues = Array(Int.MaxValue, null, 1000, Int.MinValue, 2)
val avg = Average(child = BoundReference(0, IntegerType, true)).withBufferOffset(2)
val avg = Average(child = BoundReference(0, IntegerType, true))
avg.bufferOffset = 2
val inputRow = new GenericMutableRow(1)
val buffer = new GenericMutableRow(4)
avg.initialize(buffer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Aggregate(groupingExpressions, resultExpressions, child)
if sqlContext.conf.useSqlAggregate2 =>
// 0. Make sure we can convert.
resultExpressions.foreach {
case agg1: AggregateExpression =>
sys.error(s"$agg1 is not supported. Please set spark.sql.useAggregate2 to false.")
case _ => // ok
}
// 1. Extracts all distinct aggregate expressions from the resultExpressions.
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,34 +71,35 @@ case class Aggregate2Sort(
while (i < aggregateExpressions.length) {
val func = aggregateExpressions(i).aggregateFunction
bufferOffsets += bufferOffset
bufferOffset = aggregateExpressions(i).mode match {
case Partial | PartialMerge => bufferOffset + func.bufferSchema.length
case Final | Complete => bufferOffset + 1
}
bufferOffset += func.bufferSchema.length
i += 1
}
aggregateExpressions.zip(bufferOffsets)
}

private val algebraicAggregateFunctions: Array[AlgebraicAggregate] = {
aggregateExprsWithBufferOffset.collect {
case (AggregateExpression2(agg: AlgebraicAggregate, mode, isDistinct), offset) =>
agg.inputSchema = child.output
agg.withBufferOffset(offset)
// println("aggregateExprsWithBufferOffset " + aggregateExprsWithBufferOffset)

private val aggregateFunctions: Array[AggregateFunction2] = {
aggregateExprsWithBufferOffset.map {
case (aggExpr, bufferOffset) =>
val func = aggExpr.aggregateFunction
func.bufferOffset = bufferOffset
func
}.toArray
}

private val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
aggregateExprsWithBufferOffset.collect {
case (AggregateExpression2(agg: AggregateFunction2, mode, isDistinct), offset)
if !agg.isInstanceOf[AlgebraicAggregate] =>
val func = agg.withBufferOffset(offset)
mode match {
case Partial | Complete =>
// Only need to bind reference when the function is not an AlgebraicAggregate
// and the mode is Partial or Complete.
BindReferences.bindReference(func, child.output)
case _ => func
val func = BindReferences.bindReference(agg, child.output)
// Need to set it again since BindReference will create a new instance.
func.bufferOffset = offset
func
case _ => agg
}
}.toArray
}
Expand All @@ -119,13 +120,8 @@ case class Aggregate2Sort(
private val bufferSize: Int = {
var size = 0
var i = 0
while (i < algebraicAggregateFunctions.length) {
size += algebraicAggregateFunctions(i).bufferSchema.length
i += 1
}
i = 0
while (i < nonAlgebraicAggregateFunctions.length) {
size += nonAlgebraicAggregateFunctions(i).bufferSchema.length
while (i < aggregateFunctions.length) {
size += aggregateFunctions(i).bufferSchema.length
i += 1
}
if (preShuffle) {
Expand Down Expand Up @@ -160,20 +156,23 @@ case class Aggregate2Sort(
val offsetExpressions = if (preShuffle) Nil else Seq.fill(groupingExpressions.length)(NoOp)

val algebraicInitialProjection = {
val initExpressions = offsetExpressions ++ algebraicAggregateFunctions.flatMap {
val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.initialValues
case agg: AggregateFunction2 => NoOp :: Nil
}
// println(initExpressions.mkString(","))

newMutableProjection(initExpressions, Nil)().target(buffer)
}

lazy val algebraicUpdateProjection = {
val bufferSchema = algebraicAggregateFunctions.flatMap {
val bufferSchema = aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.bufferAttributes
case agg: AggregateFunction2 => agg.bufferAttributes
}
val updateExpressions = algebraicAggregateFunctions.flatMap {
val updateExpressions = aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.updateExpressions
case agg: AggregateFunction2 => NoOp :: Nil
}

// println(updateExpressions.mkString(","))
Expand All @@ -182,27 +181,33 @@ case class Aggregate2Sort(

lazy val algebraicMergeProjection = {
val bufferSchemata =
offsetAttributes ++ algebraicAggregateFunctions.flatMap {
offsetAttributes ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.bufferAttributes
} ++ offsetAttributes ++ algebraicAggregateFunctions.flatMap {
case agg: AggregateFunction2 => agg.bufferAttributes
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.rightBufferSchema
case agg: AggregateFunction2 => agg.rightBufferSchema
}
val mergeExpressions = offsetExpressions ++ algebraicAggregateFunctions.flatMap {
val mergeExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
case agg: AggregateFunction2 => NoOp :: Nil
}

newMutableProjection(mergeExpressions, bufferSchemata)()
}

lazy val algebraicEvalProjection = {
val bufferSchemata =
offsetAttributes ++ algebraicAggregateFunctions.flatMap {
offsetAttributes ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.bufferAttributes
} ++ offsetAttributes ++ algebraicAggregateFunctions.flatMap {
case agg: AggregateFunction2 => agg.bufferAttributes
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.rightBufferSchema
case agg: AggregateFunction2 => agg.rightBufferSchema
}
val evalExpressions = algebraicAggregateFunctions.map {
val evalExpressions = aggregateFunctions.map {
case ae: AlgebraicAggregate => ae.evaluateExpression
case agg: AggregateFunction2 => NoOp
}

newMutableProjection(evalExpressions, bufferSchemata)()
Expand Down Expand Up @@ -251,6 +256,7 @@ case class Aggregate2Sort(
nonAlgebraicAggregateFunctions(i).merge(buffer, row)
i += 1
}
// println("buffer merge " + buffer + " " + row)
}
}

Expand Down Expand Up @@ -293,6 +299,7 @@ case class Aggregate2Sort(
val outputRow =
if (preShuffle) {
// If it is preShuffle, we just output the grouping columns and the buffer.
// println("buffer " + buffer)
joinedRow(currentGroupingKey, buffer).copy()
} else {
algebraicEvalProjection.target(aggregateResult)(buffer)
Expand All @@ -304,7 +311,6 @@ case class Aggregate2Sort(
i += 1
}
resultProjection(joinedRow(currentGroupingKey, aggregateResult))

}
initializeBuffer()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.sql.execution.aggregate2

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.{Average => Average1}
import org.apache.spark.sql.{SQLConf, AnalysisException, SQLContext}
import org.apache.spark.sql.catalyst.expressions.{Average => Average1, AggregateExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate2.{Average => Average2, AggregateExpression2, Complete}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -32,3 +32,18 @@ case class ConvertAggregateFunction(context: SQLContext) extends Rule[LogicalPla
}
}
}

case class CheckAggregateFunction(context: SQLContext) extends (LogicalPlan => Unit) {
def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) }

def apply(plan: LogicalPlan): Unit = plan.foreachUp {
case p if context.conf.useSqlAggregate2 => p.transformExpressionsUp {
case agg: AggregateExpression =>
failAnalysis(s"${SQLConf.USE_SQL_AGGREGATE2} is enabled. Please disable it to use $agg.")
}
case p if !context.conf.useSqlAggregate2 => p.transformExpressionsUp {
case agg: AggregateExpression2 =>
failAnalysis(s"${SQLConf.USE_SQL_AGGREGATE2} is disabled. Please enable it to use $agg.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,34 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
Row(null) :: Nil)

}

test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") {
checkAnswer(
ctx.sql(
"""
|SELECT mydoublesum(cast(value as double)), key, avg(value)
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(60.0, 1, 20.0) :: Row(-1.0, 2, -0.5) :: Row(null, 3, null) :: Nil)

checkAnswer(
ctx.sql(
"""
|SELECT
| mydoublesum(cast(value as double) + 1.5 * key),
| avg(value - key),
| key,
| mydoublesum(cast(value as double) - 1.5 * key),
| avg(value)
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(64.5, 19.0, 1, 55.5, 20.0) ::
Row(5.0, -2.5, 2, -7.0, -0.5) ::
Row(null, null, 3, null, null) :: Nil)
}

override def afterAll(): Unit = {
ctx.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2")
}
Expand Down

0 comments on commit 5b46d41

Please sign in to comment.