From 0a827b3d5162e4e197f37745ef7c2d7a893a3bc4 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 17 Jul 2015 10:52:59 -0700 Subject: [PATCH] Add comments and doc. Move some classes to the right places. --- .../expressions/aggregate2/aggregates.scala | 72 ----- .../apache/spark/sql/UDAFRegistration.scala | 2 +- .../expressions/aggregate2/udaf.scala | 149 ---------- .../sql/expressions/aggregate2/udaf.scala | 274 ++++++++++++++++++ .../spark/sql/hive/aggregate2/MyJavaUDAF.java | 6 +- 5 files changed, 278 insertions(+), 225 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/aggregate2/udaf.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate2/udaf.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala index 8a158bdcbec9e..76d21d65cf9c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate2/aggregates.scala @@ -279,77 +279,5 @@ case class Average(child: Expression) extends AlgebraicAggregate { override def children: Seq[Expression] = child :: Nil } -abstract class AggregationBuffer( - toCatalystConverters: Array[Any => Any], - toScalaConverters: Array[Any => Any], - bufferOffset: Int) - extends Row { - override def length: Int = toCatalystConverters.length - protected val offsets: Array[Int] = { - val newOffsets = new Array[Int](length) - var i = 0 - while (i < newOffsets.length) { - newOffsets(i) = bufferOffset + i - i += 1 - } - newOffsets - } -} - -class MutableAggregationBuffer( - toCatalystConverters: Array[Any => Any], - toScalaConverters: Array[Any => Any], - bufferOffset: Int, - var underlyingBuffer: MutableRow) - extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { - - override def apply(i: Int): Any = { - if (i >= length || i < 0) { - throw new IllegalArgumentException( - s"Could not access ${i}th value in this buffer because it only has $length values.") - } - toScalaConverters(i)(underlyingBuffer(offsets(i))) - } - - def update(i: Int, value: Any): Unit = { - if (i >= length || i < 0) { - throw new IllegalArgumentException( - s"Could not update ${i}th value in this buffer because it only has $length values.") - } - underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) - } - - override def copy(): MutableAggregationBuffer = { - new MutableAggregationBuffer( - toCatalystConverters, - toScalaConverters, - bufferOffset, - underlyingBuffer) - } -} - -class InputAggregationBuffer( - toCatalystConverters: Array[Any => Any], - toScalaConverters: Array[Any => Any], - bufferOffset: Int, - var underlyingInputBuffer: Row) - extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { - - override def apply(i: Int): Any = { - if (i >= length || i < 0) { - throw new IllegalArgumentException( - s"Could not access ${i}th value in this buffer because it only has $length values.") - } - toScalaConverters(i)(underlyingInputBuffer(offsets(i))) - } - - override def copy(): InputAggregationBuffer = { - new InputAggregationBuffer( - toCatalystConverters, - toScalaConverters, - bufferOffset, - underlyingInputBuffer) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala index 0e3d185403eba..97dbb2129cf44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Expression} -import org.apache.spark.sql.execution.expressions.aggregate2.{ScalaUDAF, UserDefinedAggregateFunction} +import org.apache.spark.sql.expressions.aggregate2.{ScalaUDAF, UserDefinedAggregateFunction} class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/aggregate2/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/aggregate2/udaf.scala deleted file mode 100644 index 3404fa50dbca5..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/aggregate2/udaf.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.expressions.aggregate2 - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate2.{InputAggregationBuffer, AggregateFunction2, MutableAggregationBuffer} -import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row - -abstract class UserDefinedAggregateFunction extends Serializable { - - def inputDataType: StructType - - def bufferSchema: StructType - - def returnDataType: DataType - - def deterministic: Boolean - - def initialize(buffer: MutableAggregationBuffer): Unit - - def update(buffer: MutableAggregationBuffer, input: Row): Unit - - def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit - - def evaluate(buffer: Row): Any - -} - -case class ScalaUDAF( - children: Seq[Expression], - udaf: UserDefinedAggregateFunction) - extends AggregateFunction2 with ImplicitCastInputTypes with Logging { - - require( - children.length == udaf.inputDataType.length, - s"$udaf only accepts ${udaf.inputDataType.length} arguments, " + - s"but ${children.length} are provided.") - - override def nullable: Boolean = true - - override def dataType: DataType = udaf.returnDataType - - override def deterministic: Boolean = udaf.deterministic - - override val inputTypes: Seq[DataType] = udaf.inputDataType.map(_.dataType) - - override val bufferSchema: StructType = udaf.bufferSchema - - override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes - - override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) - - val childrenSchema: StructType = { - val inputFields = children.zipWithIndex.map { - case (child, index) => - StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) - } - StructType(inputFields) - } - - lazy val inputProjection = { - val inputAttributes = childrenSchema.toAttributes - log.debug( - s"Creating MutableProj: $children, inputSchema: $inputAttributes.") - try { - GenerateMutableProjection.generate(children, inputAttributes)() - } catch { - case e: Exception => - log.error("Failed to generate mutable projection, fallback to interpreted", e) - new InterpretedMutableProjection(children, inputAttributes) - } - } - - val inputToScalaConverters: Any => Any = - CatalystTypeConverters.createToScalaConverter(childrenSchema) - - val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field => - CatalystTypeConverters.createToCatalystConverter(field.dataType) - } - - val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field => - CatalystTypeConverters.createToScalaConverter(field.dataType) - } - - lazy val inputAggregateBuffer: InputAggregationBuffer = - new InputAggregationBuffer( - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - bufferOffset, - null) - - lazy val mutableAggregateBuffer: MutableAggregationBuffer = - new MutableAggregationBuffer( - bufferValuesToCatalystConverters, - bufferValuesToScalaConverters, - bufferOffset, - null) - - - override def initialize(buffer: MutableRow): Unit = { - mutableAggregateBuffer.underlyingBuffer = buffer - - udaf.initialize(mutableAggregateBuffer) - } - - override def update(buffer: MutableRow, input: InternalRow): Unit = { - mutableAggregateBuffer.underlyingBuffer = buffer - - udaf.update( - mutableAggregateBuffer, - inputToScalaConverters(inputProjection(input)).asInstanceOf[Row]) - } - - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - mutableAggregateBuffer.underlyingBuffer = buffer1 - inputAggregateBuffer.underlyingInputBuffer = buffer2 - - udaf.update(mutableAggregateBuffer, inputAggregateBuffer) - } - - override def eval(buffer: InternalRow = null): Any = { - inputAggregateBuffer.underlyingInputBuffer = buffer - - udaf.evaluate(inputAggregateBuffer) - } - - override def toString: String = { - s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})""" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate2/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate2/udaf.scala new file mode 100644 index 0000000000000..8fa70d0f96359 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate2/udaf.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.aggregate2 + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate2.AggregateFunction2 +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row + +/** + * The abstract class for implementing user-defined aggregate function. + */ +abstract class UserDefinedAggregateFunction extends Serializable { + + /** + * A [[StructType]] represents data types of input arguments of this aggregate function. + * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments + * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like + * + * ``` + * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * input argument. Users can choose names to identify the input arguments. + */ + def inputSchema: StructType + + /** + * A [[StructType]] represents data types of values in the aggregation buffer. + * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values + * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], + * the returned [[StructType]] will look like + * + * ``` + * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * buffer value. Users can choose names to identify the input arguments. + */ + def bufferSchema: StructType + + /** + * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + */ + def returnDataType: DataType + + /** Indicates if this function is deterministic. */ + def deterministic: Boolean + + /** Initializes the given aggregation buffer. */ + def initialize(buffer: MutableAggregationBuffer): Unit + + /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + def update(buffer: MutableAggregationBuffer, input: Row): Unit + + /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */ + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit + + /** + * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given + * aggregation buffer. + */ + def evaluate(buffer: Row): Any +} + +private[sql] abstract class AggregationBuffer( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int) + extends Row { + + override def length: Int = toCatalystConverters.length + + protected val offsets: Array[Int] = { + val newOffsets = new Array[Int](length) + var i = 0 + while (i < newOffsets.length) { + newOffsets(i) = bufferOffset + i + i += 1 + } + newOffsets + } +} + +/** + * A Mutable [[Row]] representing an aggregation buffer. + */ +class MutableAggregationBuffer private[sql] ( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingBuffer: MutableRow) + extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + + override def apply(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + toScalaConverters(i)(underlyingBuffer(offsets(i))) + } + + def update(i: Int, value: Any): Unit = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not update ${i}th value in this buffer because it only has $length values.") + } + underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) + } + + override def copy(): MutableAggregationBuffer = { + new MutableAggregationBuffer( + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingBuffer) + } +} + +/** + * A [[Row]] representing an immutable aggregation buffer. + */ +class InputAggregationBuffer private[sql] ( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingInputBuffer: Row) + extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + + override def apply(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + toScalaConverters(i)(underlyingInputBuffer(offsets(i))) + } + + override def copy(): InputAggregationBuffer = { + new InputAggregationBuffer( + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingInputBuffer) + } +} + +/** + * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the + * internal aggregation code path. + * @param children + * @param udaf + */ +case class ScalaUDAF( + children: Seq[Expression], + udaf: UserDefinedAggregateFunction) + extends AggregateFunction2 with ImplicitCastInputTypes with Logging { + + require( + children.length == udaf.inputSchema.length, + s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + + s"but ${children.length} are provided.") + + override def nullable: Boolean = true + + override def dataType: DataType = udaf.returnDataType + + override def deterministic: Boolean = udaf.deterministic + + override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) + + override val bufferSchema: StructType = udaf.bufferSchema + + override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes + + override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + + val childrenSchema: StructType = { + val inputFields = children.zipWithIndex.map { + case (child, index) => + StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) + } + StructType(inputFields) + } + + lazy val inputProjection = { + val inputAttributes = childrenSchema.toAttributes + log.debug( + s"Creating MutableProj: $children, inputSchema: $inputAttributes.") + try { + GenerateMutableProjection.generate(children, inputAttributes)() + } catch { + case e: Exception => + log.error("Failed to generate mutable projection, fallback to interpreted", e) + new InterpretedMutableProjection(children, inputAttributes) + } + } + + val inputToScalaConverters: Any => Any = + CatalystTypeConverters.createToScalaConverter(childrenSchema) + + val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToCatalystConverter(field.dataType) + } + + val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToScalaConverter(field.dataType) + } + + lazy val inputAggregateBuffer: InputAggregationBuffer = + new InputAggregationBuffer( + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + lazy val mutableAggregateBuffer: MutableAggregationBuffer = + new MutableAggregationBuffer( + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + + override def initialize(buffer: MutableRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.initialize(mutableAggregateBuffer) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.update( + mutableAggregateBuffer, + inputToScalaConverters(inputProjection(input)).asInstanceOf[Row]) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer1 + inputAggregateBuffer.underlyingInputBuffer = buffer2 + + udaf.update(mutableAggregateBuffer, inputAggregateBuffer) + } + + override def eval(buffer: InternalRow = null): Any = { + inputAggregateBuffer.underlyingInputBuffer = buffer + + udaf.evaluate(inputAggregateBuffer) + } + + override def toString: String = { + s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})""" + } +} diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate2/MyJavaUDAF.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate2/MyJavaUDAF.java index bd089ea587d3c..6d13774457929 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate2/MyJavaUDAF.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate2/MyJavaUDAF.java @@ -20,8 +20,8 @@ import java.util.ArrayList; import java.util.List; -import org.apache.spark.sql.catalyst.expressions.aggregate2.MutableAggregationBuffer; -import org.apache.spark.sql.execution.expressions.aggregate2.UserDefinedAggregateFunction; +import org.apache.spark.sql.expressions.aggregate2.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.aggregate2.UserDefinedAggregateFunction; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataType; @@ -48,7 +48,7 @@ public MyJavaUDAF() { _returnDataType = DataTypes.DoubleType; } - @Override public StructType inputDataType() { + @Override public StructType inputSchema() { return _inputDataType; }