Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 10, 2015
1 parent 39e4e7e commit dded1c5
Show file tree
Hide file tree
Showing 11 changed files with 631 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.aggregate2.{AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -482,7 +483,11 @@ class Analyzer(
q transformExpressions {
case u @ UnresolvedFunction(name, children) =>
withPosition(u) {
registry.lookupFunction(name, children)
registry.lookupFunction(name, children) match {
case agg2: AggregateFunction2 =>
AggregateExpression2(agg2, aggregate2.Complete, false)
case other => other
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate2.AggregateExpression2
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -85,6 +86,7 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK
case _: AggregateExpression2 => // OK
case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ object FunctionRegistry {

// aggregate functions
expression[Average]("avg"),
expression[aggregate2.Average]("avg2"),
expression[Count]("count"),
expression[First]("first"),
expression[Last]("last"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate2

import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.trees.{LeafNode, UnaryNode}
import org.apache.spark.sql.types._

private[sql] sealed trait AggregateMode

private[sql] case object Partial extends AggregateMode

private[sql] case object PartialMerge extends AggregateMode

private[sql] case object Final extends AggregateMode

private[sql] case object Complete extends AggregateMode

/**
* A container of a Aggregate Function, Aggregate Mode, and a field (`isDistinct`) indicating
* if DISTINCT keyword is specified for this function.
* @param aggregateFunction
* @param mode
* @param isDistinct
*/
private[sql] case class AggregateExpression2(
aggregateFunction: AggregateFunction2,
mode: AggregateMode,
isDistinct: Boolean) extends Expression {

override def children: Seq[Expression] = aggregateFunction :: Nil

override def dataType: DataType = aggregateFunction.dataType
override def foldable: Boolean = aggregateFunction.foldable
override def nullable: Boolean = aggregateFunction.nullable

override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"

override def eval(input: InternalRow = null): Any =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
}

abstract class AggregateFunction2
extends Expression {

self: Product =>

var bufferOffset: Int = 0

def withBufferOffset(newBufferOffset: Int): AggregateFunction2 = {
bufferOffset = newBufferOffset
this
}

def bufferValueDataTypes: StructType

def initialBufferValues: Array[Any]

def initialize(buffer: MutableRow): Unit

def updateBuffer(buffer: MutableRow, bufferValues: Array[Any]): Unit = {
var i = 0
println("bufferOffset in average2 " + bufferOffset)
while (i < bufferValues.length) {
buffer.update(bufferOffset + i, bufferValues(i))
i += 1
}
}

def update(buffer: MutableRow, input: InternalRow): Unit

def merge(buffer1: MutableRow, buffer2: InternalRow): Unit

override def eval(buffer: InternalRow = null): Any
}

case class Average(child: Expression)
extends AggregateFunction2 with UnaryNode[Expression] {

override def nullable: Boolean = child.nullable

override def bufferValueDataTypes: StructType = child match {
case e @ DecimalType() =>
StructType(
StructField("Sum", DecimalType.Unlimited) ::
StructField("Count", LongType) :: Nil)
case _ =>
StructType(
StructField("Sum", DoubleType) ::
StructField("Count", LongType) :: Nil)
}

override def dataType: DataType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4)
case DecimalType.Unlimited => DecimalType.Unlimited
case _ => DoubleType
}

override def initialBufferValues: Array[Any] = {
Array(
Cast(Literal(0), bufferValueDataTypes("Sum").dataType).eval(null), // Sum
0L) // Count
}

override def initialize(buffer: MutableRow): Unit =
updateBuffer(buffer, initialBufferValues)

private val inputLiteral =
MutableLiteral(null, child.dataType)
private val bufferedSum =
MutableLiteral(null, bufferValueDataTypes("Sum").dataType)
private val bufferedCount = MutableLiteral(null, LongType)
private val updateSum =
Add(Cast(inputLiteral, bufferValueDataTypes("Sum").dataType), bufferedSum)
private val inputBufferedSum =
MutableLiteral(null, bufferValueDataTypes("Sum").dataType)
private val mergeSum = Add(inputBufferedSum, bufferedSum)
private val evaluateAvg =
Cast(Divide(bufferedSum, Cast(bufferedCount, bufferValueDataTypes("Sum").dataType)), dataType)

override def update(buffer: MutableRow, input: InternalRow): Unit = {
val newInput = child.eval(input)
println("newInput " + newInput)
if (newInput != null) {
inputLiteral.value = newInput
bufferedSum.value = buffer(bufferOffset)
val newSum = updateSum.eval(null)
val newCount = buffer.getLong(bufferOffset + 1) + 1
buffer.update(bufferOffset, newSum)
buffer.update(bufferOffset + 1, newCount)
}
}

override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
if (buffer2(bufferOffset + 1) != 0L) {
inputBufferedSum.value = buffer2(bufferOffset)
bufferedSum.value = buffer1(bufferOffset)
val newSum = mergeSum.eval(null)
val newCount =
buffer1.getLong(bufferOffset + 1) + buffer2.getLong(bufferOffset + 1)
buffer1.update(bufferOffset, newSum)
buffer1.update(bufferOffset + 1, newCount)
}
}

override def eval(buffer: InternalRow): Any = {
if (buffer(bufferOffset + 1) == 0L) {
null
} else {
bufferedSum.value = buffer(bufferOffset)
bufferedCount.value = buffer.getLong(bufferOffset + 1)
evaluateAvg.eval(null)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.aggregate2

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, BoundReference, Literal, ExpressionEvalHelper}
import org.apache.spark.sql.types._

class AggregateExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

test("Average") {
val inputValues = Array(Int.MaxValue, null, 1000, Int.MinValue, 2)
val avg = Average(child = BoundReference(0, IntegerType, true)).withBufferOffset(2)
val inputRow = new GenericMutableRow(1)
val buffer = new GenericMutableRow(4)
avg.initialize(buffer)

// We there is no input data, average should return null.
assert(avg.eval(buffer) === null)
// When input values are all nulls, average should return null.
var i = 0
while (i < 10) {
inputRow.update(0, null)
avg.update(inputRow, buffer)
i += 1
}
assert(avg.eval(buffer) === null)

// Add some values.
i = 0
while (i < inputValues.length) {
inputRow.update(0, inputValues(i))
avg.update(buffer, inputRow)
i += 1
}
assert(avg.eval(buffer) === 1001 / 4.0)

// eval should not reset the buffer
assert(buffer(2) === 1001L)
assert(buffer(3) === 4L)
assert(avg.eval(buffer) === 1001 / 4.0)

// Merge with a just initialized buffer.
val inputBuffer = new GenericMutableRow(4)
avg.initialize(inputBuffer)
avg.merge(buffer, inputBuffer)
assert(buffer(2) === 1001L)
assert(buffer(3) === 4L)
assert(avg.eval(buffer) === 1001 / 4.0)

// Merge with a buffer containing partial results.
inputBuffer.update(2, 2000.0)
inputBuffer.update(3, 10L)
avg.merge(buffer, inputBuffer)
assert(buffer(2) === 3001L)
assert(buffer(3) === 14L)
assert(avg.eval(buffer) === 3001 / 14.0)
}
}
5 changes: 5 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ private[spark] object SQLConf {
val USE_SQL_SERIALIZER2 = booleanConf("spark.sql.useSerializer2",
defaultValue = Some(true), doc = "<TODO>")

val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
defaultValue = Some(false), doc = "<TODO>")

val USE_JACKSON_STREAMING_API = booleanConf("spark.sql.json.useJacksonStreamingAPI",
defaultValue = Some(true), doc = "<TODO>")

Expand Down Expand Up @@ -479,6 +482,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf {

private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)

private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)

/**
* Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
DDLStrategy ::
TakeOrderedAndProject ::
HashAggregation ::
AggregateOperator2 ::
LeftSemiJoin ::
HashJoin ::
InMemoryScans ::
Expand Down
Loading

0 comments on commit dded1c5

Please sign in to comment.