From f7996d0fa9c4c9f3b40f6630130cb92c38d1eefd Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 10 Jul 2015 14:47:45 -0700 Subject: [PATCH] Add AlgebraicAggregate --- .../expressions/aggregate2/aggregates.scala | 160 ++++++++++-------- .../spark/sql/execution/Aggregate2Sort.scala | 6 + 2 files changed, 93 insertions(+), 73 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index bd87008164f76..b18fe40855673 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate2 +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.trees.{LeafNode, UnaryNode} import org.apache.spark.sql.types._ @@ -71,19 +72,8 @@ abstract class AggregateFunction2 def bufferValueDataTypes: StructType - def initialBufferValues: Array[Any] - def initialize(buffer: MutableRow): Unit - def updateBuffer(buffer: MutableRow, bufferValues: Array[Any]): Unit = { - var i = 0 - println("bufferOffset in average2 " + bufferOffset) - while (i < bufferValues.length) { - buffer.update(bufferOffset + i, bufferValues(i)) - i += 1 - } - } - def update(buffer: MutableRow, input: InternalRow): Unit def merge(buffer1: MutableRow, buffer2: InternalRow): Unit @@ -91,83 +81,107 @@ abstract class AggregateFunction2 override def eval(buffer: InternalRow = null): Any } -case class Average(child: Expression) - extends AggregateFunction2 with UnaryNode[Expression] { +/** + * A helper class for aggregate functions that can be implemented in terms of catalyst expressions. + */ +abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{ + self: Product => - override def nullable: Boolean = child.nullable + val bufferSchema: Seq[Attribute] + val initialValues: Seq[Expression] + val updateExpressions: Seq[Expression] + val mergeExpressions: Seq[Expression] + val evaluateExpression: Expression - override def bufferValueDataTypes: StructType = child match { - case e @ DecimalType() => - StructType( - StructField("Sum", DecimalType.Unlimited) :: - StructField("Count", LongType) :: Nil) - case _ => - StructType( - StructField("Sum", DoubleType) :: - StructField("Count", LongType) :: Nil) - } + /** Must be filled in by the executors */ + var inputSchema: Seq[Attribute] = _ - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 4, scale + 4) - case DecimalType.Unlimited => DecimalType.Unlimited - case _ => DoubleType + lazy val rightBufferSchema = bufferSchema.map(_.newInstance()) + implicit class RichAttribute(a: AttributeReference) { + def left = a + def right = rightBufferSchema(bufferSchema.indexOf(a)) } - override def initialBufferValues: Array[Any] = { - Array( - Cast(Literal(0), bufferValueDataTypes("Sum").dataType).eval(null), // Sum - 0L) // Count + override def bufferValueDataTypes: StructType = StructType.fromAttributes(bufferSchema) + override def initialize(buffer: MutableRow): Unit = { + var i = 0 + while (i < bufferSchema.size) { + buffer(i + bufferOffset) = initialValues(i).eval() + i += 1 + } } - override def initialize(buffer: MutableRow): Unit = - updateBuffer(buffer, initialBufferValues) - - private val inputLiteral = - MutableLiteral(null, child.dataType) - private val bufferedSum = - MutableLiteral(null, bufferValueDataTypes("Sum").dataType) - private val bufferedCount = MutableLiteral(null, LongType) - private val updateSum = - Add(Cast(inputLiteral, bufferValueDataTypes("Sum").dataType), bufferedSum) - private val inputBufferedSum = - MutableLiteral(null, bufferValueDataTypes("Sum").dataType) - private val mergeSum = Add(inputBufferedSum, bufferedSum) - private val evaluateAvg = - Cast(Divide(bufferedSum, Cast(bufferedCount, bufferValueDataTypes("Sum").dataType)), dataType) + lazy val boundUpdateExpressions = { + val updateSchema = inputSchema ++ bufferSchema + updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray + } + val joinedRow = new JoinedRow override def update(buffer: MutableRow, input: InternalRow): Unit = { - val newInput = child.eval(input) - println("newInput " + newInput) - if (newInput != null) { - inputLiteral.value = newInput - bufferedSum.value = buffer(bufferOffset) - val newSum = updateSum.eval(null) - val newCount = buffer.getLong(bufferOffset + 1) + 1 - buffer.update(bufferOffset, newSum) - buffer.update(bufferOffset + 1, newCount) + var i = 0 + while (i < bufferSchema.size) { + buffer(i + bufferOffset) = boundUpdateExpressions(i).eval(joinedRow(input, buffer)) + i += 1 } } + lazy val boundMergeExpressions = { + val mergeSchema = bufferSchema ++ rightBufferSchema + mergeExpressions.map(BindReferences.bindReference(_, mergeSchema)).toArray + } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - if (buffer2(bufferOffset + 1) != 0L) { - inputBufferedSum.value = buffer2(bufferOffset) - bufferedSum.value = buffer1(bufferOffset) - val newSum = mergeSum.eval(null) - val newCount = - buffer1.getLong(bufferOffset + 1) + buffer2.getLong(bufferOffset + 1) - buffer1.update(bufferOffset, newSum) - buffer1.update(bufferOffset + 1, newCount) + var i = 0 + while (i < bufferSchema.size) { + buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow(buffer1, buffer2)) + i += 1 } } + lazy val boundEvaluateExpression = BindReferences.bindReference(evaluateExpression, bufferSchema) override def eval(buffer: InternalRow): Any = { - if (buffer(bufferOffset + 1) == 0L) { - null - } else { - bufferedSum.value = buffer(bufferOffset) - bufferedCount.value = buffer.getLong(bufferOffset + 1) - evaluateAvg.eval(null) - } + boundEvaluateExpression.eval(buffer) } } + +case class Average(child: Expression) extends AlgebraicAggregate { + val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => DecimalType.Unlimited + case _ => DoubleType + } + + val intermediateType = child.dataType match { + case _ @ DecimalType() => DecimalType.Unlimited + case _ => DoubleType + } + + val currentSum = AttributeReference("currentSum", DoubleType)() + val currentCount = AttributeReference("currentCount", LongType)() + + val bufferSchema = currentSum :: currentCount :: Nil + + val initialValues = Seq( + /* currentSum = */ Cast(Literal(0), intermediateType), + /* currentCount = */ Literal(0L) + ) + + val updateExpressions = Seq( + /* currentSum = */ + Add( + currentSum, + Coalesce(Cast(child, intermediateType) :: Cast(Literal(0), intermediateType) :: Nil)), + /* currentCount = */ If(IsNotNull(child), currentCount, currentCount + 1L) + ) + + val mergeExpressions = Seq( + /* currentSum = */ currentSum.left + currentSum.right, + /* currentCount = */ currentCount.left + currentCount.right + ) + + val evaluateExpression = Cast(currentCount, resultType) / Cast(currentSum, resultType) + + override def nullable: Boolean = false + override def dataType: DataType = resultType + override def children: Seq[Expression] = child :: Nil +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate2Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate2Sort.scala index 5dddf09b09fd4..29a5561eca86e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate2Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate2Sort.scala @@ -32,6 +32,7 @@ case class Aggregate2Sort( child: SparkPlan) extends UnaryNode { + override def requiredChildDistribution: List[Distribution] = { if (preShuffle) { UnspecifiedDistribution :: Nil @@ -51,6 +52,7 @@ case class Aggregate2Sort( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => + new Iterator[InternalRow] { private val aggregateFunctions: Array[AggregateFunction2] = { var bufferOffset = @@ -74,6 +76,10 @@ case class Aggregate2Sort( i += 1 } + functions.foreach { + case ae: AlgebraicAggregate => ae.inputSchema = child.output + case _ => + } functions }