In [24]:
from pyspark.sql.types import *
from pyspark.sql.functions import col, count, rand, collect_list, explode, struct, count, lit
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from pyspark.sql.session import SparkSession

In [25]:
spark = SparkSession.builder.appName("test_udf").getOrCreate()

In [72]:
df = spark.range(0, 10 * 1000).withColumn('id', (col('id') / 1000).cast('integer')).withColumn('v', rand()).withColumn('z', rand())
df.cache()
df.count()

df.show(5)

+---+-------------------+-------------------+
| id|                  v|                  z|
+---+-------------------+-------------------+
|  0| 0.8365971235722934| 0.7449485811905311|
|  0| 0.6510974938412822| 0.5464682438299392|
|  0|0.44093925882189744| 0.7859773974039915|
|  0| 0.9330163903774489|0.23823861718324213|
|  0| 0.6582840340877774| 0.6208003333990663|
+---+-------------------+-------------------+
only showing top 5 rows



In [11]:
@udf('double')
def plus_one(v):
    return v + 1

%timeit df.withColumn('v', plus_one(df.v)).agg(count(col('v'))).show()

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

420 ms ± 75.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
@pandas_udf("double", PandasUDFType.SCALAR)
def pandas_plus_one(v):
    return v + 1

%timeit df.withColumn('v', pandas_plus_one(df.v)).agg(count(col('v'))).show()

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

434 ms ± 107 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
import pandas as pd
from scipy import stats

@udf('double')
def cdf(v):
    return float(stats.norm.cdf(v))

%timeit df.withColumn('cumulative_probability', cdf(df.v)).agg(count(col('cumulative_probability'))).show()

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|c

In [15]:
import pandas as pd
from scipy import stats

@pandas_udf('double', PandasUDFType.SCALAR)
def pandas_cdf(v):
    return pd.Series(stats.norm.cdf(v))

%timeit df.withColumn('cumulative_probability', pandas_cdf(df.v)).agg(count(col('cumulative_probability'))).show()

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|count(cumulative_probability)|
+-----------------------------+
|                        10000|
+-----------------------------+

+-----------------------------+
|c

In [16]:
from pyspark.sql import Row
@udf(ArrayType(df.schema))
def substract_mean(rows):
    vs = pd.Series([r.v for r in rows])
    vs = vs - vs.mean()
    return [Row(id=rows[i]['id'], v=float(vs[i])) for i in range(len(rows))]
  
%timeit df.groupby('id').agg(collect_list(struct(df['id'], df['v'])).alias('rows')).withColumn('new_rows', substract_mean(col('rows'))).withColumn('new_row', explode(col('new_rows'))).withColumn('id', col('new_row.id')).withColumn('v', col('new_row.v')).agg(count(col('v'))).show()

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

7.77 s ± 924 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [17]:
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
# Input/output are both a pandas.DataFrame
def pandas_subtract_mean(pdf):
    return pdf.assign(v=pdf.v - pdf.v.mean())

%timeit df.groupby('id').apply(pandas_subtract_mean).agg(count(col('v'))).show()

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

+--------+
|count(v)|
+--------+
|   10000|
+--------+

6.51 s ± 180 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [29]:
df2 = df.withColumn('y', rand()).withColumn('x1', rand()).withColumn('x2', rand()).select('id', 'y', 'x1', 'x2')
df2.show(30)

+---+-------------------+--------------------+--------------------+
| id|                  y|                  x1|                  x2|
+---+-------------------+--------------------+--------------------+
|  0| 0.8502610919417317| 0.43669007543764393|  0.6800893642619276|
|  0| 0.1455883495715783|  0.7184844266257804|  0.3984815415680507|
|  0| 0.9994549277407335|  0.8372445956077054|  0.5456350473248871|
|  0| 0.6160462489797837|0.003917068191399142|  0.7875559960531864|
|  0|0.39779731583851785|  0.7232312805840987|  0.9806549511160453|
|  0| 0.3906420458991622| 0.29890075952728123| 0.13838777368518385|
|  0|  0.498729518987718|0.001235759965507...|   0.233930478833418|
|  0|0.25831603360624056|   0.648782742518444| 0.19288905153955715|
|  0| 0.9474342051762257|  0.9321677940806428|  0.5540365429916454|
|  0|  0.555576313429177|  0.9748629686753693| 0.27188506427130144|
|  1| 0.2891660739801619|  0.1956211183035622| 0.40443111644860785|
|  1| 0.4124620760445288| 0.49168074129757977|  

In [40]:
import pandas as pd
import statsmodels.api as sm
# df has four columns: id, y, x1, x2

group_column = 'id'
y_column = 'y'
x_columns = ['x1', 'x2']
schema = df2.select(group_column, *x_columns).schema

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
# Input/output are both a pandas.DataFrame
def ols(pdf):
    group_key = pdf[group_column].iloc[0]
    y = pdf[y_column]
    X = pdf[x_columns]
    X = sm.add_constant(X)
    model = sm.OLS(y, X).fit()
    return pd.DataFrame([[group_key] + [model.params[i] for i in   x_columns]], columns=[group_column] + x_columns)

beta = df2.groupby(group_column).apply(ols)
beta.show(10)

+---+--------------------+--------------------+
| id|                  x1|                  x2|
+---+--------------------+--------------------+
|148|  -0.597982841924691| -0.5667421850943335|
|463| 0.11315407143220102| 0.08908789917247106|
|471|1.598112473393809E-4| -0.1668047766170488|
|496| -0.5174565155997618| -0.8970306405166046|
|833|-0.01007891407192...|  -0.616031899311242|
|243|-0.35780117329274563| -0.0160277204652311|
|392|  0.4251997985221291|-0.02313769104947...|
|540|  0.8043585716685674| 0.16671267844593776|
|623| -0.2314952571236663|-0.09849386531058979|
|737|  0.1413629812752918|-0.06903615961904919|
+---+--------------------+--------------------+
only showing top 10 rows



In [80]:
import pandas as pd
import statsmodels.api as sm
# df has four columns: id, y, x1, x2

group_column = 'id'
y_column = 'y'
x_columns = ['x1', 'x2']
schema = df2.select(group_column, *x_columns).schema

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
# Input/output are both a pandas.DataFrame
def my_sum(pdf):
    return pdf.groupby(by=pdf.id).sum()

res = df2.groupby(group_column).apply(my_sum)
res.show(10)

+---+------------------+------------------+
| id|                x1|                x2|
+---+------------------+------------------+
|  6| 4.145265574081791|5.1015533582358525|
|  3|  5.93146176912726| 3.883158863825751|
|  5| 5.633078964801301|3.4180928779596913|
|  4|  4.48452897320926| 4.198278878605744|
|  4| 6.363309710132798|3.7443245658833395|
|  4| 5.097097014576121|  4.78374794068377|
|  4|3.9526239874915414|5.5833011105912655|
|  5|3.7599003546248326| 6.593672132713081|
|  3|4.2134823631276666|3.7120460830980955|
|  6| 4.844121957968311| 3.969748829993887|
+---+------------------+------------------+
only showing top 10 rows

