From 70b169c981ece62cbf755169043cbe1239da9afa Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 16 Jul 2015 16:57:02 -0700 Subject: [PATCH] Remove groupOrdering. --- .../execution/aggregate2/Aggregate2Sort.scala | 25 ++++++----------- .../sql/hive/execution/Aggregate2Suite.scala | 28 ++++++++++++------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala index 3bdb3b9c32f59..c2201d2f91aaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala @@ -149,14 +149,6 @@ case class Aggregate2Sort( // This is used to project expressions for the grouping expressions. protected val groupGenerator = newMutableProjection(groupingExpressions, child.output)() - // A ordering used to compare if a new row belongs to the current group - // or a new group. - private val groupOrdering: Ordering[InternalRow] = { - val groupingAttributes = groupingExpressions.map(_.toAttribute) - newOrdering( - groupingAttributes.map(expr => SortOrder(expr, Ascending)), - groupingAttributes) - } // The partition key of the current partition. private var currentGroupingKey: InternalRow = _ // The partition key of next partition. @@ -182,18 +174,18 @@ case class Aggregate2Sort( // aggregate function, the size of the buffer matches the number of values in the // input rows. To simplify the code for code-gen, we need create some dummy // attributes and expressions for these grouping expressions. - val offsetAttributes = { + private val offsetAttributes = { if (partialAggregation) { Nil } else { Seq.fill(groupingExpressions.length)(AttributeReference("offset", NullType)()) } } - val offsetExpressions = + private val offsetExpressions = if (partialAggregation) Nil else Seq.fill(groupingExpressions.length)(NoOp) // This projection is used to initialize buffer values for all AlgebraicAggregates. - val algebraicInitialProjection = { + private val algebraicInitialProjection = { val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.initialValues case agg: AggregateFunction2 => NoOp :: Nil @@ -202,7 +194,7 @@ case class Aggregate2Sort( } // This projection is used to update buffer values for all AlgebraicAggregates. - lazy val algebraicUpdateProjection = { + private lazy val algebraicUpdateProjection = { val bufferSchema = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.bufferAttributes case agg: AggregateFunction2 => agg.bufferAttributes @@ -215,7 +207,7 @@ case class Aggregate2Sort( } // This projection is used to merge buffer values for all AlgebraicAggregates. - lazy val algebraicMergeProjection = { + private lazy val algebraicMergeProjection = { val bufferSchemata = offsetAttributes ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.bufferAttributes @@ -233,7 +225,7 @@ case class Aggregate2Sort( } // This projection is used to evaluate all AlgebraicAggregates. - lazy val algebraicEvalProjection = { + private lazy val algebraicEvalProjection = { val bufferSchemata = offsetAttributes ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.bufferAttributes @@ -313,8 +305,9 @@ case class Aggregate2Sort( // For the below compare method, we do not need to make a copy of groupingKey. val groupingKey = groupGenerator(currentRow) // Check if the current row belongs the current input row. - val comparing = groupOrdering.compare(currentGroupingKey, groupingKey) - if (comparing == 0) { + currentGroupingKey.equals(groupingKey) + + if (currentGroupingKey == groupingKey) { processRow(currentRow) } else { // We find a new group. diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala index 48771e8f403d7..348e449bd22b5 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/Aggregate2Suite.scala @@ -32,15 +32,19 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { override def beforeAll(): Unit = { originalUseAggregate2 = ctx.conf.useSqlAggregate2 ctx.sql("set spark.sql.useAggregate2=true") - val data = Seq[(Int, Integer)]( + val data = Seq[(Integer, Integer)]( (1, 10), + (null, -60), (1, 20), (1, 30), (2, 0), + (null, -10), (2, -1), (2, null), (2, null), + (null, 100), (3, null), + (null, null), (3, null)).toDF("key", "value") data.write.saveAsTable("agg2") @@ -54,7 +58,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { |FROM agg2 |GROUP BY key """.stripMargin), - Row(-0.5) :: Row(20.0) :: Row(null) :: Nil) + Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil) } test("test average2") { @@ -79,7 +83,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { |FROM agg2 |GROUP BY key """.stripMargin), - Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Nil) + Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) checkAnswer( ctx.sql( @@ -88,7 +92,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { |FROM agg2 |GROUP BY key """.stripMargin), - Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Nil) + Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil) checkAnswer( ctx.sql( @@ -97,14 +101,14 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { |FROM agg2 |GROUP BY key + 10 """.stripMargin), - Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Nil) + Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil) checkAnswer( ctx.sql( """ |SELECT avg(value) FROM agg2 """.stripMargin), - Row(11.8) :: Nil) + Row(11.125) :: Nil) checkAnswer( ctx.sql( @@ -137,14 +141,14 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { |FROM agg2 |GROUP BY key """.stripMargin), - Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Nil) + Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) checkAnswer( ctx.sql( """ |SELECT mydoublesum(cast(value as double)) FROM agg2 """.stripMargin), - Row(59.0) :: Nil) + Row(89.0) :: Nil) checkAnswer( ctx.sql( @@ -163,7 +167,10 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { |FROM agg2 |GROUP BY key """.stripMargin), - Row(60.0, 1, 20.0) :: Row(-1.0, 2, -0.5) :: Row(null, 3, null) :: Nil) + Row(60.0, 1, 20.0) :: + Row(-1.0, 2, -0.5) :: + Row(null, 3, null) :: + Row(30.0, null, 10.0) :: Nil) checkAnswer( ctx.sql( @@ -179,7 +186,8 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll { """.stripMargin), Row(64.5, 19.0, 1, 55.5, 20.0) :: Row(5.0, -2.5, 2, -7.0, -0.5) :: - Row(null, null, 3, null, null) :: Nil) + Row(null, null, 3, null, null) :: + Row(null, null, null, null, 10.0) :: Nil) } override def afterAll(): Unit = {