Skip to content

Commit

Permalink
bug fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 22, 2015
1 parent 00eb298 commit 3b43b24
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ case class Aggregate2Sort(
child: SparkPlan)
extends UnaryNode {

override def canProcessUnsafeRows: Boolean = true

override def references: AttributeSet = {
val referencesInResults =
AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
Expand Down Expand Up @@ -72,6 +74,7 @@ case class Aggregate2Sort(
if (aggregateExpressions.length == 0) {
new GroupingIterator(
groupingExpressions,
resultExpressions,
newMutableProjection,
child.output,
iter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ private[sql] abstract class SortAggregationIterator(
*/
class GroupingIterator(
groupingExpressions: Seq[NamedExpression],
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
inputAttributes: Seq[Attribute],
inputIter: Iterator[InternalRow])
Expand All @@ -251,14 +252,18 @@ class GroupingIterator(
newMutableProjection,
inputAttributes,
inputIter) {

private val resultProjection =
newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))()

override protected def initialBufferOffset: Int = 0

override protected def processRow(row: InternalRow): Unit = {
// Since we only do grouping, there is nothing to do at here.
}

override protected def generateOutput(): InternalRow = {
currentGroupingKey
resultProjection(currentGroupingKey)
}
}

Expand Down Expand Up @@ -521,7 +526,6 @@ class FinalSortAggregationIterator(
nonAlgebraicAggregateFunctions(i).eval(buffer))
i += 1
}

resultProjection(joinedRow(currentGroupingKey, aggregateResult))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.{StructType, MapType, ArrayType}

object Utils {
// Right now, we do not support complex types in the grouping key schema.
private def groupingKeySchemaIsSupported(aggregate: Aggregate): Boolean = {
private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
case array: ArrayType => true
case map: MapType => true
Expand All @@ -39,7 +39,7 @@ object Utils {
}

private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
case p: Aggregate if groupingKeySchemaIsSupported(p) =>
case p: Aggregate if supportsGroupingKeySchema(p) =>
val converted = p.transformExpressionsDown {
case expressions.Average(child) =>
aggregate.AggregateExpression2(
Expand Down Expand Up @@ -125,6 +125,33 @@ object Utils {
case other => None
}

private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
// If the plan cannot be converted, we will do a final round check to if the original
// logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
// we need to throw an exception.
val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression2 => agg.aggregateFunction
}
}.distinct
if (aggregateFunction2s.nonEmpty) {
// For functions implemented based on the new interface, prepare a list of function names.
val invalidFunctions = {
if (aggregateFunction2s.length > 1) {
s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
s"and ${aggregateFunction2s.head.nodeName} are"
} else {
s"${aggregateFunction2s.head.nodeName} is"
}
}
val errorMessage =
s"${invalidFunctions} implemented based on the new Aggregate Function " +
s"interface and it cannot be used with functions implemented based on " +
s"the old Aggregate Function interface."
throw new AnalysisException(errorMessage)
}
}

def tryConvert(
plan: LogicalPlan,
useNewAggregation: Boolean,
Expand All @@ -134,26 +161,12 @@ object Utils {
if (converted.isDefined) {
converted
} else {
// If the plan cannot be converted, we will do a final round check to if the original
// logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
// we need to throw an exception.
p match {
case Aggregate(_, aggregateExpressions, _) => aggregateExpressions.foreach { expr =>
expr.foreach {
case agg2: AggregateExpression2 =>
// TODO: Make this errorMessage more user-friendly.
val errorMessage =
s"${agg2.aggregateFunction} is implemented based on new Aggregate Function " +
s"interface and it cannot be used with old Aggregate Function implementaion."
throw new AnalysisException(errorMessage)
case other => // OK
}
}
case other => // OK
}

checkInvalidAggregateFunction2(p)
None
}
case p: Aggregate =>
checkInvalidAggregateFunction2(p)
None
case other => None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,4 +275,6 @@ case class ScalaUDAF(
override def toString: String = {
s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
}

override def nodeName: String = udaf.getClass.getSimpleName
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,36 +154,36 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
checkAnswer(
sqlContext.sql(
"""
|SELECT DISTINCT key, value1
|SELECT DISTINCT value1, key
|FROM agg2
""".stripMargin),
Row(1, 10) ::
Row(null, -60) ::
Row(1, 30) ::
Row(2, 1) ::
Row(null, -10) ::
Row(2, -1) ::
Row(2, null) ::
Row(null, 100) ::
Row(3, null) ::
Row(10, 1) ::
Row(-60, null) ::
Row(30, 1) ::
Row(1, 2) ::
Row(-10, null) ::
Row(-1, 2) ::
Row(null, 2) ::
Row(100, null) ::
Row(null, 3) ::
Row(null, null) :: Nil)

checkAnswer(
sqlContext.sql(
"""
|SELECT key, value1
|SELECT value1, key
|FROM agg2
|GROUP BY key, value1
""".stripMargin),
Row(1, 10) ::
Row(null, -60) ::
Row(1, 30) ::
Row(2, 1) ::
Row(null, -10) ::
Row(2, -1) ::
Row(2, null) ::
Row(null, 100) ::
Row(3, null) ::
Row(10, 1) ::
Row(-60, null) ::
Row(30, 1) ::
Row(1, 2) ::
Row(-10, null) ::
Row(-1, 2) ::
Row(null, 2) ::
Row(100, null) ::
Row(null, 3) ::
Row(null, null) :: Nil)
}

Expand Down Expand Up @@ -427,12 +427,13 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
|SELECT
| key,
| sum(value + 1.5 * key),
| mydoublesum(value)
| mydoublesum(value),
| mydoubleavg(value)
|FROM agg1
|GROUP BY key
""".stripMargin).collect()
}.getMessage
assert(errorMessage.contains("is implemented based on new Aggregate Function interface"))
assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))

// TODO: once we support Hive UDAF in the new interface,
// we can remove the following two tests.
Expand All @@ -448,7 +449,7 @@ class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAf
|GROUP BY key
""".stripMargin).collect()
}.getMessage
assert(errorMessage.contains("is implemented based on new Aggregate Function interface"))
assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))

// This will fall back to the old aggregate
val newAggregateOperators = sqlContext.sql(
Expand Down

0 comments on commit 3b43b24

Please sign in to comment.