Skip to content

Commit

Permalink
Do not bind references in AlgebraicAggregate and use code gen for all…
Browse files Browse the repository at this point in the history
… places.
  • Loading branch information
yhuai committed Jul 14, 2015
1 parent 072209f commit 1b0bb3f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,45 +128,19 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
}
}

lazy val boundUpdateExpressions = {
val updateSchema = inputSchema ++ offsetExpressions ++ bufferAttributes
val bound = updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray
println(s"update: ${updateExpressions.mkString(",")}")
println(s"update: ${bound.mkString(",")}")
bound
}

val joinedRow = new JoinedRow
override def update(buffer: MutableRow, input: InternalRow): Unit = {
var i = 0
while (i < bufferAttributes.size) {
buffer(i + bufferOffset) = boundUpdateExpressions(i).eval(joinedRow(input, buffer))
i += 1
}
throw new UnsupportedOperationException(
"AlgebraicAggregate's update should not be called directly")
}

lazy val boundMergeExpressions = {
val mergeSchema = offsetExpressions ++ bufferAttributes ++ offsetExpressions ++ rightBufferSchema
mergeExpressions.map(BindReferences.bindReference(_, mergeSchema)).toArray
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
var i = 0
println(s"Merging: $buffer1 $buffer2 with ${boundMergeExpressions.mkString(",")}")
joinedRow(buffer1, buffer2)
while (i < bufferAttributes.size) {
println(s"$i + $bufferOffset: ${boundMergeExpressions(i).eval(joinedRow)}")
buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow)
i += 1
}
throw new UnsupportedOperationException(
"AlgebraicAggregate's merge should not be called directly")
}

lazy val boundEvaluateExpression =
BindReferences.bindReference(evaluateExpression, offsetExpressions ++ bufferAttributes)
override def eval(buffer: InternalRow): Any = {
println(s"eval: $buffer")
val res = boundEvaluateExpression.eval(buffer)
println(s"eval: $buffer with $boundEvaluateExpression => $res")
res
throw new UnsupportedOperationException(
"AlgebraicAggregate's eval should not be called directly")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ case class Aggregate2Sort(
private val buffer: MutableRow = new GenericMutableRow(bufferSize)
private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length)
private val joinedRow = new JoinedRow4
private val resultProjection =
new InterpretedMutableProjection(
resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)
private lazy val resultProjection =
newMutableProjection(
resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()

val offsetAttributes = if (preShuffle) Nil else Seq.fill(groupingExpressions.length)(AttributeReference("offset", NullType)())
val offsetExpressions = if (preShuffle) Nil else Seq.fill(groupingExpressions.length)(NoOp)
Expand All @@ -128,7 +128,7 @@ case class Aggregate2Sort(
val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.initialValues
}
println(initExpressions.mkString(","))
// println(initExpressions.mkString(","))
newMutableProjection(initExpressions, Nil)().target(buffer)
}

Expand All @@ -140,24 +140,38 @@ case class Aggregate2Sort(
case ae: AlgebraicAggregate => ae.updateExpressions
}

println(updateExpressions.mkString(","))
// println(updateExpressions.mkString(","))
newMutableProjection(updateExpressions, bufferSchema ++ child.output)().target(buffer)
}

val mergeProjection = {
lazy val mergeProjection = {
val bufferSchemata =
offsetAttributes ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.bufferAttributes
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.rightBufferSchema
}
val mergeExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
}
val mergeExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
}

newMutableProjection(mergeExpressions, bufferSchemata)()
}

lazy val evalProjection = {
val bufferSchemata =
offsetAttributes ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.bufferAttributes
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.rightBufferSchema
}
val evalExpressions = aggregateFunctions.map {
case ae: AlgebraicAggregate => ae.evaluateExpression
}

newMutableProjection(evalExpressions, bufferSchemata)()
}

// Initialize this iterator.
initialize()

Expand All @@ -177,7 +191,7 @@ case class Aggregate2Sort(

private def initializeBuffer(): Unit = {
initialProjection(EmptyRow)
println("initilized: " + buffer)
// println("initilized: " + buffer)
}

private def processRow(row: InternalRow): Unit = {
Expand Down Expand Up @@ -230,16 +244,20 @@ case class Aggregate2Sort(
// If it is preShuffle, we just output the grouping columns and the buffer.
joinedRow(currentGroupingKey, buffer).copy()
} else {
/*
var i = 0
while (i < aggregateFunctions.length) {
aggregateResult.update(i, aggregateFunctions(i).eval(buffer))
i += 1
}
resultProjection(joinedRow(currentGroupingKey, aggregateResult)).copy()
*/
resultProjection(joinedRow(currentGroupingKey, evalProjection.target(aggregateResult)(buffer)))

}
initializeBuffer()

println(s"outputRow $preShuffle " + outputRow)
// println(s"outputRow $preShuffle " + outputRow)
outputRow
} else {
// no more result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,24 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
""".stripMargin),
Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Nil)

checkAnswer(
ctx.sql(
"""
|SELECT avg(value), key
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Nil)

checkAnswer(
ctx.sql(
"""
|SELECT avg(value) + 1.5, key + 10
|FROM agg2
|GROUP BY key + 10
""".stripMargin),
Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Nil)

}

override def afterAll(): Unit = {
Expand Down

0 comments on commit 1b0bb3f

Please sign in to comment.