Skip to content

Commit

Permalink
Remove groupOrdering.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 16, 2015
1 parent 4721936 commit 70b169c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,6 @@ case class Aggregate2Sort(
// This is used to project expressions for the grouping expressions.
protected val groupGenerator =
newMutableProjection(groupingExpressions, child.output)()
// A ordering used to compare if a new row belongs to the current group
// or a new group.
private val groupOrdering: Ordering[InternalRow] = {
val groupingAttributes = groupingExpressions.map(_.toAttribute)
newOrdering(
groupingAttributes.map(expr => SortOrder(expr, Ascending)),
groupingAttributes)
}
// The partition key of the current partition.
private var currentGroupingKey: InternalRow = _
// The partition key of next partition.
Expand All @@ -182,18 +174,18 @@ case class Aggregate2Sort(
// aggregate function, the size of the buffer matches the number of values in the
// input rows. To simplify the code for code-gen, we need create some dummy
// attributes and expressions for these grouping expressions.
val offsetAttributes = {
private val offsetAttributes = {
if (partialAggregation) {
Nil
} else {
Seq.fill(groupingExpressions.length)(AttributeReference("offset", NullType)())
}
}
val offsetExpressions =
private val offsetExpressions =
if (partialAggregation) Nil else Seq.fill(groupingExpressions.length)(NoOp)

// This projection is used to initialize buffer values for all AlgebraicAggregates.
val algebraicInitialProjection = {
private val algebraicInitialProjection = {
val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.initialValues
case agg: AggregateFunction2 => NoOp :: Nil
Expand All @@ -202,7 +194,7 @@ case class Aggregate2Sort(
}

// This projection is used to update buffer values for all AlgebraicAggregates.
lazy val algebraicUpdateProjection = {
private lazy val algebraicUpdateProjection = {
val bufferSchema = aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.bufferAttributes
case agg: AggregateFunction2 => agg.bufferAttributes
Expand All @@ -215,7 +207,7 @@ case class Aggregate2Sort(
}

// This projection is used to merge buffer values for all AlgebraicAggregates.
lazy val algebraicMergeProjection = {
private lazy val algebraicMergeProjection = {
val bufferSchemata =
offsetAttributes ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.bufferAttributes
Expand All @@ -233,7 +225,7 @@ case class Aggregate2Sort(
}

// This projection is used to evaluate all AlgebraicAggregates.
lazy val algebraicEvalProjection = {
private lazy val algebraicEvalProjection = {
val bufferSchemata =
offsetAttributes ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.bufferAttributes
Expand Down Expand Up @@ -313,8 +305,9 @@ case class Aggregate2Sort(
// For the below compare method, we do not need to make a copy of groupingKey.
val groupingKey = groupGenerator(currentRow)
// Check if the current row belongs the current input row.
val comparing = groupOrdering.compare(currentGroupingKey, groupingKey)
if (comparing == 0) {
currentGroupingKey.equals(groupingKey)

if (currentGroupingKey == groupingKey) {
processRow(currentRow)
} else {
// We find a new group.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,19 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
override def beforeAll(): Unit = {
originalUseAggregate2 = ctx.conf.useSqlAggregate2
ctx.sql("set spark.sql.useAggregate2=true")
val data = Seq[(Int, Integer)](
val data = Seq[(Integer, Integer)](
(1, 10),
(null, -60),
(1, 20),
(1, 30),
(2, 0),
(null, -10),
(2, -1),
(2, null),
(2, null),
(null, 100),
(3, null),
(null, null),
(3, null)).toDF("key", "value")

data.write.saveAsTable("agg2")
Expand All @@ -54,7 +58,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(-0.5) :: Row(20.0) :: Row(null) :: Nil)
Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil)
}

test("test average2") {
Expand All @@ -79,7 +83,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Nil)
Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil)

checkAnswer(
ctx.sql(
Expand All @@ -88,7 +92,7 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Nil)
Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil)

checkAnswer(
ctx.sql(
Expand All @@ -97,14 +101,14 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
|FROM agg2
|GROUP BY key + 10
""".stripMargin),
Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Nil)
Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil)

checkAnswer(
ctx.sql(
"""
|SELECT avg(value) FROM agg2
""".stripMargin),
Row(11.8) :: Nil)
Row(11.125) :: Nil)

checkAnswer(
ctx.sql(
Expand Down Expand Up @@ -137,14 +141,14 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Nil)
Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil)

checkAnswer(
ctx.sql(
"""
|SELECT mydoublesum(cast(value as double)) FROM agg2
""".stripMargin),
Row(59.0) :: Nil)
Row(89.0) :: Nil)

checkAnswer(
ctx.sql(
Expand All @@ -163,7 +167,10 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(60.0, 1, 20.0) :: Row(-1.0, 2, -0.5) :: Row(null, 3, null) :: Nil)
Row(60.0, 1, 20.0) ::
Row(-1.0, 2, -0.5) ::
Row(null, 3, null) ::
Row(30.0, null, 10.0) :: Nil)

checkAnswer(
ctx.sql(
Expand All @@ -179,7 +186,8 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
""".stripMargin),
Row(64.5, 19.0, 1, 55.5, 20.0) ::
Row(5.0, -2.5, 2, -7.0, -0.5) ::
Row(null, null, 3, null, null) :: Nil)
Row(null, null, 3, null, null) ::
Row(null, null, null, null, 10.0) :: Nil)
}

override def afterAll(): Unit = {
Expand Down

0 comments on commit 70b169c

Please sign in to comment.