Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A prototype of vectorized UDAF No. 1. #2

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ private[spark] object PythonEvalType {

val SQL_PANDAS_SCALAR_UDF = 200
val SQL_PANDAS_GROUP_MAP_UDF = 201
val SQL_PANDAS_GROUP_AGGREGATE_UDF = 202
}

/**
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class PythonEvalType(object):

SQL_PANDAS_SCALAR_UDF = 200
SQL_PANDAS_GROUP_MAP_UDF = 201
SQL_PANDAS_GROUP_AGGREGATE_UDF = 202


def portable_hash(x):
Expand Down
36 changes: 35 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import StringType, DataType
from pyspark.sql.udf import UserDefinedFunction, _create_udf
from pyspark.sql.udf import UserDefinedFunction, UserDefinedAggregateFunction, _create_udf


def _create_function(name, doc=""):
Expand Down Expand Up @@ -2241,6 +2241,40 @@ def pandas_udf(f=None, returnType=None, functionType=None):
return _create_udf(f=f, returnType=return_type, evalType=eval_type)


# ---------------------------- User Defined Aggregate Function ----------------------------------

def pandas_udaf(final=None, returnType=StringType(), algebraic=False, partial=None,
bufferType=None):
"""
Creates a :class:`Column` expression representing a vectorized user defined aggregate
function (UDAF).
"""
def _udaf(final, returnType, algebraic, partial, bufferType):
if algebraic:
partial = partial or final
bufferType = bufferType or returnType
else:
if partial is None or bufferType is None:
raise ValueError(
"If not algebraic, partial and bufferType must be defined.")
udaf_obj = UserDefinedAggregateFunction(final, returnType, partial, bufferType)
return udaf_obj._wrapped()

# decorator @pandas_udaf, @pandas_udaf() or @pandas_udaf(dataType())
if final is None or isinstance(final, (str, DataType)):
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
if isinstance(returnType, bool):
algebraic = returnType
returnType = StringType()
return_type = final or returnType
return functools.partial(_udaf, returnType=return_type, algebraic=algebraic,
partial=partial, bufferType=bufferType)
else:
return _udaf(final=final, returnType=returnType, algebraic=algebraic,
partial=partial, bufferType=bufferType)


blacklist = ['map', 'since', 'ignore_unicode_prefix']
__all__ = [k for k, v in globals().items()
if not k.startswith('_') and k[0].islower() and callable(v) and k not in blacklist]
Expand Down
24 changes: 24 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3849,6 +3849,30 @@ def test_unsupported_types(self):
df.groupby('id').apply(f).collect()


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class VectorizedUDAFTests(ReusedSQLTestCase):

def test_vectorized_udaf_basic(self):
from pyspark.sql.functions import pandas_udaf, col, expr
df = self.spark.range(100).select(col('id').alias('n'), (col('id') % 2 == 0).alias('g'))

@pandas_udaf(LongType(), algebraic=True)
def p_sum(v):
return v.sum()

@pandas_udaf(
DoubleType(),
algebraic=False,
partial=lambda v: (v.sum(), v.count()),
bufferType=StructType().add("sum", LongType()).add("count", LongType()))
def p_avg(sum, count):
return (sum.sum() / count.sum())

res = df.groupBy(col('g')).agg(p_sum(col('n')), expr('count(n)'), p_avg(col('n')))
expected = df.groupBy(col('g')).agg(expr('sum(n)'), expr('count(n)'), expr('avg(n)'))
self.assertEquals(expected.collect(), res.collect())


if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
Expand Down
104 changes: 104 additions & 0 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,107 @@ def wrapper(*args):
wrapper.evalType = self.evalType

return wrapper


class UserDefinedAggregateFunction(object):
"""
User defined aggregate function in Python

.. versionadded:: 2.3
"""
def __init__(self, final, returnType, partial, bufferType, name=None):
for f in [final, partial]:
if not callable(f):
raise TypeError(
"Not a function or callable (__call__ is not defined): "
"{0}".format(type(f)))

self.final = final
self._returnType = returnType
self.partial = partial
self._bufferType = bufferType
# Stores UserDefinedPythonFunctions jobj, once initialized
self._returnType_placeholder = None
self._bufferType_placeholder = None
self._judaf_placeholder = None
self._name = name or (
final.__name__ if hasattr(final, '__name__')
else final.__class__.__name__)

@property
def returnType(self):
# This makes sure this is called after SparkContext is initialized.
# ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
if self._returnType_placeholder is None:
if isinstance(self._returnType, DataType):
self._returnType_placeholder = self._returnType
else:
self._returnType_placeholder = _parse_datatype_string(self._returnType)
return self._returnType_placeholder

@property
def bufferType(self):
# This makes sure this is called after SparkContext is initialized.
# ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
if self._bufferType_placeholder is None:
if isinstance(self._bufferType, DataType):
self._bufferType_placeholder = self._bufferType
else:
self._bufferType_placeholder = _parse_datatype_string(self._bufferType)
return self._bufferType_placeholder

@property
def _judaf(self):
# It is possible that concurrent access, to newly created UDF,
# will initialize multiple UserDefinedPythonFunctions.
# This is unlikely, doesn't affect correctness,
# and should have a minimal performance impact.
if self._judaf_placeholder is None:
self._judaf_placeholder = self._create_judaf()
return self._judaf_placeholder

def _create_judaf(self):
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

wrapped_final = _wrap_function(sc, self.final, self.returnType)
wrapped_partial = _wrap_function(sc, self.partial, self.bufferType)
jdt_final = spark._jsparkSession.parseDataType(self.returnType.json())
jdt_partial = spark._jsparkSession.parseDataType(self.bufferType.json())
judaf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedAggregatePythonFunction(
self._name, wrapped_final, jdt_final, wrapped_partial, jdt_partial)
return judaf

def __call__(self, *cols):
judaf = self._judaf
sc = SparkContext._active_spark_context
return Column(judaf.apply(_to_seq(sc, cols, _to_java_column)))

def _wrapped(self):
"""
Wrap this udf with a function and attach docstring from func
"""

# It is possible for a callable instance without __name__ attribute or/and
# __module__ attribute to be wrapped here. For example, functools.partial. In this case,
# we should avoid wrapping the attributes from the wrapped function to the wrapper
# function. So, we take out these attribute names from the default names to set and
# then manually assign it after being wrapped.
assignments = tuple(
a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')

@functools.wraps(self.final, assigned=assignments)
def wrapper(*args):
return self(*args)

wrapper.__name__ = self._name
wrapper.__module__ = (self.final.__module__ if hasattr(self.final, '__module__')
else self.final.__class__.__module__)
wrapper.final = self.final
wrapper.returnType = self.returnType
wrapper.partial = self.partial
wrapper.bufferType = self.bufferType

return wrapper
30 changes: 27 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
from pyspark.sql.types import to_arrow_type, StructType
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -110,6 +110,24 @@ def wrapped(*series):
return wrapped


def wrap_pandas_group_aggregate_udf(f, return_type):
import pandas as pd
if isinstance(return_type, StructType):
arrow_return_types = [to_arrow_type(field.dataType) for field in return_type]
else:
arrow_return_types = [to_arrow_type(return_type)]

def fn(*args):
out = f(*[pd.Series(arg[0]) for arg in args])
if not isinstance(out, (tuple, list)):
out = (out,)
assert len(out) == len(arrow_return_types), \
'Columns of tuple don\'t match return schema'

return [(pd.Series(v), t) for v, t in zip(out, arrow_return_types)]
return fn


def read_single_udf(pickleSer, infile, eval_type):
num_arg = read_int(infile)
arg_offsets = [read_int(infile) for i in range(num_arg)]
Expand All @@ -126,6 +144,8 @@ def read_single_udf(pickleSer, infile, eval_type):
return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type)
elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type)
elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF:
return arg_offsets, wrap_pandas_group_aggregate_udf(row_func, return_type)
else:
return arg_offsets, wrap_udf(row_func, return_type)

Expand All @@ -143,13 +163,17 @@ def read_udfs(pickleSer, infile, eval_type):
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
# In the special case of a single UDF this will return a single result rather
# than a tuple of results; this is the format that the JVM side expects.
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
if eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF:
mapper_str = "lambda a: sum([%s], [])" % (", ".join(call_udf))
else:
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
mapper = eval(mapper_str, udfs)

func = lambda _, it: map(mapper, it)

if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \
or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF \
or eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGGREGATE_UDF:
ser = ArrowStreamPandasSerializer()
else:
ser = BatchedSerializer(PickleSerializer(), 100)
Expand Down
Loading