Skip to content

Commit

Permalink
Change existing AggregateExpression to AggregateExpression1 and add a…
Browse files Browse the repository at this point in the history
…n AggregateExpression as the common interface for both AggregateExpression1 and AggregateExpression2.
  • Loading branch information
yhuai committed Jul 20, 2015
1 parent 380880f commit 594cdf5
Show file tree
Hide file tree
Showing 12 changed files with 83 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,6 @@ class Analyzer(
def containsAggregates(exprs: Seq[Expression]): Boolean = {
exprs.foreach(_.foreach {
case agg: AggregateExpression => return true
case agg2: AggregateExpression2 => return true
case _ =>
})
false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case _: AggregateExpression2 => // OK
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ abstract class Expression extends TreeNode[Expression] {
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
ve.copy(s"/* $this */\n" + ve.code)
// We may want to print out $this in the comment of generated code for debugging.
// ve.copy(s"/* $this */\n" + ve.code)
ve
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,30 @@ package org.apache.spark.sql.catalyst.expressions.aggregate2
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

/** The mode of an [[AggregateFunction]]. */
/** The mode of an [[AggregateFunction1]]. */
private[sql] sealed trait AggregateMode

/**
* An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation.
* An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the aggregation buffer is returned.
*/
private[sql] case object Partial extends AggregateMode

/**
* An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers
* An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
* containing intermediate results for this function.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the aggregation buffer is returned.
*/
private[sql] case object PartialMerge extends AggregateMode

/**
* An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers
* An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
* containing intermediate results for this function and the generate final result.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the final result of this function is returned.
Expand All @@ -58,7 +58,7 @@ private[sql] case object Final extends AggregateMode
*/
private[sql] case object Complete extends AggregateMode

private[sql] case object NoOp extends Expression {
private[sql] case object NoOp extends Expression with Unevaluable {
override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
throw new TreeNodeException(
Expand All @@ -78,19 +78,14 @@ private[sql] case object NoOp extends Expression {
private[sql] case class AggregateExpression2(
aggregateFunction: AggregateFunction2,
mode: AggregateMode,
isDistinct: Boolean) extends Expression {
isDistinct: Boolean) extends Expression with Unevaluable {

override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
override def foldable: Boolean = false
override def nullable: Boolean = aggregateFunction.nullable

override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"

override def eval(input: InternalRow = null): Any = {
throw new TreeNodeException(
this, s"No function to evaluate expression. type: ${this.nodeName}")
}
}

abstract class AggregateFunction2
Expand Down Expand Up @@ -136,6 +131,9 @@ abstract class AggregateFunction2
* and `buffer2`.
*/
def merge(buffer1: MutableRow, buffer2: InternalRow): Unit

override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
}

/**
Expand Down
Loading

0 comments on commit 594cdf5

Please sign in to comment.