## Examples for PySpark UDFs

### [How to Turn Python Functions into PySpark Functions (UDF)](https://changhsinlee.com/pyspark-udf/)

In [1]:
# Start a Spark session
import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

In [2]:
import pandas as pd
from pyspark.sql.functions import udf

In [3]:
print(pyspark.__version__)

2.4.3


In [4]:
# example data
df_pd = pd.DataFrame(
    data={'integers': [1, 2, 3], 
     'floats': [-1.0, 0.5, 2.7],
     'array_integer': [[1, 2], [3, 4, 5], [6, 7, 8, 9]]}
)

In [5]:
df_pd

Unnamed: 0,integers,floats,array_integer
0,1,-1.0,"[1, 2]"
1,2,0.5,"[3, 4, 5]"
2,3,2.7,"[6, 7, 8, 9]"


In [6]:
df = spark.createDataFrame(df_pd)

In [7]:
df.printSchema()

root
 |-- integers: long (nullable = true)
 |-- floats: double (nullable = true)
 |-- array_integer: array (nullable = true)
 |    |-- element: long (containsNull = true)



In [8]:
df.show()

+--------+------+-------------+
|integers|floats|array_integer|
+--------+------+-------------+
|       1|  -1.0|       [1, 2]|
|       2|   0.5|    [3, 4, 5]|
|       3|   2.7| [6, 7, 8, 9]|
+--------+------+-------------+



#### Turn python function into a Spark function

##### For numeric types

As long as the python function’s output has a corresponding data type in Spark, then I can turn it into a UDF. When registering UDFs, I have to specify the data type using the types from pyspark.sql.types. All the types supported by PySpark [can be found here](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=types#module-pyspark.sql.types).

unlike Python function which works for both integers and floats, a Spark UDF will return a column of NULLs if the input data type doesn’t match the output data type, as in the following example.

In [9]:
def square(x):
    return x**2

It Python, this function can be used for both integer and float

In [10]:
square(5)

25

In [11]:
square(3.0)

9.0

Not so much in Spark!

In [12]:
# Integer type output
from pyspark.sql.types import IntegerType
udf_square_int = udf(lambda z: square(z), IntegerType())

In [13]:
(
    df.select('integers', 
              'floats', 
              udf_square_int('integers').alias('int_squared'), 
              udf_square_int('floats').alias('float_squared')) 
    .show()
)

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|          1|         null|
|       2|   0.5|          4|         null|
|       3|   2.7|          9|         null|
+--------+------+-----------+-------------+



In [14]:
# float type output
from pyspark.sql.types import FloatType
udf_square_float = udf(lambda z: square(z), FloatType())

In [15]:
(
    df.select('integers', 
              'floats', 
              udf_square_float('integers').alias('int_squared'), 
              udf_square_float('floats').alias('float_squared')) 
    .show()
)

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|       null|          1.0|
|       2|   0.5|       null|         0.25|
|       3|   2.7|       null|         7.29|
+--------+------+-----------+-------------+



In [16]:
(
    df.select('integers', 
              'floats', 
              udf_square_int('integers').alias('int_squared'), 
              udf_square_float('floats').alias('float_squared')) 
    .show()
)

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|          1|          1.0|
|       2|   0.5|          4|         0.25|
|       3|   2.7|          9|         7.29|
+--------+------+-----------+-------------+



In [17]:
# () to group pyspark statement to avoid "\"
(
    df.select(
        'integers',
        'floats',
        udf_square_int('integers').alias('int_squared'),
        udf_square_float('floats').alias('float_squared'))
    .show()
)

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|          1|          1.0|
|       2|   0.5|          4|         0.25|
|       3|   2.7|          9|         7.29|
+--------+------+-----------+-------------+



We can specify the output data type in python

In [18]:
## Force the output to be float
def square_float(x):
    return float(x**2)
udf_square_float2 = udf(lambda z: square_float(z), FloatType())

In [19]:
(
    df.select('integers', 
              'floats', 
              udf_square_float2('integers').alias('int_squared'), 
              udf_square_float2('floats').alias('float_squared')) 
    .show()
)

+--------+------+-----------+-------------+
|integers|floats|int_squared|float_squared|
+--------+------+-----------+-------------+
|       1|  -1.0|        1.0|          1.0|
|       2|   0.5|        4.0|         0.25|
|       3|   2.7|        9.0|         7.29|
+--------+------+-----------+-------------+



##### Array types, or lists

When the input of the Python function is a list, then the values in the list have to be of the same type.

In [20]:
def square_list(x):
    return [float(val)**2 for val in x]

In [21]:
from pyspark.sql.types import ArrayType
udf_square_list = udf(lambda y: square_list(y), ArrayType(FloatType()))

In [22]:
df.select('array_integer', udf_square_list('array_integer')).show()

+-------------+-----------------------+
|array_integer|<lambda>(array_integer)|
+-------------+-----------------------+
|       [1, 2]|             [1.0, 4.0]|
|    [3, 4, 5]|      [9.0, 16.0, 25.0]|
| [6, 7, 8, 9]|   [36.0, 49.0, 64.0...|
+-------------+-----------------------+



In [23]:
(
    df.select(
        'integers',
        'floats',
        'array_integer',
        udf_square_int('integers').alias('int_squared'),
        udf_square_float('floats').alias('float_squared'),
        udf_square_list('array_integer').alias('array_squared')
    )
    .show()
)

+--------+------+-------------+-----------+-------------+--------------------+
|integers|floats|array_integer|int_squared|float_squared|       array_squared|
+--------+------+-------------+-----------+-------------+--------------------+
|       1|  -1.0|       [1, 2]|          1|          1.0|          [1.0, 4.0]|
|       2|   0.5|    [3, 4, 5]|          4|         0.25|   [9.0, 16.0, 25.0]|
|       3|   2.7| [6, 7, 8, 9]|          9|         7.29|[36.0, 49.0, 64.0...|
+--------+------+-------------+-----------+-------------+--------------------+



##### Struct types, or tuples

If the Python function returns a tuple, then we can use the Struct data type for PySpark.

In [24]:
from pyspark.sql.types import StructType
from pyspark.sql.types import StructField
from pyspark.sql.types import StringType

A `StructType` contains `StructField`, each has a specific data type.

In [25]:
array_schema_return = StructType([
    StructField('number', IntegerType(), nullable=False),
    StructField('letters', StringType(), nullable=False)
])

In [26]:
import string

# takes a number and returns a letter from ascii_letters
def convert_ascii(number):
    return [number, string.ascii_letters[number]]

[convert_ascii(i) for i in range(5)]

[[0, 'a'], [1, 'b'], [2, 'c'], [3, 'd'], [4, 'e']]

In [27]:
udf_convert_ascii = udf(lambda z: convert_ascii(z), array_schema_return)

In [28]:
df_ascii = df.select('integers', udf_convert_ascii('integers').alias('ascii_map'))

The schema now looks like a tree.

In [29]:
df_ascii.printSchema()

root
 |-- integers: long (nullable = true)
 |-- ascii_map: struct (nullable = true)
 |    |-- number: integer (nullable = false)
 |    |-- letters: string (nullable = false)



In [30]:
df.printSchema()

root
 |-- integers: long (nullable = true)
 |-- floats: double (nullable = true)
 |-- array_integer: array (nullable = true)
 |    |-- element: long (containsNull = true)



In [31]:
df_ascii.show()

+--------+---------+
|integers|ascii_map|
+--------+---------+
|       1|   [1, b]|
|       2|   [2, c]|
|       3|   [3, d]|
+--------+---------+



##### Wrong type example

Spark doesn't know what to do when the Python function returns numpy objects

In [32]:
import numpy as np

In [33]:
d_np = pd.DataFrame({'int_arrays': [[1,2,3], [4,5]]})
df_np = spark.createDataFrame(d_np)
df_np.show()

+----------+
|int_arrays|
+----------+
| [1, 2, 3]|
|    [4, 5]|
+----------+



In [34]:
# squares with a numpy function, which returns a np.ndarray
def square_array_wrong(x):
    return np.square(x)

In [35]:
square_array_wrong([1,2,3])

array([1, 4, 9])

In [36]:
udf_square_array_wrong = udf(square_array_wrong, ArrayType(FloatType()))

It will spit out a Py4JJavaError

In [37]:
df_np.withColumn('doubled', udf_square_array_wrong('int_arrays')).show()

Py4JJavaError: An error occurred while calling o247.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 2 in stage 21.0 failed 1 times, most recent failure: Lost task 2.0 in stage 21.0 (TID 43, localhost, executor driver): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.core.multiarray._reconstruct)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:90)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:89)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:435)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:441)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:255)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3383)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2544)
	at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3364)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3363)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2544)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2758)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:291)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.core.multiarray._reconstruct)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:90)
	at org.apache.spark.sql.execution.python.BatchEvalPythonExec$$anonfun$evaluate$1.apply(BatchEvalPythonExec.scala:89)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:435)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:441)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$13$$anon$1.hasNext(WholeStageCodegenExec.scala:636)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:255)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$24.apply(RDD.scala:836)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:121)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more


I have to make sure the numbers are returned as a list, with values being native Python types:

In [38]:
def square_array_right(x):
    return np.square(x).tolist()

In [39]:
udf_square_array_right = udf(square_array_right, ArrayType(IntegerType()))

Now it returns desired values.

In [40]:
zz = df_np.withColumn('squared', udf_square_array_right('int_arrays'))
zz.show()

+----------+---------+
|int_arrays|  squared|
+----------+---------+
| [1, 2, 3]|[1, 4, 9]|
|    [4, 5]| [16, 25]|
+----------+---------+



In [41]:
(
    df_np.select(
        'int_arrays',
        udf_square_array_right('int_arrays').alias('squared')
    )
    .show()
)

+----------+---------+
|int_arrays|  squared|
+----------+---------+
| [1, 2, 3]|[1, 4, 9]|
|    [4, 5]| [16, 25]|
+----------+---------+



##### slowness

Spark doesn’t distributing the Python function as desired if the dataframe is too small.

To fix this, I repartitioned the dataframe before calling the UDF. For example,

make sure the number of partition is at least the number of executors when I submit a job.

In [42]:
df_repartitioned = df.repartition(100)

### Writing an UDF for withColumn in PySpark

https://gist.github.com/zoltanctoth/2deccd69e3d1cde1dd78

In [45]:
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf

udf_maturity = udf(lambda age: "adult" if age >=18 else "child", StringType())

df = sqlContext.createDataFrame([{'name': 'Alice', 'age': 1}, {'name': 'Bob', 'age': 21}])
df.withColumn("maturity", udf_maturity(df.age)).show()

+---+-----+--------+
|age| name|maturity|
+---+-----+--------+
|  1|Alice|   child|
| 21|  Bob|   adult|
+---+-----+--------+



### Spark: Custom UDF Example


https://ragrawal.wordpress.com/2015/10/02/spark-custom-udf-example/

has both pyspark and scala examples

In [46]:
# Generate Random Data
import itertools
import random
students = ['John', 'Mike','Matt']
subjects = ['Math', 'Sci', 'Geography', 'History']
random.seed(1)
data = []
 
for (student, subject) in itertools.product(students, subjects):
    data.append((student, subject, random.randint(0, 100)))
 
# Create Schema Object
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
schema = StructType([
            StructField("student", StringType(), nullable=False),
            StructField("subject", StringType(), nullable=False),
            StructField("score", IntegerType(), nullable=False)
    ])
 
# Create DataFrame 
from pyspark.sql import HiveContext
sqlContext = HiveContext(sc)
rdd = sc.parallelize(data)
df = sqlContext.createDataFrame(rdd, schema)
 
# Define udf
from pyspark.sql.functions import udf
def scoreToCategory(score):
    if score >= 80: return 'A'
    elif score >= 60: return 'B'
    elif score >= 35: return 'C'
    else: return 'D'
 
udf_ScoreToCategory=udf(scoreToCategory, StringType())
df.withColumn("category", udf_ScoreToCategory("score")).show(10)

+-------+---------+-----+--------+
|student|  subject|score|category|
+-------+---------+-----+--------+
|   John|     Math|   17|       D|
|   John|      Sci|   72|       B|
|   John|Geography|   97|       A|
|   John|  History|    8|       D|
|   Mike|     Math|   32|       D|
|   Mike|      Sci|   15|       D|
|   Mike|Geography|   63|       B|
|   Mike|  History|   97|       A|
|   Matt|     Math|   57|       C|
|   Matt|      Sci|   60|       B|
+-------+---------+-----+--------+
only showing top 10 rows



### More examples from ProgramCreek

[Python pyspark.sql.functions.udf() Examples](https://www.programcreek.com/python/example/98239/pyspark.sql.functions.udf)