From abecbf8ee3a45cae73d925b629531316955c8f8e Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 2 Oct 2017 16:49:17 +0900 Subject: [PATCH 1/7] Add ArrayDataBuffer. --- .../sql/catalyst/util/ArrayDataBuffer.scala | 174 ++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayDataBuffer.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayDataBuffer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayDataBuffer.scala new file mode 100644 index 0000000000000..6e4781637ae71 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayDataBuffer.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{DataType, Decimal} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +class ArrayDataBuffer(val buffer: ArrayBuffer[Any]) extends ArrayData { + + def this(initialCapacity: Int) = this(new ArrayBuffer[Any](initialCapacity)) + def this() = this(new ArrayBuffer[Any]()) + + override def copy(): ArrayData = { + val newValues = new ArrayBuffer[Any](buffer.length) + var i = 0 + while (i < buffer.length) { + newValues(i) = InternalRow.copyValue(buffer(i)) + i += 1 + } + new ArrayDataBuffer(newValues) + } + + override def array: Array[Any] = { + val newValues = new Array[Any](buffer.length) + var i = 0 + while (i < buffer.length) { + newValues(i) = InternalRow.copyValue(buffer(i)) + i += 1 + } + newValues + } + + override def numElements(): Int = buffer.length + + private def getAs[T](ordinal: Int) = buffer(ordinal).asInstanceOf[T] + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + override def getByte(ordinal: Int): Byte = getAs(ordinal) + override def getShort(ordinal: Int): Short = getAs(ordinal) + override def getInt(ordinal: Int): Int = getAs(ordinal) + override def getLong(ordinal: Int): Long = getAs(ordinal) + override def getFloat(ordinal: Int): Float = getAs(ordinal) + override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def getMap(ordinal: Int): MapData = getAs(ordinal) + + override def setNullAt(ordinal: Int): Unit = buffer(ordinal) = null + + override def update(ordinal: Int, value: Any): Unit = buffer(ordinal) = value + + def +=(value: Any): this.type = { + buffer += value + this + } + + def ++=(values: TraversableOnce[Any]): this.type = { + buffer ++= values + this + } + + def ++=(values: ArrayData): this.type = { + values match { + case buff: ArrayDataBuffer => buffer ++= buff.buffer + case _ => buffer ++= values.array + } + this + } + + def clear(): Unit = { + buffer.clear() + } + + override def toString(): String = buffer.mkString("[", ",", "]") + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[ArrayDataBuffer]) { + return false + } + + val other = o.asInstanceOf[ArrayDataBuffer] + if (other eq null) { + return false + } + + val len = numElements() + if (len != other.numElements()) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = buffer(i) + val o2 = other.buffer(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numElements() + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + buffer(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} From 1fdf77dc3a6b11c5d3eba2655c1b43388fed351e Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Tue, 19 Sep 2017 13:57:12 +0900 Subject: [PATCH 2/7] Add BufferInput. --- .../spark/sql/execution/python/udaf.scala | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/udaf.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/udaf.scala new file mode 100644 index 0000000000000..73709593f3185 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/udaf.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.util.{ArrayData, ArrayDataBuffer} +import org.apache.spark.sql.types._ + +case class BufferInputs( + children: Seq[Expression], + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends ImperativeAggregate + with NonSQLExpression + with Logging { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = true + + override def dataType: DataType = aggBufferSchema + + override val aggBufferSchema: StructType = + StructType(children.zipWithIndex.map { + case (child, i) => + StructField(s"_$i", ArrayType(child.dataType, child.nullable), nullable = false) + }) + + override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + private[this] lazy val childrenSchema: StructType = { + val inputFields = children.zipWithIndex.map { + case (child, index) => + StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) + } + StructType(inputFields) + } + + private lazy val inputProjection = { + val inputAttributes = childrenSchema.toAttributes + log.debug( + s"Creating MutableProj: $children, inputSchema: $inputAttributes.") + GenerateMutableProjection.generate(children, inputAttributes) + } + + override def initialize(buffer: InternalRow): Unit = { + aggBufferSchema.zipWithIndex.foreach { case (_, i) => + buffer.update(i + mutableAggBufferOffset, new ArrayDataBuffer()) + } + } + + override def update(buffer: InternalRow, input: InternalRow): Unit = { + val projected = inputProjection(input) + aggBufferSchema.zip(childrenSchema).zipWithIndex.foreach { + case ((StructField(_, dt @ ArrayType(_, _), _, _), childSchema), i) => + val bufferOffset = i + mutableAggBufferOffset + val arrayDataBuffer = + buffer.get(bufferOffset, dt).asInstanceOf[ArrayDataBuffer] + if (projected.isNullAt(i)) { + arrayDataBuffer += null + } else { + arrayDataBuffer += InternalRow.copyValue(projected.get(i, childSchema.dataType)) + } + } + } + + override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { + aggBufferSchema.zipWithIndex.foreach { + case (StructField(_, dt @ ArrayType(elementType, _), _, _), i) => + val bufferOffset = i + mutableAggBufferOffset + val inputOffset = i + inputAggBufferOffset + val arrayDataBuffer1 = buffer1.get(bufferOffset, dt).asInstanceOf[ArrayDataBuffer] + buffer2.get(inputOffset, dt) match { + case arrayDataBuffer2: UnsafeArrayData => + elementType match { + case BooleanType => arrayDataBuffer1 ++= arrayDataBuffer2.toBooleanArray() + case ByteType => arrayDataBuffer1 ++= arrayDataBuffer2.toByteArray() + case ShortType => arrayDataBuffer1 ++= arrayDataBuffer2.toShortArray() + case IntegerType => arrayDataBuffer1 ++= arrayDataBuffer2.toIntArray() + case LongType => arrayDataBuffer1 ++= arrayDataBuffer2.toLongArray() + case FloatType => arrayDataBuffer1 ++= arrayDataBuffer2.toFloatArray() + case DoubleType => arrayDataBuffer1 ++= arrayDataBuffer2.toDoubleArray() + } + case arrayDataBuffer2: ArrayData => + arrayDataBuffer1 ++= arrayDataBuffer2 + } + } + } + + private val row = new GenericInternalRow(aggBufferSchema.size) + + override def eval(buffer: InternalRow): Any = { + aggBufferSchema.zipWithIndex.foreach { case (buffSchema, i) => + val bufferOffset = i + mutableAggBufferOffset + row.update(i, buffer.get(bufferOffset, buffSchema.dataType)) + } + row + } +} From f3f89a9b448b3cd92307bf664f126491e93875e8 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 2 Oct 2017 16:36:27 +0900 Subject: [PATCH 3/7] Add UserDefinedAggregatePythonFunction. --- .../spark/sql/execution/python/udaf.scala | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/udaf.scala index 73709593f3185..f7b2b35d48d15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/udaf.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.python +import org.apache.spark.api.python.PythonFunction import org.apache.spark.internal.Logging +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -25,6 +27,68 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjecti import org.apache.spark.sql.catalyst.util.{ArrayData, ArrayDataBuffer} import org.apache.spark.sql.types._ +/** + * A user-defined aggregate Python function. This is used by the Python API. + */ +case class UserDefinedAggregatePythonFunction( + name: String, + evaluate: PythonFunction, + returnType: DataType, + partial: PythonFunction, + bufferType: DataType) { + + /** + * Creates a `Column` for this UDAF using given `Column`s as input arguments. + */ + @scala.annotation.varargs + def apply(exprs: Column*): Column = { + val aggregateExpression = + AggregateExpression( + PythonUDAF(exprs.map(_.expr), this), + Complete, + isDistinct = false) + Column(aggregateExpression) + } +} + +/** + * The internal wrapper used to hook a [[UserDefinedAggregatePythonFunction]] `udaf` in the + * internal aggregation code path. + */ +case class PythonUDAF( + children: Seq[Expression], + udaf: UserDefinedAggregatePythonFunction) + extends AggregateFunction + with Unevaluable + with NonSQLExpression + with UserDefinedExpression { + + override def nullable: Boolean = true + + override def dataType: DataType = udaf.returnType + + override val aggBufferSchema: StructType = udaf.bufferType match { + case StructType(fields) => StructType(fields.map { field => + StructField(s"${udaf.name}.${field.name}", + ArrayType(field.dataType, containsNull = field.nullable), nullable = false) + }) + case dt => new StructType().add(udaf.name, ArrayType(dt, containsNull = true), nullable = false) + } + + override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + override def toString: String = { + s"${udaf.name}(${children.mkString(",")})" + } + + override def nodeName: String = udaf.name +} + case class BufferInputs( children: Seq[Expression], mutableAggBufferOffset: Int = 0, From 799c82cb4be0d522f1f02410c686df727c0a5a2d Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 2 Oct 2017 17:58:19 +0900 Subject: [PATCH 4/7] Add pandas_udaf decorator. --- python/pyspark/sql/functions.py | 36 ++++++++++- python/pyspark/sql/udf.py | 104 ++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b631e2041706f..b2a896cbaa1d8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -32,7 +32,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import StringType, DataType -from pyspark.sql.udf import UserDefinedFunction, _create_udf +from pyspark.sql.udf import UserDefinedFunction, UserDefinedAggregateFunction, _create_udf def _create_function(name, doc=""): @@ -2241,6 +2241,40 @@ def pandas_udf(f=None, returnType=None, functionType=None): return _create_udf(f=f, returnType=return_type, evalType=eval_type) +# ---------------------------- User Defined Aggregate Function ---------------------------------- + +def pandas_udaf(final=None, returnType=StringType(), algebraic=False, partial=None, + bufferType=None): + """ + Creates a :class:`Column` expression representing a vectorized user defined aggregate + function (UDAF). + """ + def _udaf(final, returnType, algebraic, partial, bufferType): + if algebraic: + partial = partial or final + bufferType = bufferType or returnType + else: + if partial is None or bufferType is None: + raise ValueError( + "If not algebraic, partial and bufferType must be defined.") + udaf_obj = UserDefinedAggregateFunction(final, returnType, partial, bufferType) + return udaf_obj._wrapped() + + # decorator @pandas_udaf, @pandas_udaf() or @pandas_udaf(dataType()) + if final is None or isinstance(final, (str, DataType)): + # If DataType has been passed as a positional argument + # for decorator use it as a returnType + if isinstance(returnType, bool): + algebraic = returnType + returnType = StringType() + return_type = final or returnType + return functools.partial(_udaf, returnType=return_type, algebraic=algebraic, + partial=partial, bufferType=bufferType) + else: + return _udaf(final=final, returnType=returnType, algebraic=algebraic, + partial=partial, bufferType=bufferType) + + blacklist = ['map', 'since', 'ignore_unicode_prefix'] __all__ = [k for k, v in globals().items() if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist] diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c3301a41ccd5a..30d86439c7de8 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -159,3 +159,107 @@ def wrapper(*args): wrapper.evalType = self.evalType return wrapper + + +class UserDefinedAggregateFunction(object): + """ + User defined aggregate function in Python + + .. versionadded:: 2.3 + """ + def __init__(self, final, returnType, partial, bufferType, name=None): + for f in [final, partial]: + if not callable(f): + raise TypeError( + "Not a function or callable (__call__ is not defined): " + "{0}".format(type(f))) + + self.final = final + self._returnType = returnType + self.partial = partial + self._bufferType = bufferType + # Stores UserDefinedPythonFunctions jobj, once initialized + self._returnType_placeholder = None + self._bufferType_placeholder = None + self._judaf_placeholder = None + self._name = name or ( + final.__name__ if hasattr(final, '__name__') + else final.__class__.__name__) + + @property + def returnType(self): + # This makes sure this is called after SparkContext is initialized. + # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. + if self._returnType_placeholder is None: + if isinstance(self._returnType, DataType): + self._returnType_placeholder = self._returnType + else: + self._returnType_placeholder = _parse_datatype_string(self._returnType) + return self._returnType_placeholder + + @property + def bufferType(self): + # This makes sure this is called after SparkContext is initialized. + # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. + if self._bufferType_placeholder is None: + if isinstance(self._bufferType, DataType): + self._bufferType_placeholder = self._bufferType + else: + self._bufferType_placeholder = _parse_datatype_string(self._bufferType) + return self._bufferType_placeholder + + @property + def _judaf(self): + # It is possible that concurrent access, to newly created UDF, + # will initialize multiple UserDefinedPythonFunctions. + # This is unlikely, doesn't affect correctness, + # and should have a minimal performance impact. + if self._judaf_placeholder is None: + self._judaf_placeholder = self._create_judaf() + return self._judaf_placeholder + + def _create_judaf(self): + from pyspark.sql import SparkSession + + spark = SparkSession.builder.getOrCreate() + sc = spark.sparkContext + + wrapped_final = _wrap_function(sc, self.final, self.returnType) + wrapped_partial = _wrap_function(sc, self.partial, self.bufferType) + jdt_final = spark._jsparkSession.parseDataType(self.returnType.json()) + jdt_partial = spark._jsparkSession.parseDataType(self.bufferType.json()) + judaf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedAggregatePythonFunction( + self._name, wrapped_final, jdt_final, wrapped_partial, jdt_partial) + return judaf + + def __call__(self, *cols): + judaf = self._judaf + sc = SparkContext._active_spark_context + return Column(judaf.apply(_to_seq(sc, cols, _to_java_column))) + + def _wrapped(self): + """ + Wrap this udf with a function and attach docstring from func + """ + + # It is possible for a callable instance without __name__ attribute or/and + # __module__ attribute to be wrapped here. For example, functools.partial. In this case, + # we should avoid wrapping the attributes from the wrapped function to the wrapper + # function. So, we take out these attribute names from the default names to set and + # then manually assign it after being wrapped. + assignments = tuple( + a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__') + + @functools.wraps(self.final, assigned=assignments) + def wrapper(*args): + return self(*args) + + wrapper.__name__ = self._name + wrapper.__module__ = (self.final.__module__ if hasattr(self.final, '__module__') + else self.final.__class__.__module__) + wrapper.final = self.final + wrapper.returnType = self.returnType + wrapper.partial = self.partial + wrapper.bufferType = self.bufferType + + return wrapper From e7bcb64d25914367f8b3b1e3d040094e059237ac Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 2 Oct 2017 17:55:41 +0900 Subject: [PATCH 5/7] Add ExtractPythonUDAFs. --- .../spark/api/python/PythonRunner.scala | 1 + .../spark/sql/execution/QueryExecution.scala | 1 + .../sql/execution/aggregate/AggUtils.scala | 2 +- .../execution/aggregate/AggregateExec.scala | 37 ++++ .../aggregate/HashAggregateExec.scala | 2 +- .../aggregate/ObjectHashAggregateExec.scala | 2 +- .../aggregate/SortAggregateExec.scala | 2 +- .../python/ArrowEvalPythonExec.scala | 12 +- .../execution/python/ExtractPythonUDAFs.scala | 172 ++++++++++++++++++ 9 files changed, 223 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDAFs.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f524de68fbce0..1492b340b9b06 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -39,6 +39,7 @@ private[spark] object PythonEvalType { val SQL_PANDAS_SCALAR_UDF = 200 val SQL_PANDAS_GROUP_MAP_UDF = 201 + val SQL_PANDAS_GROUP_AGGREGATE_UDF = 202 } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f404621399cea..5f2cf7d3cd189 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -103,6 +103,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( python.ExtractPythonUDFs, + python.ExtractPythonUDAFs, PlanSubqueries(sparkSession), new ReorderJoinPredicates, EnsureRequirements(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index ebbdf1aaa024d..34a259f0579a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.SQLConf * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object AggUtils { - private def createAggregate( + private[sql] def createAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]] = None, groupingExpressions: Seq[NamedExpression] = Nil, aggregateExpressions: Seq[AggregateExpression] = Nil, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala new file mode 100644 index 0000000000000..a94ebef04a992 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.execution.UnaryExecNode + +trait AggregateExec extends UnaryExecNode { + + def requiredChildDistributionExpressions: Option[Seq[Expression]] + + def groupingExpressions: Seq[NamedExpression] + + def aggregateExpressions: Seq[AggregateExpression] + + def aggregateAttributes: Seq[Attribute] + + def initialInputBufferOffset: Int + + def resultExpressions: Seq[NamedExpression] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 51f7c9e22b902..9d3042f808e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -43,7 +43,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { + extends UnaryExecNode with AggregateExec with CodegenSupport { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 66955b8ef723c..c95577795bd87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -65,7 +65,7 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AggregateExec { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index fc87de2c52e41..76cb3de412de2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -38,7 +38,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { + extends UnaryExecNode with AggregateExec { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index e27210117a1e7..2466de45b92f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -58,10 +58,15 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int) /** * A physical plan that evaluates a [[PythonUDF]], */ -case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) +case class ArrowEvalPythonExec( + udfs: Seq[PythonUDF], + output: Seq[Attribute], + child: SparkPlan, + evalType: Int = PythonEvalType.SQL_PANDAS_SCALAR_UDF, + _batchSize: Option[Int] = None) extends EvalPythonExec(udfs, output, child) { - private val batchSize = conf.arrowMaxRecordsPerBatch + private val batchSize = _batchSize.getOrElse(conf.arrowMaxRecordsPerBatch) private val sessionLocalTimeZone = conf.sessionLocalTimeZone protected override def evaluate( @@ -80,8 +85,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) val columnarBatchIter = new ArrowPythonRunner( - funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_SCALAR_UDF, argOffsets, schema, sessionLocalTimeZone) + funcs, bufferSize, reuseWorker, evalType, argOffsets, schema, sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDAFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDAFs.scala new file mode 100644 index 0000000000000..db6ba48e6484d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDAFs.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.mutable.{ArrayBuffer, Map} + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.aggregate.{AggregateExec, AggUtils} +import org.apache.spark.sql.types.{ArrayType, StructType} + +object ExtractPythonUDAFs extends Rule[SparkPlan] { + + private def isPythonUDAF(aggregateExpression: AggregateExpression): Boolean = { + aggregateExpression.aggregateFunction.isInstanceOf[PythonUDAF] + } + + private def hasPythonUDAF(aggregateExpressions: Seq[AggregateExpression]): Boolean = { + aggregateExpressions.exists(isPythonUDAF) + } + + private def hasDistinct(aggregateExpressions: Seq[AggregateExpression]): Boolean = { + aggregateExpressions.exists(_.isDistinct) + } + + override def apply(plan: SparkPlan): SparkPlan = plan transformUp { + case agg: AggregateExec if !hasPythonUDAF(agg.aggregateExpressions) => agg + case agg: AggregateExec if hasDistinct(agg.aggregateExpressions) => + throw new AnalysisException("Vectorized UDAF with distinct is not supported.") + case agg: AggregateExec => + + val newAggExprs = ArrayBuffer.empty[AggregateExpression] ++ agg.aggregateExpressions + val newAggAttrs = ArrayBuffer.empty[Attribute] ++ agg.aggregateAttributes + + val buffers = ArrayBuffer.empty[BufferInputs] + val udafs = ArrayBuffer.empty[PythonUDF] + val udafResultAttrs = ArrayBuffer.empty[AttributeReference] + + val replacingReslutExprs = Map.empty[Expression, NamedExpression] ++ + agg.groupingExpressions.map(expr => expr -> expr.toAttribute) + + agg.aggregateExpressions.foreach { + case aggExpr if isPythonUDAF(aggExpr) => + val pythonUDAF = aggExpr.aggregateFunction.asInstanceOf[PythonUDAF] + + aggExpr.mode match { + case Partial => + val buffer = buffers.find { buf => + buf.children.length == pythonUDAF.children.length && + buf.children.zip(pythonUDAF.children).forall { case (c, child) => + c.semanticEquals(child) + } + } match { + case Some(buf) => + newAggExprs -= aggExpr + newAggAttrs --= pythonUDAF.aggBufferAttributes + + buf + case None => + val buf = BufferInputs(pythonUDAF.children) + buffers += buf + + newAggExprs.update( + newAggExprs.indexOf(aggExpr), aggExpr.copy(aggregateFunction = buf)) + + val index = newAggAttrs.indexOfSlice(pythonUDAF.aggBufferAttributes) + newAggAttrs --= pythonUDAF.aggBufferAttributes + newAggAttrs.insertAll(index, buf.aggBufferAttributes) + + buf + } + + val udaf = PythonUDF(pythonUDAF.udaf.name, pythonUDAF.udaf.partial, + pythonUDAF.udaf.bufferType, buffer.aggBufferAttributes, + PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF) + udafs += udaf + + val (resultAttrs, replacingExprs) = pythonUDAF.inputAggBufferAttributes.map { attr => + val arrayType = attr.dataType.asInstanceOf[ArrayType] + val resultAttr = AttributeReference( + attr.name, arrayType.elementType, arrayType.containsNull)() + (resultAttr, attr -> Alias(CreateArray(Seq(resultAttr)), attr.name)()) + }.unzip + udafResultAttrs ++= resultAttrs + replacingReslutExprs ++= replacingExprs + + case Final => + val buffer = BufferInputs(pythonUDAF.inputAggBufferAttributes.map { attr => + val arrayType = attr.dataType.asInstanceOf[ArrayType] + AttributeReference(attr.name, arrayType.elementType, arrayType.containsNull)() + }) + + newAggExprs.update( + newAggExprs.indexOf(aggExpr), aggExpr.copy(aggregateFunction = buffer)) + + val bufferOut = AttributeReference("buffer", buffer.dataType, buffer.nullable)() + newAggAttrs.update(newAggAttrs.indexOf(aggExpr.resultAttribute), bufferOut) + + val udafInputs = buffer.dataType.asInstanceOf[StructType].zipWithIndex.map { + case (field, idx) => + GetStructField(bufferOut, idx, Option(field.name)) + } + val udaf = PythonUDF(pythonUDAF.udaf.name, pythonUDAF.udaf.evaluate, + pythonUDAF.udaf.returnType, udafInputs, + PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF) + udafs += udaf + + val resultAttr = AttributeReference(udaf.name, udaf.dataType, udaf.nullable)() + udafResultAttrs += resultAttr + replacingReslutExprs += aggExpr.resultAttribute -> resultAttr + + case _ => + throw new AnalysisException(s"Unsupported aggregate mode: ${aggExpr.mode}.") + } + case aggExpr => + aggExpr.mode match { + case Partial => + val af = aggExpr.aggregateFunction + replacingReslutExprs ++= + af.inputAggBufferAttributes.zip(af.aggBufferAttributes).map { + case (attr, buffer) => + attr -> Alias(buffer, attr.name)( + attr.exprId, attr.qualifier, Option(attr.metadata)) + } + case _ => + } + } + + val newAgg = AggUtils.createAggregate( + requiredChildDistributionExpressions = agg.requiredChildDistributionExpressions, + groupingExpressions = agg.groupingExpressions, + aggregateExpressions = newAggExprs, + aggregateAttributes = newAggAttrs, + initialInputBufferOffset = agg.initialInputBufferOffset, + resultExpressions = agg.groupingExpressions ++ newAggAttrs, + child = agg.child) + + val exec = ArrowEvalPythonExec( + udafs, + newAgg.output ++ udafResultAttrs, + newAgg, + PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF, + Some(1)) + + val project = agg.resultExpressions.map { expr => + expr.transformUp { + case expr if replacingReslutExprs.contains(expr) => replacingReslutExprs(expr) + }.asInstanceOf[NamedExpression] + } + + ProjectExec(project, exec) + } +} From 2b96946420c6fdca3f5fc379e8ab756388f88e30 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 2 Oct 2017 16:30:38 +0900 Subject: [PATCH 6/7] Modify read_udfs to support pandas_udaf. --- python/pyspark/rdd.py | 1 + python/pyspark/worker.py | 30 +++++++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 340bc3a6b7470..48a1a1b785c23 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -70,6 +70,7 @@ class PythonEvalType(object): SQL_PANDAS_SCALAR_UDF = 200 SQL_PANDAS_GROUP_MAP_UDF = 201 + SQL_PANDAS_GROUP_AGGREGATE_UDF = 202 def portable_hash(x): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 939643071943a..d97df2a7d3d93 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -33,7 +33,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer -from pyspark.sql.types import to_arrow_type +from pyspark.sql.types import to_arrow_type, StructType from pyspark import shuffle pickleSer = PickleSerializer() @@ -110,6 +110,24 @@ def wrapped(*series): return wrapped +def wrap_pandas_group_aggregate_udf(f, return_type): + import pandas as pd + if isinstance(return_type, StructType): + arrow_return_types = [to_arrow_type(field.dataType) for field in return_type] + else: + arrow_return_types = [to_arrow_type(return_type)] + + def fn(*args): + out = f(*[pd.Series(arg[0]) for arg in args]) + if not isinstance(out, (tuple, list)): + out = (out,) + assert len(out) == len(arrow_return_types), \ + 'Columns of tuple don\'t match return schema' + + return [(pd.Series(v), t) for v, t in zip(out, arrow_return_types)] + return fn + + def read_single_udf(pickleSer, infile, eval_type): num_arg = read_int(infile) arg_offsets = [read_int(infile) for i in range(num_arg)] @@ -126,6 +144,8 @@ def read_single_udf(pickleSer, infile, eval_type): return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type) elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF: + return arg_offsets, wrap_pandas_group_aggregate_udf(row_func, return_type) else: return arg_offsets, wrap_udf(row_func, return_type) @@ -143,13 +163,17 @@ def read_udfs(pickleSer, infile, eval_type): # lambda a: (f0(a0), f1(a1, a2), f2(a3)) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. - mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + if eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF: + mapper_str = "lambda a: sum([%s], [])" % (", ".join(call_udf)) + else: + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \ - or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: + or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \ + or eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF: ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) From 646b0a78dc71784713503e2e0676664facf8fbd0 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 2 Oct 2017 18:26:34 +0900 Subject: [PATCH 7/7] Add tests. --- python/pyspark/sql/tests.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 762afe0d730f3..1b69ea4279508 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3849,6 +3849,30 @@ def test_unsupported_types(self): df.groupby('id').apply(f).collect() +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class VectorizedUDAFTests(ReusedSQLTestCase): + + def test_vectorized_udaf_basic(self): + from pyspark.sql.functions import pandas_udaf, col, expr + df = self.spark.range(100).select(col('id').alias('n'), (col('id') % 2 == 0).alias('g')) + + @pandas_udaf(LongType(), algebraic=True) + def p_sum(v): + return v.sum() + + @pandas_udaf( + DoubleType(), + algebraic=False, + partial=lambda v: (v.sum(), v.count()), + bufferType=StructType().add("sum", LongType()).add("count", LongType())) + def p_avg(sum, count): + return (sum.sum() / count.sum()) + + res = df.groupBy(col('g')).agg(p_sum(col('n')), expr('count(n)'), p_avg(col('n'))) + expected = df.groupBy(col('g')).agg(expr('sum(n)'), expr('count(n)'), expr('avg(n)')) + self.assertEquals(expected.collect(), res.collect()) + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: