diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d01d0e13b2154..1ae1f5a3c7976 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -541,7 +541,6 @@ class Analyzer( def containsAggregates(exprs: Seq[Expression]): Boolean = { exprs.foreach(_.foreach { case agg: AggregateExpression => return true - case agg2: AggregateExpression2 => return true case _ => }) false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 602e0be6876ce..52b2644c0052c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -86,7 +86,6 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK - case _: AggregateExpression2 => // OK case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 82e27bcc8ffa3..101679adf3837 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -89,7 +89,9 @@ abstract class Expression extends TreeNode[Expression] { val primitive = ctx.freshName("primitive") val ve = GeneratedExpressionCode("", isNull, primitive) ve.code = genCode(ctx, ve) - ve.copy(s"/* $this */\n" + ve.code) + // We may want to print out $this in the comment of generated code for debugging. + // ve.copy(s"/* $this */\n" + ve.code) + ve } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index 76d21d65cf9c5..bc0a6a8e8a30f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -20,22 +20,22 @@ package org.apache.spark.sql.catalyst.expressions.aggregate2 import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row -/** The mode of an [[AggregateFunction]]. */ +/** The mode of an [[AggregateFunction1]]. */ private[sql] sealed trait AggregateMode /** - * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation. + * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ private[sql] case object Partial extends AggregateMode /** - * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers * containing intermediate results for this function. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. @@ -43,7 +43,7 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers * containing intermediate results for this function and the generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. @@ -58,7 +58,7 @@ private[sql] case object Final extends AggregateMode */ private[sql] case object Complete extends AggregateMode -private[sql] case object NoOp extends Expression { +private[sql] case object NoOp extends Expression with Unevaluable { override def nullable: Boolean = true override def eval(input: InternalRow): Any = { throw new TreeNodeException( @@ -78,7 +78,7 @@ private[sql] case object NoOp extends Expression { private[sql] case class AggregateExpression2( aggregateFunction: AggregateFunction2, mode: AggregateMode, - isDistinct: Boolean) extends Expression { + isDistinct: Boolean) extends Expression with Unevaluable { override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType @@ -86,11 +86,6 @@ private[sql] case class AggregateExpression2( override def nullable: Boolean = aggregateFunction.nullable override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" - - override def eval(input: InternalRow = null): Any = { - throw new TreeNodeException( - this, s"No function to evaluate expression. type: ${this.nodeName}") - } } abstract class AggregateFunction2 @@ -136,6 +131,9 @@ abstract class AggregateFunction2 * and `buffer2`. */ def merge(buffer1: MutableRow, buffer2: InternalRow): Unit + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index d705a1286065c..e07c920a41d0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -27,7 +27,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet -trait AggregateExpression extends Expression with Unevaluable { +trait AggregateExpression extends Expression with Unevaluable + +trait AggregateExpression1 extends AggregateExpression { /** * Aggregate expressions should not be foldable. @@ -38,7 +40,7 @@ trait AggregateExpression extends Expression with Unevaluable { * Creates a new instance that can be used to compute this aggregate expression for a group * of input rows/ */ - def newInstance(): AggregateFunction + def newInstance(): AggregateFunction1 } /** @@ -54,10 +56,10 @@ case class SplitEvaluation( partialEvaluations: Seq[NamedExpression]) /** - * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples. + * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples. * These partial evaluations can then be combined to compute the actual answer. */ -trait PartialAggregate extends AggregateExpression { +trait PartialAggregate1 extends AggregateExpression1 { /** * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. @@ -67,13 +69,13 @@ trait PartialAggregate extends AggregateExpression { /** * A specific implementation of an aggregate function. Used to wrap a generic - * [[AggregateExpression]] with an algorithm that will be used to compute one specific result. + * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. */ -abstract class AggregateFunction - extends LeafExpression with AggregateExpression with Serializable { +abstract class AggregateFunction1 + extends LeafExpression with AggregateExpression1 with Serializable { /** Base should return the generic aggregate expression that this function is computing */ - val base: AggregateExpression + val base: AggregateExpression1 override def nullable: Boolean = base.nullable override def dataType: DataType = base.dataType @@ -81,12 +83,12 @@ abstract class AggregateFunction def update(input: InternalRow): Unit // Do we really need this? - override def newInstance(): AggregateFunction = { + override def newInstance(): AggregateFunction1 = { makeCopy(productIterator.map { case a: AnyRef => a }.toArray) } } -case class Min(child: Expression) extends UnaryExpression with PartialAggregate { +case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -102,7 +104,7 @@ case class Min(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForOrderingExpr(child.dataType, "function min") } -case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) @@ -119,7 +121,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMin.value } -case class Max(child: Expression) extends UnaryExpression with PartialAggregate { +case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -135,7 +137,7 @@ case class Max(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForOrderingExpr(child.dataType, "function max") } -case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) @@ -152,7 +154,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMax.value } -case class Count(child: Expression) extends UnaryExpression with PartialAggregate { +case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -165,7 +167,7 @@ case class Count(child: Expression) extends UnaryExpression with PartialAggregat override def newInstance(): CountFunction = new CountFunction(child, this) } -case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var count: Long = _ @@ -180,7 +182,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = count } -case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { +case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 { def this() = this(null) override def children: Seq[Expression] = expressions @@ -200,8 +202,8 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate case class CountDistinctFunction( @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -220,7 +222,7 @@ case class CountDistinctFunction( override def eval(input: InternalRow): Any = seen.size.toLong } -case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { +case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = expressions @@ -233,8 +235,8 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress case class CollectHashSetFunction( @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -255,7 +257,7 @@ case class CollectHashSetFunction( } } -case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { +case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = inputSet :: Nil @@ -269,8 +271,8 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression case class CombineSetsAndCountFunction( @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -305,7 +307,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { } case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression { + extends UnaryExpression with AggregateExpression1 { override def nullable: Boolean = false override def dataType: DataType = HyperLogLogUDT @@ -317,9 +319,9 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) case class ApproxCountDistinctPartitionFunction( expr: Expression, - base: AggregateExpression, + base: AggregateExpression1, relativeSD: Double) - extends AggregateFunction { + extends AggregateFunction1 { def this() = this(null, null, 0) // Required for serialization. private val hyperLogLog = new HyperLogLog(relativeSD) @@ -335,7 +337,7 @@ case class ApproxCountDistinctPartitionFunction( } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression { + extends UnaryExpression with AggregateExpression1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -347,9 +349,9 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) case class ApproxCountDistinctMergeFunction( expr: Expression, - base: AggregateExpression, + base: AggregateExpression1, relativeSD: Double) - extends AggregateFunction { + extends AggregateFunction1 { def this() = this(null, null, 0) // Required for serialization. private val hyperLogLog = new HyperLogLog(relativeSD) @@ -363,7 +365,7 @@ case class ApproxCountDistinctMergeFunction( } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends UnaryExpression with PartialAggregate { + extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -381,7 +383,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) } -case class Average(child: Expression) extends UnaryExpression with PartialAggregate { +case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 { override def prettyName: String = "avg" @@ -427,8 +429,8 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg TypeUtils.checkForNumericExpr(child.dataType, "function average") } -case class AverageFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class AverageFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -474,7 +476,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } -case class Sum(child: Expression) extends UnaryExpression with PartialAggregate { +case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true @@ -509,7 +511,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForNumericExpr(child.dataType, "function sum") } -case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. private val calcType = @@ -554,7 +556,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr * <-- null <-- no data * null <-- null <-- no data */ -case class CombineSum(child: Expression) extends AggregateExpression { +case class CombineSum(child: Expression) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = child :: Nil @@ -564,8 +566,8 @@ case class CombineSum(child: Expression) extends AggregateExpression { override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) } -case class CombineSumFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class CombineSumFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -601,7 +603,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } -case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate { +case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { def this() = this(null) override def nullable: Boolean = true @@ -627,8 +629,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") } -case class SumDistinctFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class SumDistinctFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -653,7 +655,7 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 { def this() = this(null, null) override def children: Seq[Expression] = inputSet :: Nil @@ -667,8 +669,8 @@ case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends Agg case class CombineSetsAndSumFunction( @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -695,7 +697,7 @@ case class CombineSetsAndSumFunction( } } -case class First(child: Expression) extends UnaryExpression with PartialAggregate { +case class First(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType override def toString: String = s"FIRST($child)" @@ -709,7 +711,7 @@ case class First(child: Expression) extends UnaryExpression with PartialAggregat override def newInstance(): FirstFunction = new FirstFunction(child, this) } -case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var result: Any = null @@ -723,7 +725,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = result } -case class Last(child: Expression) extends UnaryExpression with PartialAggregate { +case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 { override def references: AttributeSet = child.references override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -738,7 +740,7 @@ case class Last(child: Expression) extends UnaryExpression with PartialAggregate override def newInstance(): LastFunction = new LastFunction(child, this) } -case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var result: Any = null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 179a348d5baac..b8e3b0d53a505 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -129,10 +129,10 @@ object PartialAggregation { case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => // Collect all aggregate expressions. val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a}) // Collect all aggregate expressions that can be computed partially. val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p}) // Only do partial aggregation if supported by all aggregate expressions. if (allAggregates.size == partialAggregates.size) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5074e6bcef772..f5aeff4b16855 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -29,7 +29,6 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override lazy val resolved: Boolean = { val hasSpecialExpressions = projectList.exists ( _.collect { case agg: AggregateExpression => agg - case agg: AggregateExpression2 => agg case generator: Generator => generator case window: WindowExpression => window }.nonEmpty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 3cd60a2aa55ed..c2c945321db95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -68,14 +68,14 @@ case class Aggregate( * output. */ case class ComputedAggregate( - unbound: AggregateExpression, - aggregate: AggregateExpression, + unbound: AggregateExpression1, + aggregate: AggregateExpression1, resultAttribute: AttributeReference) /** A list of aggregates that need to be computed for each group. */ private[this] val computedAggregates = aggregateExpressions.flatMap { agg => agg.collect { - case a: AggregateExpression => + case a: AggregateExpression1 => ComputedAggregate( a, BindReferences.bindReference(a, child.output), @@ -87,8 +87,8 @@ case class Aggregate( private[this] val computedSchema = computedAggregates.map(_.resultAttribute) /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction] = { - val buffer = new Array[AggregateFunction](computedAggregates.length) + private[this] def newAggregateBuffer(): Array[AggregateFunction1] = { + val buffer = new Array[AggregateFunction1](computedAggregates.length) var i = 0 while (i < computedAggregates.length) { buffer(i) = computedAggregates(i).aggregate.newInstance() @@ -146,7 +146,7 @@ case class Aggregate( } } else { child.execute().mapPartitions { iter => - val hashTable = new HashMap[InternalRow, Array[AggregateFunction]] + val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]] val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) var currentRow: InternalRow = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index c069da016f9f0..d6b264105d6c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -69,7 +69,7 @@ case class GeneratedAggregate( protected override def doExecute(): RDD[InternalRow] = { val aggregatesToCompute = aggregateExpressions.flatMap { a => - a.collect { case agg: AggregateExpression => agg} + a.collect { case agg: AggregateExpression1 => agg} } // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e7bb2bff47ff9..70197993730f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -184,7 +184,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { + def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists { case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && @@ -192,8 +192,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => true } - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] = - exprs.flatMap(_.collect { case a: AggregateExpression => a }) + def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = + exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala index ea082b5dcc5aa..c1c7869fdb426 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate2 import org.apache.spark.sql.{SQLConf, AnalysisException, SQLContext} -import org.apache.spark.sql.catalyst.expressions.{Average => Average1, AggregateExpression} +import org.apache.spark.sql.catalyst.expressions.{Average => Average1, AggregateExpression1} import org.apache.spark.sql.catalyst.expressions.aggregate2.{Average => Average2, AggregateExpression2, Complete} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -38,7 +38,7 @@ case class CheckAggregateFunction(context: SQLContext) extends (LogicalPlan => U def apply(plan: LogicalPlan): Unit = plan.foreachUp { case p if context.conf.useSqlAggregate2 => p.transformExpressionsUp { - case agg: AggregateExpression => + case agg: AggregateExpression1 => failAnalysis(s"${SQLConf.USE_SQL_AGGREGATE2} is enabled. Please disable it to use $agg.") } case p if !context.conf.useSqlAggregate2 => p.transformExpressionsUp { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 4d23c7035c03d..3259b50acc765 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -409,7 +409,7 @@ private[hive] case class HiveWindowFunction( private[hive] case class HiveGenericUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = AbstractGenericUDAFResolver @@ -441,7 +441,7 @@ private[hive] case class HiveGenericUDAF( /** It is used as a wrapper for the hive functions which uses UDAF interface */ private[hive] case class HiveUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = UDAF @@ -550,9 +550,9 @@ private[hive] case class HiveGenericUDTF( private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], - base: AggregateExpression, + base: AggregateExpression1, isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction + extends AggregateFunction1 with HiveInspectors { def this() = this(null, null, null)