Skip to content

Commit

Permalink
now with correct answers\!
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Jul 10, 2015
1 parent f7996d0 commit 6bbc6ba
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
/** Must be filled in by the executors */
var inputSchema: Seq[Attribute] = _

def offsetExpressions: Seq[Attribute] = Seq.fill(bufferOffset)(AttributeReference("offset", NullType)())

lazy val rightBufferSchema = bufferSchema.map(_.newInstance())
implicit class RichAttribute(a: AttributeReference) {
def left = a
Expand All @@ -112,8 +114,11 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
}

lazy val boundUpdateExpressions = {
val updateSchema = inputSchema ++ bufferSchema
updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray
val updateSchema = inputSchema ++ offsetExpressions ++ bufferSchema
val bound = updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray
println(s"update: ${updateExpressions.mkString(",")}")
println(s"update: ${bound.mkString(",")}")
bound
}

val joinedRow = new JoinedRow
Expand All @@ -126,20 +131,27 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
}

lazy val boundMergeExpressions = {
val mergeSchema = bufferSchema ++ rightBufferSchema
val mergeSchema = offsetExpressions ++ bufferSchema ++ offsetExpressions ++ rightBufferSchema
mergeExpressions.map(BindReferences.bindReference(_, mergeSchema)).toArray
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
var i = 0
println(s"Merging: $buffer1 $buffer2 with ${boundMergeExpressions.mkString(",")}")
joinedRow(buffer1, buffer2)
while (i < bufferSchema.size) {
buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow(buffer1, buffer2))
println(s"$i + $bufferOffset: ${boundMergeExpressions(i).eval(joinedRow)}")
buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow)
i += 1
}
}

lazy val boundEvaluateExpression = BindReferences.bindReference(evaluateExpression, bufferSchema)
lazy val boundEvaluateExpression =
BindReferences.bindReference(evaluateExpression, offsetExpressions ++ bufferSchema)
override def eval(buffer: InternalRow): Any = {
boundEvaluateExpression.eval(buffer)
println(s"eval: $buffer")
val res = boundEvaluateExpression.eval(buffer)
println(s"eval: $buffer with $boundEvaluateExpression => $res")
res
}
}

Expand Down Expand Up @@ -171,15 +183,15 @@ case class Average(child: Expression) extends AlgebraicAggregate {
Add(
currentSum,
Coalesce(Cast(child, intermediateType) :: Cast(Literal(0), intermediateType) :: Nil)),
/* currentCount = */ If(IsNotNull(child), currentCount, currentCount + 1L)
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
)

val mergeExpressions = Seq(
/* currentSum = */ currentSum.left + currentSum.right,
/* currentCount = */ currentCount.left + currentCount.right
)

val evaluateExpression = Cast(currentCount, resultType) / Cast(currentSum, resultType)
val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)

override def nullable: Boolean = false
override def dataType: DataType = resultType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
|GROUP BY key
""".stripMargin).queryExecution.executedPlan(3).execute().collect().foreach(println)

ctx.sql(
"""
|SELECT key, avg2(value)
|FROM agg2
|GROUP BY key
""".stripMargin).show()
checkAnswer(
ctx.sql(
"""
|SELECT key, avg2(value)
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Nil)

}

override def afterAll(): Unit = {
Expand Down

0 comments on commit 6bbc6ba

Please sign in to comment.