Skip to content

Commit

Permalink
Code cleanup: Remove unnecesary AttributeReferences.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 13, 2015
1 parent b7720ba commit 39ee975
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ private[sql] case class AggregateExpression2(

override def eval(input: InternalRow = null): Any =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

def bufferSchema: StructType = aggregateFunction.bufferSchema
def bufferAttributes: Seq[Attribute] = aggregateFunction.bufferAttributes
}

abstract class AggregateFunction2
Expand All @@ -77,7 +80,11 @@ abstract class AggregateFunction2
this
}

def bufferValueDataTypes: StructType
/** The schema of the aggregation buffer. */
def bufferSchema: StructType

/** Attributes of fields in bufferSchema. */
def bufferAttributes: Seq[Attribute]

def initialize(buffer: MutableRow): Unit

Expand All @@ -94,7 +101,6 @@ abstract class AggregateFunction2
abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
self: Product =>

val bufferSchema: Seq[Attribute]
val initialValues: Seq[Expression]
val updateExpressions: Seq[Expression]
val mergeExpressions: Seq[Expression]
Expand All @@ -105,23 +111,25 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{

def offsetExpressions: Seq[Attribute] = Seq.fill(bufferOffset)(AttributeReference("offset", NullType)())

lazy val rightBufferSchema = bufferSchema.map(_.newInstance())
lazy val rightBufferSchema = bufferAttributes.map(_.newInstance())
implicit class RichAttribute(a: AttributeReference) {
def left = a
def right = rightBufferSchema(bufferSchema.indexOf(a))
def right = rightBufferSchema(bufferAttributes.indexOf(a))
}

override def bufferValueDataTypes: StructType = StructType.fromAttributes(bufferSchema)
/** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */
override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)

override def initialize(buffer: MutableRow): Unit = {
var i = 0
while (i < bufferSchema.size) {
while (i < bufferAttributes.size) {
buffer(i + bufferOffset) = initialValues(i).eval()
i += 1
}
}

lazy val boundUpdateExpressions = {
val updateSchema = inputSchema ++ offsetExpressions ++ bufferSchema
val updateSchema = inputSchema ++ offsetExpressions ++ bufferAttributes
val bound = updateExpressions.map(BindReferences.bindReference(_, updateSchema)).toArray
println(s"update: ${updateExpressions.mkString(",")}")
println(s"update: ${bound.mkString(",")}")
Expand All @@ -131,29 +139,29 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable{
val joinedRow = new JoinedRow
override def update(buffer: MutableRow, input: InternalRow): Unit = {
var i = 0
while (i < bufferSchema.size) {
while (i < bufferAttributes.size) {
buffer(i + bufferOffset) = boundUpdateExpressions(i).eval(joinedRow(input, buffer))
i += 1
}
}

lazy val boundMergeExpressions = {
val mergeSchema = offsetExpressions ++ bufferSchema ++ offsetExpressions ++ rightBufferSchema
val mergeSchema = offsetExpressions ++ bufferAttributes ++ offsetExpressions ++ rightBufferSchema
mergeExpressions.map(BindReferences.bindReference(_, mergeSchema)).toArray
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
var i = 0
println(s"Merging: $buffer1 $buffer2 with ${boundMergeExpressions.mkString(",")}")
joinedRow(buffer1, buffer2)
while (i < bufferSchema.size) {
while (i < bufferAttributes.size) {
println(s"$i + $bufferOffset: ${boundMergeExpressions(i).eval(joinedRow)}")
buffer1(i + bufferOffset) = boundMergeExpressions(i).eval(joinedRow)
i += 1
}
}

lazy val boundEvaluateExpression =
BindReferences.bindReference(evaluateExpression, offsetExpressions ++ bufferSchema)
BindReferences.bindReference(evaluateExpression, offsetExpressions ++ bufferAttributes)
override def eval(buffer: InternalRow): Any = {
println(s"eval: $buffer")
val res = boundEvaluateExpression.eval(buffer)
Expand All @@ -170,26 +178,26 @@ case class Average(child: Expression) extends AlgebraicAggregate {
case _ => DoubleType
}

val intermediateType = child.dataType match {
val sumDataType = child.dataType match {
case _ @ DecimalType() => DecimalType.Unlimited
case _ => DoubleType
}

val currentSum = AttributeReference("currentSum", DoubleType)()
val currentSum = AttributeReference("currentSum", sumDataType)()
val currentCount = AttributeReference("currentCount", LongType)()

val bufferSchema = currentSum :: currentCount :: Nil
override val bufferAttributes = currentSum :: currentCount :: Nil

val initialValues = Seq(
/* currentSum = */ Cast(Literal(0), intermediateType),
/* currentSum = */ Cast(Literal(0), sumDataType),
/* currentCount = */ Literal(0L)
)

val updateExpressions = Seq(
/* currentSum = */
Add(
currentSum,
Coalesce(Cast(child, intermediateType) :: Cast(Literal(0), intermediateType) :: Nil)),
Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.aggregate2.Aggregate2Sort
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
import org.apache.spark.sql.parquet._
import org.apache.spark.sql.sources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
Expand Down Expand Up @@ -186,67 +187,71 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
exprs.flatMap(_.collect { case a: AggregateExpression => a })
}

/**
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
*/
object AggregateOperator2 extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Aggregate(groupingExpressions, resultExpressions, child)
if sqlContext.conf.useSqlAggregate2 =>
// 1. Extracts all aggregate expressions.
// 1. Extracts all distinct aggregate expressions from the resultExpressions.
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression2 => agg
}
}.toSet.toSeq
val aggregateExpressionMap = aggregateExpressions.zipWithIndex.map {
case (agg, index) =>
agg.aggregateFunction -> Alias(agg, s"_agg$index")().toAttribute
// For those distinct aggregate expressions, we create a map from the aggregate function
// to the corresponding attribute of the function.
val aggregateFunctionMap = aggregateExpressions.map { agg =>
val aggregateFunction = agg.aggregateFunction
aggregateFunction -> Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
}.toMap

// 2. Create Pre-shuffle Aggregate Operator
val namedGroupingExpressions = groupingExpressions.zipWithIndex.map {
case (ne: NamedExpression, index) => ne
case (other, index) => Alias(other, s"_groupingExpr$index")()
// 2. Create an Aggregate Operator for partial aggregations.
val namedGroupingExpressions = groupingExpressions.map {
case ne: NamedExpression => ne
// If the expression is not a NamedExpressions, we add an alias.
// So, when we generate the result of the operator, the Aggregate Operator
// can directly get the Seq of attributes representing the grouping expressions.
case other => Alias(other, other.toString)()
}
val namedGroupingAttributes = namedGroupingExpressions.map(_.toAttribute)
val preShuffleAggregateExpressions = aggregateExpressions.map {
val partialAggregateExpressions = aggregateExpressions.map {
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
AggregateExpression2(aggregateFunction, Partial, isDistinct)
}
val preShuffleAggregateAttributes = preShuffleAggregateExpressions.zipWithIndex.flatMap {
case (AggregateExpression2(aggregateFunction, Partial, isDistinct), index) =>
aggregateFunction.bufferValueDataTypes.map {
case StructField(name, dataType, nullable, metadata) =>
AttributeReference(s"_partialAgg${index}_${name}", dataType, nullable, metadata)()
}
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
agg.bufferAttributes
}
val partialAggregate =
Aggregate2Sort(
true,
namedGroupingExpressions,
preShuffleAggregateExpressions,
preShuffleAggregateAttributes,
namedGroupingAttributes ++ preShuffleAggregateAttributes,
partialAggregateExpressions,
partialAggregateAttributes,
namedGroupingAttributes ++ partialAggregateAttributes,
planLater(child))

// 3. Create post-shuffle Aggregate Operator.
val postShuffleAggregateExpressions = aggregateExpressions.map {
// 3. Create an Aggregate Operator for final aggregations.
val finalAggregateExpressions = aggregateExpressions.map {
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
AggregateExpression2(aggregateFunction, Final, isDistinct)
}
val postShuffleAggregateAttributes =
postShuffleAggregateExpressions.map {
expr => aggregateExpressionMap(expr.aggregateFunction)
val finalAggregateAttributes =
finalAggregateExpressions.map {
expr => aggregateFunctionMap(expr.aggregateFunction)
}
val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transform {
case agg: AggregateExpression2 =>
aggregateExpressionMap(agg.aggregateFunction).toAttribute
aggregateFunctionMap(agg.aggregateFunction).toAttribute
}.asInstanceOf[NamedExpression]
}
val finalAggregate = Aggregate2Sort(
false,
namedGroupingAttributes,
postShuffleAggregateExpressions,
postShuffleAggregateAttributes,
finalAggregateExpressions,
finalAggregateAttributes,
rewrittenResultExpressions,
partialAggregate)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,22 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
test("test average2") {
ctx.sql(
"""
|SELECT key, avg2(value)
|SELECT key, avg(value)
|FROM agg2
|GROUP BY key
""".stripMargin).explain(true)

ctx.sql(
"""
|SELECT key, avg2(value)
|SELECT key, avg(value)
|FROM agg2
|GROUP BY key
""".stripMargin).queryExecution.executedPlan(3).execute().collect().foreach(println)

checkAnswer(
ctx.sql(
"""
|SELECT key, avg2(value)
|SELECT key, avg(value)
|FROM agg2
|GROUP BY key
""".stripMargin),
Expand Down

0 comments on commit 39ee975

Please sign in to comment.