Skip to content

Commit

Permalink
Change MyJavaUDAF to MyDoubleSum.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Jul 20, 2015
1 parent 594cdf5 commit a101960
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import scala.util.{Failure, Success, Try}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate2.MyDoubleSum
import org.apache.spark.sql.catalyst.util.StringKeyHashMap


Expand Down Expand Up @@ -149,7 +148,6 @@ object FunctionRegistry {
expression[Max]("max"),
expression[Min]("min"),
expression[Sum]("sum"),
expression[MyDoubleSum]("mydoublesum"),

// string functions
expression[Ascii]("ascii"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,22 @@ private[sql] case object NoOp extends Expression with Unevaluable {
private[sql] case class AggregateExpression2(
aggregateFunction: AggregateFunction2,
mode: AggregateMode,
isDistinct: Boolean) extends Expression with Unevaluable {
isDistinct: Boolean) extends AggregateExpression {

override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
override def foldable: Boolean = false
override def nullable: Boolean = aggregateFunction.nullable

override def references: AttributeSet = {
val childReferemces = mode match {
case Partial | Complete => aggregateFunction.references.toSeq
case PartialMerge | Final => aggregateFunction.bufferAttributes
}

AttributeSet(childReferemces)
}

override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
}

Expand Down Expand Up @@ -136,59 +145,6 @@ abstract class AggregateFunction2
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
}

/**
* An example [[AggregateFunction2]] that is not an [[AlgebraicAggregate]].
* This function calculate the sum of double values.
* @param child
*/
case class MyDoubleSum(child: Expression) extends AggregateFunction2 {
override val bufferSchema: StructType =
StructType(StructField("currentSum", DoubleType, true) :: Nil)

override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes

override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())

override def initialize(buffer: MutableRow): Unit = {
buffer.update(bufferOffset, null)
}

override def update(buffer: MutableRow, input: InternalRow): Unit = {
val inputValue = child.eval(input)
if (inputValue != null) {
if (buffer.isNullAt(bufferOffset)) {
buffer.setDouble(bufferOffset, inputValue.asInstanceOf[Double])
} else {
val currentSum = buffer.getDouble(bufferOffset)
buffer.setDouble(bufferOffset, currentSum + inputValue.asInstanceOf[Double])
}
}
}

override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
if (!buffer2.isNullAt(bufferOffset)) {
if (buffer1.isNullAt(bufferOffset)) {
buffer1.setDouble(bufferOffset, buffer2.getDouble(bufferOffset))
} else {
val currentSum = buffer1.getDouble(bufferOffset)
buffer1.setDouble(bufferOffset, currentSum + buffer2.getDouble(bufferOffset))
}
}
}

override def eval(buffer: InternalRow = null): Any = {
if (buffer.isNullAt(bufferOffset)) {
null
} else {
buffer.getDouble(bufferOffset)
}
}

override def nullable: Boolean = true
override def dataType: DataType = DoubleType
override def children: Seq[Expression] = child :: Nil
}

/**
* A helper class for aggregate functions that can be implemented in terms of catalyst expressions.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.Row;

public class MyJavaUDAF extends UserDefinedAggregateFunction {
public class MyDoubleSum extends UserDefinedAggregateFunction {

private StructType _inputDataType;

private StructType _bufferSchema;

private DataType _returnDataType;

public MyJavaUDAF() {
public MyDoubleSum() {
List<StructField> inputfields = new ArrayList<StructField>();
inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
_inputDataType = DataTypes.createStructType(inputfields);
Expand Down Expand Up @@ -73,7 +73,7 @@ public MyJavaUDAF() {
if (buffer.isNullAt(0)) {
buffer.update(0, input.getDouble(0));
} else {
Double newValue = input.getDouble(0) * buffer.getDouble(0);
Double newValue = input.getDouble(0) + buffer.getDouble(0);
buffer.update(0, newValue);
}
}
Expand All @@ -84,7 +84,7 @@ public MyJavaUDAF() {
if (buffer1.isNullAt(0)) {
buffer1.update(0, buffer2.getDouble(0));
} else {
Double newValue = buffer2.getDouble(0) * buffer1.getDouble(0);
Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
buffer1.update(0, newValue);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row}
import org.scalatest.BeforeAndAfterAll
import test.org.apache.spark.sql.hive.aggregate2.MyJavaUDAF
import test.org.apache.spark.sql.hive.aggregate2.MyDoubleSum

class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {

Expand Down Expand Up @@ -48,6 +48,10 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
(3, null)).toDF("key", "value")

data.write.saveAsTable("agg2")

// Register a UDAF
val javaUDAF = new MyDoubleSum
ctx.udaf.register("mydoublesum", javaUDAF)
}

test("test average2 no key in output") {
Expand All @@ -62,20 +66,6 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
}

test("test average2") {
ctx.sql(
"""
|SELECT key, avg(value)
|FROM agg2
|GROUP BY key
""".stripMargin).explain(true)

ctx.sql(
"""
|SELECT key, avg(value)
|FROM agg2
|GROUP BY key
""".stripMargin).collect().foreach(println)

checkAnswer(
ctx.sql(
"""
Expand Down Expand Up @@ -116,23 +106,29 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
|SELECT avg(null)
""".stripMargin),
Row(null) :: Nil)
}

test("udaf") {
checkAnswer(
ctx.sql(
"""
|SELECT
| key,
| mydoublesum(cast(value as double) + 1.5 * key),
| avg(value - key),
| mydoublesum(cast(value as double) - 1.5 * key),
| avg(value)
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(1, 64.5, 19.0, 55.5, 20.0) ::
Row(2, 5.0, -2.5, -7.0, -0.5) ::
Row(3, null, null, null, null) ::
Row(null, null, null, null, 10.0) :: Nil)
}

test("non-AlgebraicAggregate aggreguate function") {
ctx.sql(
"""
|SELECT key, mydoublesum(cast(value as double))
|FROM agg2
|GROUP BY key
""".stripMargin).explain(true)

ctx.sql(
"""
|SELECT key, mydoublesum(cast(value as double))
|FROM agg2
|GROUP BY key
""".stripMargin).collect().foreach(println)


checkAnswer(
ctx.sql(
Expand All @@ -156,7 +152,6 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
|SELECT mydoublesum(null)
""".stripMargin),
Row(null) :: Nil)

}

test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") {
Expand Down Expand Up @@ -190,53 +185,27 @@ class Aggregate2Suite extends QueryTest with BeforeAndAfterAll {
Row(null, null, null, null, 10.0) :: Nil)
}

test("udaf") {
val myJavaUDAF = new MyJavaUDAF
ctx.udaf.register("myjavaudaf", myJavaUDAF)

ctx.sql(
"""
|SELECT
| key,
| mydoublesum(cast(value as double) + 1.5 * key),
| avg(value - key),
| myjavaudaf(value),
| mydoublesum(cast(value as double) - 1.5 * key),
| avg(value)
|FROM agg2
|GROUP BY key
""".stripMargin).explain(true)

ctx.sql(
"""
|SELECT
| key,
| mydoublesum(cast(value as double) + 1.5 * key),
| avg(value - key),
| myjavaudaf(value),
| mydoublesum(cast(value as double) - 1.5 * key),
| avg(value)
|FROM agg2
|GROUP BY key
""".stripMargin).collect().foreach(println)

checkAnswer(
ctx.sql(
"""
|SELECT
| key,
| mydoublesum(cast(value as double) + 1.5 * key),
| avg(value - key),
| myjavaudaf(value),
| mydoublesum(cast(value as double) - 1.5 * key),
| avg(value)
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(1, 64.5, 19.0, 6000.0, 55.5, 20.0) ::
Row(2, 5.0, -2.5, -0.0, -7.0, -0.5) ::
Row(3, null, null, null, null, null) ::
Row(null, null, null, 60000.0, null, 10.0) :: Nil)
test("Cannot use AggregateExpression1 and AggregateExpressions2 together") {
Seq(true, false).foreach { useAggregate2 =>
ctx.sql(s"set spark.sql.useAggregate2=$useAggregate2")
val errorMessage = intercept[AnalysisException] {
ctx.sql(
"""
|SELECT
| key,
| sum(cast(value as double) + 1.5 * key),
| mydoublesum(value)
|FROM agg2
|GROUP BY key
""".stripMargin).collect()
}.getMessage
val expectedErrorMessage =
s"${SQLConf.USE_SQL_AGGREGATE2.key} is ${if (useAggregate2) "enabled" else "disabled"}. " +
s"Please ${if (useAggregate2) "disable" else "enable"} it to use"
assert(errorMessage.contains(expectedErrorMessage))
}

ctx.sql(s"set spark.sql.useAggregate2=true")
}

override def afterAll(): Unit = {
Expand Down

0 comments on commit a101960

Please sign in to comment.