In [72]:
import findspark
findspark.init()

import pandas as pd

%load_ext autotime

import chart_studio.plotly as py
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)

import pyspark

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

The autotime extension is already loaded. To reload it, use:
  %reload_ext autotime


time: 7 ms


In [4]:
spark = SparkSession \
    .builder \
    .appName("RFM Analysis with PySpark") \
    .getOrCreate()

time: 35.9 s


In [5]:
spark

time: 14 ms


Data Source: https://www.kaggle.com/carrie1/ecommerce-data

In [6]:
data = spark.read.format("csv").option("header", "true").load("data.csv")

time: 4.32 s


In [7]:
data

DataFrame[InvoiceNo: string, StockCode: string, Description: string, Quantity: string, InvoiceDate: string, UnitPrice: string, CustomerID: string, Country: string]

time: 61 ms


In [8]:
data.columns

['InvoiceNo',
 'StockCode',
 'Description',
 'Quantity',
 'InvoiceDate',
 'UnitPrice',
 'CustomerID',
 'Country']

time: 2 ms


In [9]:
data.printSchema()

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: string (nullable = true)
 |-- InvoiceDate: string (nullable = true)
 |-- UnitPrice: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- Country: string (nullable = true)

time: 5 ms


In [10]:
# cache data in memory
data.cache().count()

541909

time: 2.46 s


In [33]:
data.show(5)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-----+----------+-----------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|Total|      Date|RecencyDays|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-----+----------+-----------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom| 15.3|2010-12-01|        395|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|2010-12-01|        395|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom| 22.0|2010-12-01|        395|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|2010-12-01|        395|
|   536365|   84029E|RED WOOLLY HOTTIE...|       6|12/1/2010 8:26|     3.39|

# 1 Data Pre-Processing

In [13]:
data = data.withColumn("Quantity", data["Quantity"].cast(IntegerType()))
data = data.withColumn("UnitPrice", data["UnitPrice"].cast(DoubleType()))

time: 25 ms


In [14]:
# define Total column
data = data.withColumn("Total", round(data["UnitPrice"] * data["Quantity"], 2))

time: 19 ms


In [15]:
# change date format
data = data.withColumn("Date", to_date(unix_timestamp("InvoiceDate", "MM/dd/yyyy").cast("timestamp")))

time: 130 ms


In [16]:
# calculate difference in days between 2011-12-31 and the Invoice Date
data = data.withColumn("RecencyDays", expr("datediff('2011-12-31', Date)"))

time: 66 ms


# 2 Create RFM Table

In [50]:
# Creation of RFM table

rfm_table = data.groupBy("CustomerId")\
                        .agg(min("RecencyDays").alias("Recency"), \
                             count("InvoiceNo").alias("Frequency"), \
                             sum("Total").alias("Monetary"))

time: 95 ms


In [51]:
rfm_table = rfm_table.withColumn("Monetary", round(rfm_table["Monetary"], 2))

time: 6 ms


In [52]:
rfm_table.printSchema()

root
 |-- CustomerId: string (nullable = true)
 |-- Recency: integer (nullable = true)
 |-- Frequency: long (nullable = false)
 |-- Monetary: double (nullable = true)

time: 998 µs


In [54]:
rfm_table.show(5)

+----------+-------+---------+--------+
|CustomerId|Recency|Frequency|Monetary|
+----------+-------+---------+--------+
|     16250|    283|       24|  389.44|
|     15574|    199|      168|  702.25|
|     15555|     34|      925|  4758.2|
|     15271|     29|      275| 2485.82|
|     17714|    342|       10|   153.0|
+----------+-------+---------+--------+
only showing top 5 rows

time: 2.72 s


In [20]:
rfm_table.cache().count()

4373

time: 4.87 s


# 3 Computing Quartiles of RFM values

In [56]:
r_quartile = rfm_table.approxQuantile("Recency", [0.25, 0.5, 0.75], 0)
f_quartile = rfm_table.approxQuantile("Frequency", [0.25, 0.5, 0.75], 0)
m_quartile = rfm_table.approxQuantile("Monetary", [0.25, 0.5, 0.75], 0)

time: 13.6 s


In [57]:
# calculate Recency based on quartile

rfm_table = rfm_table.withColumn("R_Quartile", \
                                 when(col("Recency") >= r_quartile[2] , 1).\
                                 when(col("Recency") >= r_quartile[1] , 2).\
                                 when(col("Recency") >= r_quartile[0] , 3).\
                                 otherwise(4))

time: 78 ms


In [58]:
# calculate Frequency based on quartile

rfm_table = rfm_table.withColumn("F_Quartile", \
                                 when(col("Frequency") > f_quartile[2] , 4).\
                                 when(col("Frequency") > f_quartile[1] , 3).\
                                 when(col("Frequency") > f_quartile[0] , 2).\
                                 otherwise(1))

time: 43 ms


In [59]:
# calculate Monetary based on quartile

rfm_table = rfm_table.withColumn("M_Quartile", \
                                 when(col("Monetary") >= m_quartile[2] , 4).\
                                 when(col("Monetary") >= m_quartile[1] , 3).\
                                 when(col("Monetary") >= m_quartile[0] , 2).\
                                 otherwise(1))

time: 56 ms


In [60]:
# combine the scores (R_Quartile, F_Quartile,M_Quartile) together.

rfm_table = rfm_table.withColumn("RFM_Score", concat(col("R_Quartile"), col("F_Quartile"), col("M_Quartile")))

time: 25 ms


In [62]:
rfm_table.show(10)

+----------+-------+---------+--------+----------+----------+----------+---------+
|CustomerId|Recency|Frequency|Monetary|R_Quartile|F_Quartile|M_Quartile|RFM_Score|
+----------+-------+---------+--------+----------+----------+----------+---------+
|     16250|    283|       24|  389.44|         1|         2|         2|      122|
|     15574|    199|      168|  702.25|         1|         4|         3|      143|
|     15555|     34|      925|  4758.2|         4|         4|         4|      444|
|     15271|     29|      275| 2485.82|         4|         4|         4|      444|
|     17714|    342|       10|   153.0|         1|         1|         1|      111|
|     17686|     29|      286| 5739.46|         4|         4|         4|      444|
|     13865|     80|       30|  501.56|         2|         2|         2|      222|
|     14157|     41|       49|  400.43|         3|         3|         2|      332|
|     13610|     34|      228| 1115.43|         4|         4|         3|      443|
|   

# 4 RFM Analysis

In [63]:
# Best customers

rfm_table.select("CustomerID").where("RFM_Score == 444").show(10)

+----------+
|CustomerID|
+----------+
|     15555|
|     15271|
|     17686|
|     17757|
|     16549|
|     13985|
|     14525|
|     18283|
|     12957|
|     17491|
+----------+
only showing top 10 rows

time: 2.16 s


In [64]:
# group by RFM Score

grouped_by_rfmscore = rfm_table.groupBy("RFM_Score").count().orderBy("count", ascending=False)

time: 16 ms


In [65]:
# convert Spark dataframe to pandas in order to visualize data

grouped_by_rfmscore_pandas = grouped_by_rfmscore.toPandas()

time: 11.3 s


In [66]:
grouped_by_rfmscore_pandas

Unnamed: 0,RFM_Score,count
0,444,442
1,111,395
2,344,234
3,122,210
4,333,187
5,211,183
6,222,175
7,233,168
8,433,141
9,322,128


time: 11 ms


In [67]:
grouped_by_rfmscore_pandas['RFM_Score'] = "Seg " + grouped_by_rfmscore_pandas['RFM_Score'].map(str)

time: 311 ms


In [73]:
data = [go.Bar(x=grouped_by_rfmscore_pandas['RFM_Score'], y=grouped_by_rfmscore_pandas['count'])]

layout = go.Layout(
    title=go.layout.Title(
        text='Customer RFM Segments'
    ),
    xaxis=go.layout.XAxis(
        title=go.layout.xaxis.Title(
            text='RFM Segment'
        )
    ),
    yaxis=go.layout.YAxis(
        title=go.layout.yaxis.Title(
            text='Number of Customers'
        )
    )
)

fig = go.Figure(data=data, layout=layout)
iplot(fig, filename='rfm_Segments')

time: 1.38 s
