In [461]:
import findspark
findspark.init()

import pandas as pd

%load_ext autotime

import chart_studio.plotly as py
init_notebook_mode(connected=True)

import pyspark

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

The autotime extension is already loaded. To reload it, use:
  %reload_ext autotime


time: 4 ms


In [6]:
spark = SparkSession \
    .builder \
    .appName("RFM Analysis with PySpark") \
    .getOrCreate()

time: 35.1 s


In [7]:
spark

time: 14 ms


Data Source: https://www.kaggle.com/carrie1/ecommerce-data

In [232]:
data = spark.read.format("csv").option("header", "true").load("data.csv")

time: 277 ms


In [233]:
data

DataFrame[InvoiceNo: string, StockCode: string, Description: string, Quantity: string, InvoiceDate: string, UnitPrice: string, CustomerID: string, Country: string]

time: 5 ms


In [234]:
data.columns

['InvoiceNo',
 'StockCode',
 'Description',
 'Quantity',
 'InvoiceDate',
 'UnitPrice',
 'CustomerID',
 'Country']

time: 2 ms


In [235]:
data.printSchema()

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: string (nullable = true)
 |-- InvoiceDate: string (nullable = true)
 |-- UnitPrice: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- Country: string (nullable = true)

time: 2 ms


In [237]:
# cache data in memory
data.cache().count()

541909

time: 97 ms


In [432]:
data.groupBy("RecencyDays").count().show()

+-----------+-----+
|RecencyDays|count|
+-----------+-----+
|        148| 1455|
|         31| 3454|
|        137| 1035|
|         85| 2426|
|         65| 2540|
|        255| 1025|
|         53| 4070|
|        296| 1131|
|         78| 1810|
|        321|  624|
|        375| 1586|
|        155| 1131|
|        108| 1376|
|        211|  858|
|        193| 1538|
|         34| 2543|
|        115| 2112|
|        101| 3129|
|         81| 2467|
|        183| 1027|
+-----------+-----+
only showing top 20 rows

time: 484 ms


In [426]:
data.show()

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-----+----------+-----------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|Total|      Date|RecencyDays|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-----+----------+-----------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom| 15.3|2010-12-01|        395|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|2010-12-01|        395|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom| 22.0|2010-12-01|        395|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|2010-12-01|        395|
|   536365|   84029E|RED WOOLLY HOTTIE...|       6|12/1/2010 8:26|     3.39|

# 1 Data Pre-Processing

In [239]:
data = data.withColumn("Quantity", data["Quantity"].cast(IntegerType()))
data = data.withColumn("UnitPrice", data["UnitPrice"].cast(DoubleType()))

time: 7.99 ms


In [240]:
# define Total column
data = data.withColumn("Total", round(data["UnitPrice"] * data["Quantity"], 2))

time: 8 ms


In [241]:
# change date format
data = data.withColumn("Date", to_date(unix_timestamp("InvoiceDate", "MM/dd/yyyy").cast("timestamp")))


time: 13 ms


In [416]:
# calculate difference in days between 2011-12-31 and the Invoice Date
data = data.withColumn("RecencyDays", expr("datediff('2011-12-31', Date)"))

time: 8 ms


# 2 Create RFM Table

In [436]:
# Creation of RFM table

rfm_table = data.groupBy("CustomerId")\
                        .agg(max("Date").alias("LastPurchase"), \
                             min("RecencyDays").alias("Recency"), \
                             count("InvoiceNo").alias("Frequency"), \
                             sum("Total").alias("Monetary"))

time: 21 ms


In [437]:
rfm_table.printSchema()

root
 |-- CustomerId: string (nullable = true)
 |-- LastPurchase: date (nullable = true)
 |-- Recency: integer (nullable = true)
 |-- Frequency: long (nullable = false)
 |-- Monetary: double (nullable = true)

time: 1 ms


In [438]:
rfm_table.show(5)

+----------+------------+-------+---------+------------------+
|CustomerId|LastPurchase|Recency|Frequency|          Monetary|
+----------+------------+-------+---------+------------------+
|     16250|  2011-03-23|    283|       24|            389.44|
|     15574|  2011-06-15|    199|      168| 702.2500000000002|
|     15555|  2011-11-27|     34|      925| 4758.200000000001|
|     15271|  2011-12-02|     29|      275|2485.8199999999997|
|     17714|  2011-01-23|    342|       10|             153.0|
+----------+------------+-------+---------+------------------+
only showing top 5 rows

time: 3.27 s


In [441]:
rfm_table.cache().count()

4373

time: 862 ms


# 3 Computing Quartiles of RFM values

In [442]:
r_quartile = rfm_table.approxQuantile("Recency", [0.25, 0.5, 0.75], 0)
f_quartile = rfm_table.approxQuantile("Frequency", [0.25, 0.5, 0.75], 0)
m_quartile = rfm_table.approxQuantile("Monetary", [0.25, 0.5, 0.75], 0)

time: 6.31 s


In [446]:
# calculate Recency based on quartile

rfm_table = rfm_table.withColumn("R_Quartile", \
                                 when(col("Recency") >= r_quartile[2] , 1).\
                                 when(col("Recency") >= r_quartile[1] , 2).\
                                 when(col("Recency") >= r_quartile[0] , 3).\
                                 otherwise(4))

time: 72 ms


In [447]:
# calculate Frequency based on quartile

rfm_table = rfm_table.withColumn("F_Quartile", \
                                 when(col("Frequency") > f_quartile[2] , 4).\
                                 when(col("Frequency") > f_quartile[1] , 3).\
                                 when(col("Frequency") > f_quartile[0] , 2).\
                                 otherwise(1))

time: 33 ms


In [448]:
# calculate Monetary based on quartile

rfm_table = rfm_table.withColumn("M_Quartile", \
                                 when(col("Monetary") >= m_quartile[2] , 4).\
                                 when(col("Monetary") >= m_quartile[1] , 3).\
                                 when(col("Monetary") >= m_quartile[0] , 2).\
                                 otherwise(1))

time: 25 ms


In [449]:
# combine the scores (R_Quartile, F_Quartile,M_Quartile) together.

rfm_table = rfm_table.withColumn("RFM_Score", concat(col("R_Quartile"), col("F_Quartile"), col("M_Quartile")))

time: 21 ms


In [450]:
rfm_table.show(5)

+----------+------------+-------+---------+------------------+----------+----------+----------+---------+
|CustomerId|LastPurchase|Recency|Frequency|          Monetary|R_Quartile|F_Quartile|M_Quartile|RFM_Score|
+----------+------------+-------+---------+------------------+----------+----------+----------+---------+
|     16250|  2011-03-23|    283|       24|            389.44|         1|         2|         2|      122|
|     15574|  2011-06-15|    199|      168| 702.2500000000002|         1|         4|         3|      143|
|     15555|  2011-11-27|     34|      925| 4758.200000000001|         4|         4|         4|      444|
|     15271|  2011-12-02|     29|      275|2485.8199999999997|         4|         4|         4|      444|
|     17714|  2011-01-23|    342|       10|             153.0|         1|         1|         1|      111|
+----------+------------+-------+---------+------------------+----------+----------+----------+---------+
only showing top 5 rows

time: 123 ms


# 4 RFM Analysis

In [457]:
# Best customers

rfm_table.select("CustomerID").where("RFM_Score == 444").show(11)

+----------+
|CustomerID|
+----------+
|     15555|
|     15271|
|     17686|
|     17757|
|     16549|
|     13985|
|     14525|
|     18283|
|     12957|
|     17491|
|     16133|
+----------+
only showing top 11 rows

time: 100 ms


In [467]:
# group by RFM Score

grouped_by_rfmscore = rfm_table.groupBy("R_Quartile", "F_Quartile", "M_Quartile").count().orderBy("count", ascending=False)

time: 18 ms


In [468]:
# convert Spark dataframe to pandas in order to visualize data

grouped_by_rfmscore_pandas = grouped_by_rfmscore.toPandas()

time: 5.25 s


In [469]:
grouped_by_rfmscore_pandas

Unnamed: 0,R_Quartile,F_Quartile,M_Quartile,count
0,4,4,4,442
1,1,1,1,395
2,3,4,4,234
3,1,2,2,210
4,3,3,3,187
5,2,1,1,183
6,2,2,2,175
7,2,3,3,168
8,4,3,3,141
9,3,1,1,128


time: 17 ms


In [453]:
grouped_by_rfmscore_pandas['RFM_Score'] = "Seg " + grouped_by_rfmscore_pandas['RFM_Score'].map(str)

time: 46 ms


In [462]:
data = [go.Bar(x=grouped_by_rfmscore_pandas['RFM_Score'], y=grouped_by_rfmscore_pandas['count'])]

layout = go.Layout(
    title=go.layout.Title(
        text='Customer RFM Segments'
    ),
    xaxis=go.layout.XAxis(
        title=go.layout.xaxis.Title(
            text='RFM Segment'
        )
    ),
    yaxis=go.layout.YAxis(
        title=go.layout.yaxis.Title(
            text='Number of Customers'
        )
    )
)

fig = go.Figure(data=data, layout=layout)
iplot(fig, filename='rfm_Segments')

time: 299 ms
