# Spark UDF Compare

This notebook is inspired by the original Databricks Blog post for [Introducing Pandas UDF for PySpark](https://www.databricks.com/blog/2017/10/30/introducing-vectorized-udfs-for-pyspark.html). For the original implementation of the benchmark, check the [Pandas UDF Notebook](https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/1281142885375883/2174302049319883/7729323681064935/latest.html).

## Background

I wanted to see if I could rewrite the same functions used in the original blog post in Rust and do the same speed comparison. The point of this is not to prove that you should rewrite your Spark UDFs in Rust, but rather to show it is possible (but probably not the best idea). I kept my rust code as vanilla as possible. This is not to compare other rust created frameworks like Polars or Ballista. 

## Changes

I did *slightly* modify the original notebook. I adjusted the formatting (imports, multiline statements, etc.), and updated the since to match the new Pandas UDF interface for Spark 3.0.0.

In [0]:
import pandas as pd

import pyspark.sql.functions as F
from pyspark.sql.functions import pandas_udf

import pyspark.sql.types as T

from pyspark.sql import Row

from scipy import stats



In [0]:
# To get the pip install to work correctly
# I had to install cargo and the pip package directly on the driver using the web terminal 
# Using the web terminal, run the following command
# curl https://sh.rustup.rs -sSf | sh -s -- -y && source "$HOME/.cargo/env" && pip install rust-udf-example

In [0]:
from rust_udf import plus_one_rs, cdf_rs, subtract_mean_rs

In [0]:
df = (
    spark.range(0, 10 * 1000 * 1000)
    .withColumn("id", (F.col("id") / 10000).cast("integer"))
    .withColumn("v", F.rand())
)

df.cache()

print("Number of records:", df.count())
df.show()

Number of records: 10000000
+---+-------------------+
| id|                  v|
+---+-------------------+
|  0|0.15651868958313198|
|  0|0.21716444334624274|
|  0| 0.8683173173207329|
|  0|0.07800200127201706|
|  0| 0.6932806235978584|
|  0| 0.8377408833427843|
|  0| 0.8538987778780405|
|  0|0.25537961838485645|
|  0| 0.2631392753730042|
|  0|0.42540886550134116|
|  0|0.43707222732694373|
|  0| 0.7095514823050232|
|  0|   0.65171189281833|
|  0|0.16387925660322944|
|  0|0.42694974747514103|
|  0|0.13721004432109374|
|  0| 0.1174495995075876|
|  0|0.05475323785989383|
|  0| 0.3431083493195717|
|  0| 0.3549490254636438|
+---+-------------------+
only showing top 20 rows



In [0]:
@udf('double')
def plus_one(v):
    return v + 1

%timeit df.withColumn('v', plus_one(df.v)).agg(F.count(F.col('v'))).show()

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

3.48 s ± 141 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [0]:
@pandas_udf("double")
def pandas_plus_one(v: pd.Series) -> pd.Series:
    return v + 1

%timeit df.withColumn('v', pandas_plus_one(df.v)).agg(F.count(F.col('v'))).show()

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

1.77 s ± 409 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [0]:
# Rust plus_one
@udf('double')
def plus_one(v):
    return plus_one_rs(v)

%timeit df.withColumn('v', plus_one(df.v)).agg(F.count(F.col('v'))).show()

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

3.9 s ± 61.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [0]:
@udf('double')
def cdf(v):
    return float(stats.norm.cdf(v))

%timeit df.withColumn('cumulative_probability', cdf(df.v)).agg(F.count(F.col('cumulative_probability'))).show()

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|c

In [0]:
@pandas_udf('double')
def pandas_cdf(v: pd.Series) -> pd.Series:
    return pd.Series(stats.norm.cdf(v))

%timeit df.withColumn('cumulative_probability', pandas_cdf(df.v)).agg(F.count(F.col('cumulative_probability'))).show()

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|c

In [0]:
# Rust cdf
@udf('double')
def cdf(v):
  return cdf_rs(v)

%timeit df.withColumn('cumulative_probability', cdf(df.v)).agg(F.count(F.col('cumulative_probability'))).show()

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                     10000000|
+-----------------------------+

+-----------------------------+
|c

In [0]:

@udf(T.ArrayType(df.schema))
def substract_mean(rows):
    vs = pd.Series([r.v for r in rows])
    vs = vs - vs.mean()
    return [Row(id=rows[i]['id'], v=float(vs[i])) for i in range(len(rows))]

%timeit df.groupby('id')\
  .agg(F.collect_list(F.struct(df['id'], df['v'])).alias('rows'))\
  .withColumn('new_rows', substract_mean(F.col('rows')))\
  .withColumn('new_row', F.explode(F.col('new_rows')))\
  .withColumn('id', F.col('new_row.id'))\
  .withColumn('v', F.col('new_row.v'))\
  .agg(F.count(F.col('v'))).show()

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

40.9 s ± 1.78 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [0]:
def pandas_subtract_mean(pdf):
	return pdf.assign(v=pdf.v - pdf.v.mean())

%timeit df.groupby('id').applyInPandas(pandas_subtract_mean, schema=df.schema).agg(F.count(F.col('v'))).show()

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

+--------+
|count(v)|
+--------+
|10000000|
+--------+

2.82 s ± 124 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [0]:
# Rust subtract mean
@udf(T.ArrayType(T.FloatType()))
def subtract_mean(v):
	return subtract_mean_rs(v)

%timeit df.groupby('id')\
  .agg(F.collect_list('v').alias('v_list'))\
  .withColumn('v_new', subtract_mean(F.col('v_list')))\
  .withColumn('v_explode', F.explode(F.col('v_new')))\
  .agg(F.count(F.col('v_explode'))).show()

+----------------+
|count(v_explode)|
+----------------+
|        10000000|
+----------------+

+----------------+
|count(v_explode)|
+----------------+
|        10000000|
+----------------+

+----------------+
|count(v_explode)|
+----------------+
|        10000000|
+----------------+

+----------------+
|count(v_explode)|
+----------------+
|        10000000|
+----------------+

+----------------+
|count(v_explode)|
+----------------+
|        10000000|
+----------------+

+----------------+
|count(v_explode)|
+----------------+
|        10000000|
+----------------+

+----------------+
|count(v_explode)|
+----------------+
|        10000000|
+----------------+

+----------------+
|count(v_explode)|
+----------------+
|        10000000|
+----------------+

3.44 s ± 97.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
