Skip to content

Commit

Permalink
Add AlgebraicAggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Jul 10, 2015
1 parent dded1c5 commit f7996d0
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

package org.apache.spark.sql.catalyst.expressions.aggregate2

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.trees.{LeafNode, UnaryNode}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -71,103 +72,116 @@ abstract class AggregateFunction2

def bufferValueDataTypes: StructType

def initialBufferValues: Array[Any]

def initialize(buffer: MutableRow): Unit

def updateBuffer(buffer: MutableRow, bufferValues: Array[Any]): Unit = {
var i = 0
println("bufferOffset in average2 " + bufferOffset)
while (i < bufferValues.length) {
buffer.update(bufferOffset + i, bufferValues(i))
i += 1
}
}

def update(buffer: MutableRow, input: InternalRow): Unit

def merge(buffer1: MutableRow, buffer2: InternalRow): Unit

override def eval(buffer: InternalRow = null): Any
}

case class Average(child: Expression)
extends AggregateFunction2 with UnaryNode[Expression] {
/**
* A helper class for aggregate functions that can be implemented in terms of catalyst expressions.
*/
abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
self: Product =>

override def nullable: Boolean = child.nullable
val bufferSchema: Seq[Attribute]
val initialValues: Seq[Expression]
val updateExpressions: Seq[Expression]
val mergeExpressions: Seq[Expression]
val evaluateExpression: Expression

override def bufferValueDataTypes: StructType = child match {
case e @ DecimalType() =>
StructType(
StructField("Sum", DecimalType.Unlimited) ::
StructField("Count", LongType) :: Nil)
case _ =>
StructType(
StructField("Sum", DoubleType) ::
StructField("Count", LongType) :: Nil)
}
/** Must be filled in by the executors */
var inputSchema: Seq[Attribute] = _

override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4)
case DecimalType.Unlimited => DecimalType.Unlimited
case _ => DoubleType
lazy val rightBufferSchema = bufferSchema.map(_.newInstance())
implicit class RichAttribute(a: AttributeReference) {
def left = a
def right = rightBufferSchema(bufferSchema.indexOf(a))
}

override def initialBufferValues: Array[Any] = {
Array(
Cast(Literal(0), bufferValueDataTypes("Sum").dataType).eval(null), // Sum
0L) // Count
override def bufferValueDataTypes: StructType = StructType.fromAttributes(bufferSchema)
override def initialize(buffer: MutableRow): Unit = {
var i = 0
while (i < bufferSchema.size) {
buffer(i + bufferOffset) = initialValues(i).eval()
i += 1
}
}

override def initialize(buffer: MutableRow): Unit =
updateBuffer(buffer, initialBufferValues)

private val inputLiteral =
MutableLiteral(null, child.dataType)
private val bufferedSum =
MutableLiteral(null, bufferValueDataTypes("Sum").dataType)
private val bufferedCount = MutableLiteral(null, LongType)
private val updateSum =
Add(Cast(inputLiteral, bufferValueDataTypes("Sum").dataType), bufferedSum)
private val inputBufferedSum =
MutableLiteral(null, bufferValueDataTypes("Sum").dataType)
private val mergeSum = Add(inputBufferedSum, bufferedSum)
private val evaluateAvg =
Cast(Divide(bufferedSum, Cast(bufferedCount, bufferValueDataTypes("Sum").dataType)), dataType)
lazy val boundUpdateExpressions = {
val updateSchema = inputSchema ++ bufferSchema
updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray
}

val joinedRow = new JoinedRow
override def update(buffer: MutableRow, input: InternalRow): Unit = {
val newInput = child.eval(input)
println("newInput " + newInput)
if (newInput != null) {
inputLiteral.value = newInput
bufferedSum.value = buffer(bufferOffset)
val newSum = updateSum.eval(null)
val newCount = buffer.getLong(bufferOffset + 1) + 1
buffer.update(bufferOffset, newSum)
buffer.update(bufferOffset + 1, newCount)
var i = 0
while (i < bufferSchema.size) {
buffer(i + bufferOffset) = boundUpdateExpressions(i).eval(joinedRow(input, buffer))
i += 1
}
}

lazy val boundMergeExpressions = {
val mergeSchema = bufferSchema ++ rightBufferSchema
mergeExpressions.map(BindReferences.bindReference(_, mergeSchema)).toArray
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
if (buffer2(bufferOffset + 1) != 0L) {
inputBufferedSum.value = buffer2(bufferOffset)
bufferedSum.value = buffer1(bufferOffset)
val newSum = mergeSum.eval(null)
val newCount =
buffer1.getLong(bufferOffset + 1) + buffer2.getLong(bufferOffset + 1)
buffer1.update(bufferOffset, newSum)
buffer1.update(bufferOffset + 1, newCount)
var i = 0
while (i < bufferSchema.size) {
buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow(buffer1, buffer2))
i += 1
}
}

lazy val boundEvaluateExpression = BindReferences.bindReference(evaluateExpression, bufferSchema)
override def eval(buffer: InternalRow): Any = {
if (buffer(bufferOffset + 1) == 0L) {
null
} else {
bufferedSum.value = buffer(bufferOffset)
bufferedCount.value = buffer.getLong(bufferOffset + 1)
evaluateAvg.eval(null)
}
boundEvaluateExpression.eval(buffer)
}
}

case class Average(child: Expression) extends AlgebraicAggregate {
val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4)
case DecimalType.Unlimited => DecimalType.Unlimited
case _ => DoubleType
}

val intermediateType = child.dataType match {
case _ @ DecimalType() => DecimalType.Unlimited
case _ => DoubleType
}

val currentSum = AttributeReference("currentSum", DoubleType)()
val currentCount = AttributeReference("currentCount", LongType)()

val bufferSchema = currentSum :: currentCount :: Nil

val initialValues = Seq(
/* currentSum = */ Cast(Literal(0), intermediateType),
/* currentCount = */ Literal(0L)
)

val updateExpressions = Seq(
/* currentSum = */
Add(
currentSum,
Coalesce(Cast(child, intermediateType) :: Cast(Literal(0), intermediateType) :: Nil)),
/* currentCount = */ If(IsNotNull(child), currentCount, currentCount + 1L)
)

val mergeExpressions = Seq(
/* currentSum = */ currentSum.left + currentSum.right,
/* currentCount = */ currentCount.left + currentCount.right
)

val evaluateExpression = Cast(currentCount, resultType) / Cast(currentSum, resultType)

override def nullable: Boolean = false
override def dataType: DataType = resultType
override def children: Seq[Expression] = child :: Nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ case class Aggregate2Sort(
child: SparkPlan)
extends UnaryNode {


override def requiredChildDistribution: List[Distribution] = {
if (preShuffle) {
UnspecifiedDistribution :: Nil
Expand All @@ -51,6 +52,7 @@ case class Aggregate2Sort(

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
child.execute().mapPartitions { iter =>

new Iterator[InternalRow] {
private val aggregateFunctions: Array[AggregateFunction2] = {
var bufferOffset =
Expand All @@ -74,6 +76,10 @@ case class Aggregate2Sort(
i += 1
}

functions.foreach {
case ae: AlgebraicAggregate => ae.inputSchema = child.output
case _ =>
}
functions
}

Expand Down

0 comments on commit f7996d0

Please sign in to comment.