In [1]:
print("hello world")

hello world


In [2]:
import findspark
findspark.init("/usr/local/spark-2.4.4-bin-hadoop2.7")

import pyspark

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

In [3]:
spark = SparkSession \
    .builder \
    .appName("RFM Analysis with PySpark") \
    .getOrCreate()

In [4]:
data = spark.read.format("csv").option("header","true").load("data.csv")
data.printSchema()

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: string (nullable = true)
 |-- InvoiceDate: string (nullable = true)
 |-- UnitPrice: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- Country: string (nullable = true)



In [5]:
data = data.withColumn("Quantity", data["Quantity"].cast(IntegerType()))
data = data.withColumn("UnitPrice", data["UnitPrice"].cast(DoubleType()))
data = data.withColumn("Date", to_date(unix_timestamp("InvoiceDate", "MM/dd/yyyy").cast("timestamp")))

# define Total column
data = data.withColumn("Total", round(data["UnitPrice"] * data["Quantity"], 2))

# calculate difference in days between 2011-12-31 and the Invoice Date
data = data.withColumn("RecencyDays", expr("datediff('2011-12-31', Date)"))

In [6]:
data.printSchema()

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- InvoiceDate: string (nullable = true)
 |-- UnitPrice: double (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- Date: date (nullable = true)
 |-- Total: double (nullable = true)
 |-- RecencyDays: integer (nullable = true)



In [7]:
rfm_table = data.groupBy("CustomerId") \
.agg(min("RecencyDays").alias("Recency"), \
count("InvoiceNo").alias("Frequency"), \
round(sum("Total"),2).alias("Monetary"))

In [8]:
rfm_table.show(5)

+----------+-------+---------+--------+
|CustomerId|Recency|Frequency|Monetary|
+----------+-------+---------+--------+
|     16250|    283|       24|  389.44|
|     15574|    199|      168|  702.25|
|     15555|     34|      925|  4758.2|
|     15271|     29|      275| 2485.82|
|     17714|    342|       10|   153.0|
+----------+-------+---------+--------+
only showing top 5 rows



In [9]:
r_quartile = rfm_table.approxQuantile("Recency", [0.25, 0.5, 0.75], 0)
f_quartile = rfm_table.approxQuantile("Frequency", [0.25, 0.5, 0.75], 0)
m_quartile = rfm_table.approxQuantile("Monetary", [0.25, 0.5, 0.75], 0)

rfm_table = rfm_table.withColumn("R_Quartile", \
                                 when(col("Recency") >= r_quartile[2] , 1).\
                                 when(col("Recency") >= r_quartile[1] , 2).\
                                 when(col("Recency") >= r_quartile[0] , 3).\
                                 otherwise(4))

rfm_table = rfm_table.withColumn("F_Quartile", \
                                 when(col("Frequency") > f_quartile[2] , 4).\
                                 when(col("Frequency") > f_quartile[1] , 3).\
                                 when(col("Frequency") > f_quartile[0] , 2).\
                                 otherwise(1))

rfm_table = rfm_table.withColumn("M_Quartile", \
                                 when(col("Monetary") >= m_quartile[2] , 4).\
                                 when(col("Monetary") >= m_quartile[1] , 3).\
                                 when(col("Monetary") >= m_quartile[0] , 2).\
                                 otherwise(1))

rfm_table = rfm_table.withColumn("RFM_Score", concat(col("R_Quartile"), col("F_Quartile"), col("M_Quartile")))

In [10]:
best_customers = rfm_table.select("CustomerID").where("RFM_Score == 444")

In [12]:
best_customers.count()

442