diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index b18fe40855673..ec39ec68e72ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -96,6 +96,8 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{ /** Must be filled in by the executors */ var inputSchema: Seq[Attribute] = _ + def offsetExpressions: Seq[Attribute] = Seq.fill(bufferOffset)(AttributeReference("offset", NullType)()) + lazy val rightBufferSchema = bufferSchema.map(_.newInstance()) implicit class RichAttribute(a: AttributeReference) { def left = a @@ -112,8 +114,11 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{ } lazy val boundUpdateExpressions = { - val updateSchema = inputSchema ++ bufferSchema - updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray + val updateSchema = inputSchema ++ offsetExpressions ++ bufferSchema + val bound = updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray + println(s"update: ${updateExpressions.mkString(",")}") + println(s"update: ${bound.mkString(",")}") + bound } val joinedRow = new JoinedRow @@ -126,20 +131,27 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{ } lazy val boundMergeExpressions = { - val mergeSchema = bufferSchema ++ rightBufferSchema + val mergeSchema = offsetExpressions ++ bufferSchema ++ offsetExpressions ++ rightBufferSchema mergeExpressions.map(BindReferences.bindReference(_, mergeSchema)).toArray } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { var i = 0 + println(s"Merging: $buffer1 $buffer2 with ${boundMergeExpressions.mkString(",")}") + joinedRow(buffer1, buffer2) while (i < bufferSchema.size) { - buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow(buffer1, buffer2)) + println(s"$i + $bufferOffset: ${boundMergeExpressions(i).eval(joinedRow)}") + buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow) i += 1 } } - lazy val boundEvaluateExpression = BindReferences.bindReference(evaluateExpression, bufferSchema) + lazy val boundEvaluateExpression = + BindReferences.bindReference(evaluateExpression, offsetExpressions ++ bufferSchema) override def eval(buffer: InternalRow): Any = { - boundEvaluateExpression.eval(buffer) + println(s"eval: $buffer") + val res = boundEvaluateExpression.eval(buffer) + println(s"eval: $buffer with $boundEvaluateExpression => $res") + res } } @@ -171,7 +183,7 @@ case class Average(child: Expression) extends AlgebraicAggregate { Add( currentSum, Coalesce(Cast(child, intermediateType) :: Cast(Literal(0), intermediateType) :: Nil)), - /* currentCount = */ If(IsNotNull(child), currentCount, currentCount + 1L) + /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) ) val mergeExpressions = Seq( @@ -179,7 +191,7 @@ case class Average(child: Expression) extends AlgebraicAggregate { /* currentCount = */ currentCount.left + currentCount.right ) - val evaluateExpression = Cast(currentCount, resultType) / Cast(currentSum, resultType) + val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType) override def nullable: Boolean = false override def dataType: DataType = resultType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/Aggregate2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/Aggregate2Suite.scala index a82cf9d7b47bd..b65872dc898d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/Aggregate2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/Aggregate2Suite.scala @@ -61,12 +61,15 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { |GROUP BY key """.stripMargin).queryExecution.executedPlan(3).execute().collect().foreach(println) - ctx.sql( - """ - |SELECT key, avg2(value) - |FROM agg2 - |GROUP BY key - """.stripMargin).show() + checkAnswer( + ctx.sql( + """ + |SELECT key, avg2(value) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Nil) + } override def afterAll(): Unit = {