diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index 733b5dc0be81f..f6bcf06de84b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -56,7 +56,7 @@ private[sql] case class AggregateExpression2( override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType - override def foldable: Boolean = aggregateFunction.foldable + override def foldable: Boolean = false override def nullable: Boolean = aggregateFunction.nullable override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" @@ -75,10 +75,7 @@ abstract class AggregateFunction2 var bufferOffset: Int = 0 - def withBufferOffset(newBufferOffset: Int): AggregateFunction2 = { - bufferOffset = newBufferOffset - this - } + override def foldable: Boolean = false /** The schema of the aggregation buffer. */ def bufferSchema: StructType @@ -86,6 +83,8 @@ abstract class AggregateFunction2 /** Attributes of fields in bufferSchema. */ def bufferAttributes: Seq[Attribute] + def rightBufferSchema: Seq[Attribute] + def initialize(buffer: MutableRow): Unit def update(buffer: MutableRow, input: InternalRow): Unit @@ -100,7 +99,7 @@ case class MyDoubleSum(child: Expression) extends AggregateFunction2 { StructType(StructField("currentSum", DoubleType, true) :: Nil) override val bufferAttributes: Seq[Attribute] = bufferSchema.toAttributes - + override lazy val rightBufferSchema = bufferAttributes.map(_.newInstance()) override def initialize(buffer: MutableRow): Unit = { buffer.update(bufferOffset, null) } @@ -152,17 +151,7 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable { val mergeExpressions: Seq[Expression] val evaluateExpression: Expression - /** Must be filled in by the executors */ - var inputSchema: Seq[Attribute] = _ - - override def withBufferOffset(newBufferOffset: Int): AlgebraicAggregate = { - bufferOffset = newBufferOffset - this - } - - def offsetExpressions: Seq[Attribute] = Seq.fill(bufferOffset)(AttributeReference("offset", NullType)()) - - lazy val rightBufferSchema = bufferAttributes.map(_.newInstance()) + override lazy val rightBufferSchema = bufferAttributes.map(_.newInstance()) implicit class RichAttribute(a: AttributeReference) { def left = a def right = rightBufferSchema(bufferAttributes.indexOf(a)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/AggregateExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/AggregateExpressionSuite.scala index c0d5608e63326..08095bdfa4aa9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/AggregateExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/AggregateExpressionSuite.scala @@ -25,7 +25,8 @@ class AggregateExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("Average") { val inputValues = Array(Int.MaxValue, null, 1000, Int.MinValue, 2) - val avg = Average(child = BoundReference(0, IntegerType, true)).withBufferOffset(2) + val avg = Average(child = BoundReference(0, IntegerType, true)) + avg.bufferOffset = 2 val inputRow = new GenericMutableRow(1) val buffer = new GenericMutableRow(4) avg.initialize(buffer) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f22507b8c21e1..226eead5c67e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -205,12 +205,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Aggregate(groupingExpressions, resultExpressions, child) if sqlContext.conf.useSqlAggregate2 => - // 0. Make sure we can convert. - resultExpressions.foreach { - case agg1: AggregateExpression => - sys.error(s"$agg1 is not supported. Please set spark.sql.useAggregate2 to false.") - case _ => // ok - } // 1. Extracts all distinct aggregate expressions from the resultExpressions. val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala index 1d2ae6261952e..daa1d1a01078f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala @@ -71,20 +71,19 @@ case class Aggregate2Sort( while (i < aggregateExpressions.length) { val func = aggregateExpressions(i).aggregateFunction bufferOffsets += bufferOffset - bufferOffset = aggregateExpressions(i).mode match { - case Partial | PartialMerge => bufferOffset + func.bufferSchema.length - case Final | Complete => bufferOffset + 1 - } + bufferOffset += func.bufferSchema.length i += 1 } aggregateExpressions.zip(bufferOffsets) } - - private val algebraicAggregateFunctions: Array[AlgebraicAggregate] = { - aggregateExprsWithBufferOffset.collect { - case (AggregateExpression2(agg: AlgebraicAggregate, mode, isDistinct), offset) => - agg.inputSchema = child.output - agg.withBufferOffset(offset) + // println("aggregateExprsWithBufferOffset " + aggregateExprsWithBufferOffset) + + private val aggregateFunctions: Array[AggregateFunction2] = { + aggregateExprsWithBufferOffset.map { + case (aggExpr, bufferOffset) => + val func = aggExpr.aggregateFunction + func.bufferOffset = bufferOffset + func }.toArray } @@ -92,13 +91,15 @@ case class Aggregate2Sort( aggregateExprsWithBufferOffset.collect { case (AggregateExpression2(agg: AggregateFunction2, mode, isDistinct), offset) if !agg.isInstanceOf[AlgebraicAggregate] => - val func = agg.withBufferOffset(offset) mode match { case Partial | Complete => // Only need to bind reference when the function is not an AlgebraicAggregate // and the mode is Partial or Complete. - BindReferences.bindReference(func, child.output) - case _ => func + val func = BindReferences.bindReference(agg, child.output) + // Need to set it again since BindReference will create a new instance. + func.bufferOffset = offset + func + case _ => agg } }.toArray } @@ -119,13 +120,8 @@ case class Aggregate2Sort( private val bufferSize: Int = { var size = 0 var i = 0 - while (i < algebraicAggregateFunctions.length) { - size += algebraicAggregateFunctions(i).bufferSchema.length - i += 1 - } - i = 0 - while (i < nonAlgebraicAggregateFunctions.length) { - size += nonAlgebraicAggregateFunctions(i).bufferSchema.length + while (i < aggregateFunctions.length) { + size += aggregateFunctions(i).bufferSchema.length i += 1 } if (preShuffle) { @@ -160,8 +156,9 @@ case class Aggregate2Sort( val offsetExpressions = if (preShuffle) Nil else Seq.fill(groupingExpressions.length)(NoOp) val algebraicInitialProjection = { - val initExpressions = offsetExpressions ++ algebraicAggregateFunctions.flatMap { + val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.initialValues + case agg: AggregateFunction2 => NoOp :: Nil } // println(initExpressions.mkString(",")) @@ -169,11 +166,13 @@ case class Aggregate2Sort( } lazy val algebraicUpdateProjection = { - val bufferSchema = algebraicAggregateFunctions.flatMap { + val bufferSchema = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes } - val updateExpressions = algebraicAggregateFunctions.flatMap { + val updateExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => NoOp :: Nil } // println(updateExpressions.mkString(",")) @@ -182,13 +181,16 @@ case class Aggregate2Sort( lazy val algebraicMergeProjection = { val bufferSchemata = - offsetAttributes ++ algebraicAggregateFunctions.flatMap { + offsetAttributes ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.bufferAttributes - } ++ offsetAttributes ++ algebraicAggregateFunctions.flatMap { + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.rightBufferSchema + case agg: AggregateFunction2 => agg.rightBufferSchema } - val mergeExpressions = offsetExpressions ++ algebraicAggregateFunctions.flatMap { + val mergeExpressions = offsetExpressions ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => NoOp :: Nil } newMutableProjection(mergeExpressions, bufferSchemata)() @@ -196,13 +198,16 @@ case class Aggregate2Sort( lazy val algebraicEvalProjection = { val bufferSchemata = - offsetAttributes ++ algebraicAggregateFunctions.flatMap { + offsetAttributes ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.bufferAttributes - } ++ offsetAttributes ++ algebraicAggregateFunctions.flatMap { + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.rightBufferSchema + case agg: AggregateFunction2 => agg.rightBufferSchema } - val evalExpressions = algebraicAggregateFunctions.map { + val evalExpressions = aggregateFunctions.map { case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp } newMutableProjection(evalExpressions, bufferSchemata)() @@ -251,6 +256,7 @@ case class Aggregate2Sort( nonAlgebraicAggregateFunctions(i).merge(buffer, row) i += 1 } + // println("buffer merge " + buffer + " " + row) } } @@ -293,6 +299,7 @@ case class Aggregate2Sort( val outputRow = if (preShuffle) { // If it is preShuffle, we just output the grouping columns and the buffer. + // println("buffer " + buffer) joinedRow(currentGroupingKey, buffer).copy() } else { algebraicEvalProjection.target(aggregateResult)(buffer) @@ -304,7 +311,6 @@ case class Aggregate2Sort( i += 1 } resultProjection(joinedRow(currentGroupingKey, aggregateResult)) - } initializeBuffer() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/ConvertAggregateFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala similarity index 66% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/ConvertAggregateFunction.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala index bd8d03bcd2f4e..ea082b5dcc5aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/ConvertAggregateFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.aggregate2 -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.{Average => Average1} +import org.apache.spark.sql.{SQLConf, AnalysisException, SQLContext} +import org.apache.spark.sql.catalyst.expressions.{Average => Average1, AggregateExpression} import org.apache.spark.sql.catalyst.expressions.aggregate2.{Average => Average2, AggregateExpression2, Complete} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -32,3 +32,18 @@ case class ConvertAggregateFunction(context: SQLContext) extends Rule[LogicalPla } } } + +case class CheckAggregateFunction(context: SQLContext) extends (LogicalPlan => Unit) { + def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } + + def apply(plan: LogicalPlan): Unit = plan.foreachUp { + case p if context.conf.useSqlAggregate2 => p.transformExpressionsUp { + case agg: AggregateExpression => + failAnalysis(s"${SQLConf.USE_SQL_AGGREGATE2} is enabled. Please disable it to use $agg.") + } + case p if !context.conf.useSqlAggregate2 => p.transformExpressionsUp { + case agg: AggregateExpression2 => + failAnalysis(s"${SQLConf.USE_SQL_AGGREGATE2} is disabled. Please enable it to use $agg.") + } + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala index 2d2e83f1b1e34..48771e8f403d7 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala @@ -154,6 +154,34 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { Row(null) :: Nil) } + + test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") { + checkAnswer( + ctx.sql( + """ + |SELECT mydoublesum(cast(value as double)), key, avg(value) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(60.0, 1, 20.0) :: Row(-1.0, 2, -0.5) :: Row(null, 3, null) :: Nil) + + checkAnswer( + ctx.sql( + """ + |SELECT + | mydoublesum(cast(value as double) + 1.5 * key), + | avg(value - key), + | key, + | mydoublesum(cast(value as double) - 1.5 * key), + | avg(value) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(64.5, 19.0, 1, 55.5, 20.0) :: + Row(5.0, -2.5, 2, -7.0, -0.5) :: + Row(null, null, 3, null, null) :: Nil) + } + override def afterAll(): Unit = { ctx.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2") }