Skip to content

Commit

Permalink
Support single distinct column set. WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 20, 2015
1 parent 3013579 commit 68b8ee9
Show file tree
Hide file tree
Showing 12 changed files with 680 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
}
}
| ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
| ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
lexical.normalizeKeyword(udfName) match {
case "sum" => SumDistinct(exprs.head)
case "count" => CountDistinct(exprs)
case name => UnresolvedFunction(name, exprs, isDistinct = true)
case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.aggregate2.{Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.expressions.aggregate2.{DistinctAggregateExpression1, Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -278,7 +278,7 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
Expand Down Expand Up @@ -518,10 +518,12 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan =>
q transformExpressions {
case u @ UnresolvedFunction(name, children) =>
case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
registry.lookupFunction(name, children) match {
case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, false)
case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
case agg1: AggregateExpression1 if isDistinct =>
DistinctAggregateExpression1(agg1)
case other => other
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ object UnresolvedAttribute {
def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
}

case class UnresolvedFunction(name: String, children: Seq[Expression])
case class UnresolvedFunction(
name: String,
children: Seq[Expression],
isDistinct: Boolean)
extends Expression with Unevaluable {

override def dataType: DataType = throw new UnresolvedException(this, "dataType")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ private[sql] case object NoOp extends Expression with Unevaluable {
override def children: Seq[Expression] = Nil
}

private[sql] case class DistinctAggregateExpression1(
aggregateExpression: AggregateExpression1) extends AggregateExpression {
override def children: Seq[Expression] = aggregateExpression :: Nil
override def dataType: DataType = aggregateExpression.dataType
override def foldable: Boolean = aggregateExpression.foldable
override def nullable: Boolean = aggregateExpression.nullable

override def toString: String = s"DISTINCT ${aggregateExpression.toString}"
}

/**
* A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.aggregate2.Aggregate2Sort
import org.apache.spark.sql.execution.aggregate2.{FinalAndCompleteAggregate2Sort, Aggregate2Sort}
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
import org.apache.spark.sql.parquet._
import org.apache.spark.sql.sources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
Expand Down Expand Up @@ -200,6 +200,181 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
*/
object AggregateOperator2 extends Strategy {
private def planAggregateWithoutDistinct(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[AggregateExpression2],
aggregateFunctionMap: Map[AggregateFunction2, Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
// 1. Create an Aggregate Operator for partial aggregations.
val namedGroupingExpressions = groupingExpressions.map {
case ne: NamedExpression => ne -> ne
// If the expression is not a NamedExpressions, we add an alias.
// So, when we generate the result of the operator, the Aggregate Operator
// can directly get the Seq of attributes representing the grouping expressions.
case other =>
val withAlias = Alias(other, other.toString)()
other -> withAlias
}
val groupExpressionMap = namedGroupingExpressions.toMap
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
val partialAggregateExpressions = aggregateExpressions.map {
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
AggregateExpression2(aggregateFunction, Partial, isDistinct)
}
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
agg.aggregateFunction.bufferAttributes
}
val partialAggregate =
Aggregate2Sort(
None: Option[Seq[Expression]],
namedGroupingExpressions.map(_._2),
partialAggregateExpressions,
partialAggregateAttributes,
namedGroupingAttributes ++ partialAggregateAttributes,
child)

// 2. Create an Aggregate Operator for final aggregations.
val finalAggregateExpressions = aggregateExpressions.map {
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
AggregateExpression2(aggregateFunction, Final, isDistinct)
}
val finalAggregateAttributes =
finalAggregateExpressions.map {
expr => aggregateFunctionMap(expr.aggregateFunction)
}
val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transform {
case agg: AggregateExpression2 =>
aggregateFunctionMap(agg.aggregateFunction).toAttribute
case expression if groupExpressionMap.contains(expression) =>
groupExpressionMap(expression).toAttribute
}.asInstanceOf[NamedExpression]
}
val finalAggregate = Aggregate2Sort(
Some(namedGroupingAttributes),
namedGroupingAttributes,
finalAggregateExpressions,
finalAggregateAttributes,
rewrittenResultExpressions,
partialAggregate)

finalAggregate :: Nil
}

private def planAggregateWithOneDistinct(
groupingExpressions: Seq[Expression],
functionsWithDistinct: Seq[AggregateExpression2],
functionsWithoutDistinct: Seq[AggregateExpression2],
aggregateFunctionMap: Map[AggregateFunction2, Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {

// 1. Create an Aggregate Operator for partial aggregations.
// The grouping expressions are original groupingExpressions and
// distinct columns. For example, for avg(distinct value) ... group by key
// the grouping expressions of this Aggregate Operator will be [key, value].
val namedGroupingExpressions = groupingExpressions.map {
case ne: NamedExpression => ne -> ne
// If the expression is not a NamedExpressions, we add an alias.
// So, when we generate the result of the operator, the Aggregate Operator
// can directly get the Seq of attributes representing the grouping expressions.
case other =>
val withAlias = Alias(other, other.toString)()
other -> withAlias
}
val groupExpressionMap = namedGroupingExpressions.toMap
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)

// It is safe to call head at here since functionsWithDistinct has at least one
// AggregateExpression2.
val distinctColumnExpressions =
functionsWithDistinct.head.aggregateFunction.children
val namedDistinctColumnExpressions = distinctColumnExpressions.map {
case ne: NamedExpression => ne -> ne
case other =>
val withAlias = Alias(other, other.toString)()
other -> withAlias
}
val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)

val partialAggregateExpressions = functionsWithoutDistinct.map {
case AggregateExpression2(aggregateFunction, mode, _) =>
AggregateExpression2(aggregateFunction, Partial, false)
}
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
agg.aggregateFunction.bufferAttributes
}
println("namedDistinctColumnExpressions " + namedDistinctColumnExpressions)
val partialAggregate =
Aggregate2Sort(
None: Option[Seq[Expression]],
(namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2),
partialAggregateExpressions,
partialAggregateAttributes,
namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes,
child)

// 2. Create an Aggregate Operator for partial merge aggregations.
val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
case AggregateExpression2(aggregateFunction, mode, _) =>
AggregateExpression2(aggregateFunction, PartialMerge, false)
}
val partialMergeAggregateAttributes =
partialMergeAggregateExpressions.map {
expr => aggregateFunctionMap(expr.aggregateFunction)
}
val partialMergeAggregate =
Aggregate2Sort(
Some(namedGroupingAttributes),
namedGroupingAttributes ++ distinctColumnAttributes,
partialMergeAggregateExpressions,
partialMergeAggregateAttributes,
namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes,
partialAggregate)

// 3. Create an Aggregate Operator for partial merge aggregations.
val finalAggregateExpressions = functionsWithoutDistinct.map {
Need to replace the children to distinctColumnAttributes
case AggregateExpression2(aggregateFunction, mode, _) =>
AggregateExpression2(aggregateFunction, Final, false)
}
val finalAggregateAttributes =
finalAggregateExpressions.map {
expr => aggregateFunctionMap(expr.aggregateFunction)
}
val completeAggregateExpressions = functionsWithDistinct.map {
case AggregateExpression2(aggregateFunction, mode, _) =>
AggregateExpression2(aggregateFunction, Complete, false)
}
val completeAggregateAttributes =
completeAggregateExpressions.map {
expr => aggregateFunctionMap(expr.aggregateFunction)
}

val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transform {
case agg: AggregateExpression2 =>
aggregateFunctionMap(agg.aggregateFunction).toAttribute
case expression if groupExpressionMap.contains(expression) =>
groupExpressionMap(expression).toAttribute
case expression if distinctColumnExpressionMap.contains(expression) =>
distinctColumnExpressionMap(expression).toAttribute
}.asInstanceOf[NamedExpression]
}
val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(
namedGroupingAttributes,
finalAggregateExpressions,
finalAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
rewrittenResultExpressions,
partialMergeAggregate)

finalAndCompleteAggregate :: Nil
}

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Aggregate(groupingExpressions, resultExpressions, child)
if sqlContext.conf.useSqlAggregate2 =>
Expand All @@ -216,58 +391,33 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
aggregateFunction -> Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
}.toMap

// 2. Create an Aggregate Operator for partial aggregations.
val namedGroupingExpressions = groupingExpressions.map {
case ne: NamedExpression => ne -> ne
// If the expression is not a NamedExpressions, we add an alias.
// So, when we generate the result of the operator, the Aggregate Operator
// can directly get the Seq of attributes representing the grouping expressions.
case other =>
val withAlias = Alias(other, other.toString)()
other -> withAlias
}
val groupExpressionMap = namedGroupingExpressions.toMap
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
val partialAggregateExpressions = aggregateExpressions.map {
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
AggregateExpression2(aggregateFunction, Partial, isDistinct)
val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
println("functionsWithDistinct " + functionsWithDistinct)
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
// This is a sanity check. We should not reach here since we check the same thing in
// CheckAggregateFunction.
sys.error("Having more than one distinct column sets is not allowed.")
}
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
agg.aggregateFunction.bufferAttributes
}
val partialAggregate =
Aggregate2Sort(
namedGroupingExpressions.map(_._2),
partialAggregateExpressions,
partialAggregateAttributes,
namedGroupingAttributes ++ partialAggregateAttributes,
planLater(child))

// 3. Create an Aggregate Operator for final aggregations.
val finalAggregateExpressions = aggregateExpressions.map {
case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
AggregateExpression2(aggregateFunction, Final, isDistinct)
}
val finalAggregateAttributes =
finalAggregateExpressions.map {
expr => aggregateFunctionMap(expr.aggregateFunction)
val aggregate =
if (functionsWithDistinct.isEmpty) {
planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
aggregateFunctionMap,
resultExpressions,
planLater(child))
} else {
planAggregateWithOneDistinct(
groupingExpressions,
functionsWithDistinct,
functionsWithoutDistinct,
aggregateFunctionMap,
resultExpressions,
planLater(child))
}
val rewrittenResultExpressions = resultExpressions.map { expr =>
expr.transform {
case agg: AggregateExpression2 =>
aggregateFunctionMap(agg.aggregateFunction).toAttribute
case expression if groupExpressionMap.contains(expression) =>
groupExpressionMap(expression).toAttribute
}.asInstanceOf[NamedExpression]
}
val finalAggregate = Aggregate2Sort(
namedGroupingAttributes,
finalAggregateExpressions,
finalAggregateAttributes,
rewrittenResultExpressions,
partialAggregate)

finalAggregate :: Nil
aggregate
case _ => Nil
}
}
Expand Down
Loading

0 comments on commit 68b8ee9

Please sign in to comment.