## Customer churn
In This notebook, I will be creating some machine learning model to predict customer churn. The dataset is based on Sparkify dataset which is quite popular. I have downloaded the data from a medium post 

I will not be doing any plottings as the main purpose of these notebook exercises is to get familar with pyspark. eg how to transform data. There are many notebooks avaliable on this dataset but the below reference is the most convincing for me as the author did alot of interesting data engineering and most importantly, the author transform the data into a user based dataframe for prediction

Reference
https://github.com/abduygur/churn-prediction-using-spark/blob/main/sparkify_churn_ready_ibm.ipynb

In [48]:
import findspark
findspark.init()

## import libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import when, isnan, count, isnull, mean, column, udf, from_unixtime, month, year, col, concat, first, avg, countDistinct, datediff
# One important thing to take note is that min/max is already taken by python so we need to give it another name
from pyspark.sql.functions import min as Fmin, max as Fmax, sum as Fsum
from pyspark.sql.types import StringType, DateType, LongType, TimestampType

## machine learning libraries
from pyspark.ml.feature import VectorAssembler, OneHotEncoder, StringIndexer
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.pipeline import Pipeline
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.classification import GBTClassifier, RandomForestClassifier


import httpagentparser
import time

In [0]:
spark = SparkSession.builder.appName("CustomerChurn").getOrCreate()

spark.read.format("json").load("dbfs:/FileStore/shared_uploads/zuozhe.teo@digipen.edu/medium_sparkify_event_data.json")
df1.show(5)

+-----------------+---------+---------+------+-------------+--------+---------+-----+--------------------+------+--------+-------------+---------+------------------+------+-------------+--------------------+------+
|           artist|     auth|firstName|gender|itemInSession|lastName|   length|level|            location|method|    page| registration|sessionId|              song|status|           ts|           userAgent|userId|
+-----------------+---------+---------+------+-------------+--------+---------+-----+--------------------+------+--------+-------------+---------+------------------+------+-------------+--------------------+------+
|    Martin Orford|Logged In|   Joseph|     M|           20| Morales|597.55057| free|  Corpus Christi, TX|   PUT|NextSong|1532063507000|      292|     Grand Designs|   200|1538352011000|"Mozilla/5.0 (Mac...|   293|
|John Brown's Body|Logged In|   Sawyer|     M|           74|  Larson|380.21179| free|Houston-The Woodl...|   PUT|NextSong|1538069638000|    

## Basic EDA
1. printSchema
2. Describe
3. Check for NA values

In [0]:
df1.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)



In [0]:
df1.describe().show()

+-------+-----------------+----------+---------+------+------------------+--------+-----------------+------+----------------+------+-------+--------------------+------------------+--------------------+------------------+--------------------+--------------------+------------------+
|summary|           artist|      auth|firstName|gender|     itemInSession|lastName|           length| level|        location|method|   page|        registration|         sessionId|                song|            status|                  ts|           userAgent|            userId|
+-------+-----------------+----------+---------+------+------------------+--------+-----------------+------+----------------+------+-------+--------------------+------------------+--------------------+------------------+--------------------+--------------------+------------------+
|  count|           432877|    543705|   528005|528005|            543705|  528005|           432877|543705|          528005|543705| 543705|              

In [0]:
print(f"Total rows in df {df1.count()}")
print(f"Total Unique customers {df1.select('userId').distinct().count()}")
print(f"Total columns {len(df1.columns)}")

Total rows in df 543705
Total Unique customers 449
Total columns 18


## Checking all the unique categorical values
1. gender
2. auth
3. level
4. page

In [0]:
def print_unique_col_values(df, column_name):
  results = df.select(column_name).distinct().collect()
  for result in results:
    print(result[column_name])
    
for column_name in ['gender', 'auth', 'level', 'page']:
  print("Column {}".format(column_name))
  print_unique_col_values(df1, column_name)
  print()

Column gender
F
None
M

Column auth
Logged Out
Cancelled
Guest
Logged In

Column level
free
paid

Column page
Cancel
Submit Downgrade
Thumbs Down
Home
Downgrade
Roll Advert
Logout
Save Settings
Cancellation Confirmation
About
Settings
Login
Add to Playlist
Add Friend
NextSong
Thumbs Up
Help
Upgrade
Error
Submit Upgrade
Submit Registration
Register



## Checking the time range of the dataset
1. we will be finding the minimum ts and maximum ts and converting it to datetime format
2. We can see that the time period is 2 months

In [0]:
print(df1.select(Fmax(from_unixtime(col('ts')/1000).cast(DateType())).alias('time')).collect())
print(df1.select(Fmin(from_unixtime(col('ts')/1000).cast(DateType())).alias('time')).collect())

[Row(time=datetime.date(2018, 12, 1))]
[Row(time=datetime.date(2018, 10, 1))]


In [0]:
## Checking for NA values is more complex than in pandas
## There are no nan columns
df1.select([count(when(isnan(col), 1)).alias(col) for col in df1.columns]).show()

+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+
|artist|auth|firstName|gender|itemInSession|lastName|length|level|location|method|page|registration|sessionId|song|status| ts|userAgent|userId|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+
|     0|   0|        0|     0|            0|       0|     0|    0|       0|     0|   0|           0|        0|   0|     0|  0|        0|     0|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+



In [0]:
## Checking for NA values is more complex than in pandas
## In pyspark, we have to check isnan and isnull just to be sure
df1.select([count(when(isnull(col), 1)).alias(col) for col in df1.columns]).show()

+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+
|artist|auth|firstName|gender|itemInSession|lastName|length|level|location|method|page|registration|sessionId|song|status| ts|userAgent|userId|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+
|     0|   0|        0|     0|            0|       0|     0|    0|       0|     0|   0|           0|        0|   0|     0|  0|        0|     0|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+



## Removing rows with missing information

In [0]:
df1 = df1.filter(df1['userId']!='')
df1 = df1.withColumn('userId', col('userId').cast(LongType()))
df1 = df1.filter(df1['auth'] != 'Logged Out')
df1 = df1.dropna(how='any', subset=['userId','sessionId'])

## Checking for NA values

In [0]:
## Checking for NA values is more complex than in pandas
## There are no nan columns
df1.select([count(when(isnan(col), 1)).alias(col) for col in df1.columns]).show()


+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+
|artist|auth|firstName|gender|itemInSession|lastName|length|level|location|method|page|registration|sessionId|song|status| ts|userAgent|userId|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+
|     0|   0|        0|     0|            0|       0|     0|    0|       0|     0|   0|           0|        0|   1|     0|  0|        0|     0|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+----+------+---+---------+------+



In [0]:
## Checking for NA values is more complex than in pandas
## In pyspark, we have to check isnan and isnull just to be sure
df1.select([count(when(isnull(col), 1)).alias(col) for col in df1.columns]).show()

+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+------+------+---+---------+------+
|artist|auth|firstName|gender|itemInSession|lastName|length|level|location|method|page|registration|sessionId|  song|status| ts|userAgent|userId|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+------+------+---+---------+------+
|110828|   0|    15700| 15700|            0|   15700|110828|    0|   15700|     0|   0|       15700|        0|110828|     0|  0|    15700|     0|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+------+------+---+---------+------+



## Feature Engineering
1. userAgent 
2. location (only keeping the state)
3. timestamp/registration

#### As we are predicting user churn, we have to transform the data to a user dataframe

In [0]:
label = df1.withColumn('target', when((col('page').isin(['Cancellation Confirmation','Cancel']))|(col('auth')=='Cancelled'), 1).otherwise(0))
label = label.groupby('userId').agg(Fsum('target').alias('target'))
label = label.withColumn('target', when(col('target')>0, 1).otherwise(0))
label.show(5)

+------+------+
|userId|target|
+------+------+
|    29|     0|
|    26|     0|
|    65|     0|
|   191|     0|
|   293|     1|
+------+------+
only showing top 5 rows



### Converting the userAgent column

In [0]:
## Understanding userAgent column
df1.select("userAgent").distinct().show(10, truncate=False)

# Adding additional operating and brower columns
operating_system_udf = udf(lambda x: httpagentparser.simple_detect(x)[0].split(' ')[0], StringType())
browser_udf =  udf(lambda x: httpagentparser.simple_detect(x)[1].split(' ')[0], StringType())
df1 = df1.withColumn("operating_system", operating_system_udf('userAgent'))
df1 = df1.withColumn("browser", browser_udf('userAgent'))
df1.select('operating_system', 'browser').show()

+--------------------------------------------------------------------------------------------------------------------------+
|userAgent                                                                                                                 |
+--------------------------------------------------------------------------------------------------------------------------+
|"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"|
|Mozilla/5.0 (Windows NT 6.1; WOW64; rv:30.0) Gecko/20100101 Firefox/30.0                                                  |
|Mozilla/5.0 (compatible; MSIE 9.0; Windows NT 6.1; WOW64; Trident/5.0)                                                    |
|"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"           |
|Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0                                                  |


### Converting location column

In [0]:
# understanding the location
df1.select("location").distinct().show(10, truncate=False)

udf_location = udf(lambda x: x.split(',')[1], StringType())
df1 = df1.withColumn('location', udf_location('location'))
df1.select("location").show(5)

+------------------------------------+
|location                            |
+------------------------------------+
|Corpus Christi, TX                  |
|Mobile, AL                          |
|Los Angeles-Long Beach-Anaheim, CA  |
|Orlando-Kissimmee-Sanford, FL       |
|Auburn, IN                          |
|El Campo, TX                        |
|Rochester, MN                       |
|Dallas-Fort Worth-Arlington, TX     |
|Houston-The Woodlands-Sugar Land, TX|
|Charlotte-Concord-Gastonia, NC-SC   |
+------------------------------------+
only showing top 10 rows

+--------+
|location|
+--------+
|      TX|
|      TX|
|      FL|
|      FL|
|      AL|
+--------+
only showing top 5 rows



### Converting ts and registration

In [0]:
df1=df1.withColumn("time", from_unixtime(col('ts')/1000).cast(TimestampType()))
df1=df1.withColumn("date", from_unixtime(col('ts')/1000).cast(DateType()))
df1=df1.withColumn("month", month('date'))
df1=df1.withColumn("year", year('date'))
df1=df1.withColumn("yearmonth", concat(col("year"),col("month")))
df1=df1.withColumn("first_date", from_unixtime(col('registration')/1000).cast(DateType()))
df1.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: long (nullable = true)
 |-- operating_system: string (nullable = true)
 |-- browser: string (nullable = true)
 |-- time: timestamp (nullable = true)
 |-- date: date (nullable = true)
 |-- month: integer (nullable = true)
 |-- year: integer (nullable = true)
 |-- yearmonth: string (nullable = true)
 |-- first_date: date (nullable 

In [0]:
last_date = df1.groupby('userId').agg(Fmax('date').alias('last_date'))
current_charging = df1.groupby('userId').agg(first('level').alias('current_charging'))
length_mean = df1.groupby("userId").agg(avg('length').alias("avg_length"))
length_max = df1.groupby("userId").agg(Fmax('length').alias("max_length"))


In [0]:
distinct_pages = df1.select('page').distinct().collect()
distinct_pages = [item.page for item in distinct_pages]
distinct_pages.remove("Cancellation Confirmation")
distinct_pages.remove("Cancel")

agg_expression = {x:'avg' for x in distinct_pages}
event_day = df1.groupby('userId','date').pivot('page').count()# we want to groupby day first so we can find the event per login day
event_day = event_day.groupBy('userId').agg(agg_expression).fillna(0)

event_month = df1.groupby("userId", 'yearmonth').pivot('page').count()
event_month = event_month.groupby("userId").agg(agg_expression).fillna(0)

for column in distinct_pages:
    event_day = event_day.withColumnRenamed('avg({})'.format(column),'event_{}_daily_avg'.format(column))
    event_month = event_month.withColumnRenamed('avg({})'.format(column),'event_{}_monthly_avg'.format(column))

In [0]:
#Calculate users item count in each session
item_per_session = df1.groupBy('userId','sessionId').agg(Fmax('itemInSession').alias('max_item_cnt'))
item_per_session = item_per_session.groupBy('userId').avg('max_item_cnt') \
    .withColumnRenamed('avg(max_item_cnt)','avg_item_per_session')

In [0]:
## Copied from author github
def grouped_session(df, agg_col='sessionId', group_col='yearmonth', col_suffixes='_monthly'):
    
    '''
    Group the dataframe by UserId and count distinct values of agg_col, gives the name with col_suffixes
    
    INPUT: Spark Dataframe, Aggregate Column, Group Column, Column Suffixes
    OUTPUT: UserId based groupped and aggregated column
    '''
    
    #Group dataframe by UserId and group_col(default:yearmonth) and aggregate by CountDistinct by column(default:sessionId)
    #Change the column name using col_suffixes arguement 
    session = df.groupBy('userId',group_col).agg(countDistinct(agg_col).alias('session_count'))
    session = session.groupBy('userId').avg('session_count') \
        .withColumnRenamed('avg(session_count)','session_count'+col_suffixes)
    
    return session

def grouped_session_length(df, agg_col='time', group_col='yearmonth', col_suffixes='_monthly'):
    
    '''
    Group the dataframe by UserId and sessionId and take difference max and min of agg_col, gives the name with col_suffixes.
    It is using for calculating average session length by user
    
    INPUT: Spark Dataframe, Aggregate Column, Group Column, Column Suffixes
    OUTPUT: UserId based groupped and aggregated column
    '''
    #Groups the userId, sessionId and group_col(default:yearmonth)
    session_length = df.groupBy('userId','sessionId',group_col) \
        .agg(Fmin(agg_col).alias('start_time'), Fmax(agg_col).alias('end_time'))
    #Calculate session durations
    session_length = session_length.withColumn('duration',(col('end_time') \
                                                           .cast(LongType()) - col('start_time').cast(LongType())))
    #Calculate average session duration by user and group_col(default:yearmonth)
    session_length = session_length.groupBy('userId',group_col).avg('duration') \
        .withColumnRenamed('avg(duration)','avg_duration')
    #Calculate average session duration by user distinctly
    session_length = session_length.groupBy('userId').avg('avg_duration') \
        .withColumnRenamed('avg(avg_duration)','avg_duration')
    session_length = session_length.withColumn('avg_duration'+col_suffixes,col('avg_duration')/3600)
    session_length = session_length.drop('avg_duration')
    
    return session_length

In [0]:
monthly_session = grouped_session(df1)
daily_session = grouped_session(df1,group_col='date',col_suffixes='_daily')

In [0]:
#Calculate daily and monthly average session length
session_length_monthly = grouped_session_length(df1)
session_length_daily = grouped_session_length(df1, group_col='date',col_suffixes='_daily')

In [0]:
df_ready = df1.select('userId','location','gender','first_date').distinct() \
    .join(session_length_monthly,on='userId') \
    .join(session_length_daily,on='userId') \
    .join(monthly_session,on='userId') \
    .join(daily_session,on='userId') \
    .join(item_per_session,on='userId') \
    .join(event_month,on='userId') \
    .join(event_day,on='userId') \
    .join(length_mean,on='userId') \
    .join(length_max,on='userId') \
    .join(current_charging,on='userId') \
    .join(last_date,on='userId') \
    .join(label,on='userId')

df_ready = df_ready.withColumn('tenure_days', datediff(col('last_date'),col('first_date')))
df_ready = df_ready.drop('userId','first_date','last_date')
df_ready.printSchema()

root
 |-- location: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- avg_duration_monthly: double (nullable = true)
 |-- avg_duration_daily: double (nullable = true)
 |-- session_count_monthly: double (nullable = true)
 |-- session_count_daily: double (nullable = true)
 |-- avg_item_per_session: double (nullable = true)
 |-- event_Settings_monthly_avg: double (nullable = false)
 |-- event_Add Friend_monthly_avg: double (nullable = false)
 |-- event_Thumbs Down_monthly_avg: double (nullable = false)
 |-- event_Downgrade_monthly_avg: double (nullable = false)
 |-- event_Submit Upgrade_monthly_avg: double (nullable = false)
 |-- event_Roll Advert_monthly_avg: double (nullable = false)
 |-- event_NextSong_monthly_avg: double (nullable = false)
 |-- event_Error_monthly_avg: double (nullable = false)
 |-- event_About_monthly_avg: double (nullable = false)
 |-- event_Upgrade_monthly_avg: double (nullable = false)
 |-- event_Add to Playlist_monthly_avg: double (nullable = fa

In [0]:
df_pandas_ready = df_ready.toPandas()

In [0]:
display(df_ready)

location,gender,avg_duration_monthly,avg_duration_daily,session_count_monthly,session_count_daily,avg_item_per_session,event_Settings_monthly_avg,event_Add Friend_monthly_avg,event_Thumbs Down_monthly_avg,event_Downgrade_monthly_avg,event_Submit Upgrade_monthly_avg,event_Roll Advert_monthly_avg,event_NextSong_monthly_avg,event_Error_monthly_avg,event_About_monthly_avg,event_Upgrade_monthly_avg,event_Add to Playlist_monthly_avg,event_Home_monthly_avg,event_Thumbs Up_monthly_avg,event_Logout_monthly_avg,event_Submit Downgrade_monthly_avg,event_Save Settings_monthly_avg,event_Help_monthly_avg,event_Settings_daily_avg,event_Add Friend_daily_avg,event_Thumbs Down_daily_avg,event_Downgrade_daily_avg,event_Submit Upgrade_daily_avg,event_Roll Advert_daily_avg,event_NextSong_daily_avg,event_Error_daily_avg,event_About_daily_avg,event_Upgrade_daily_avg,event_Add to Playlist_daily_avg,event_Home_daily_avg,event_Thumbs Up_daily_avg,event_Logout_daily_avg,event_Submit Downgrade_daily_avg,event_Save Settings_daily_avg,event_Help_daily_avg,avg_length,max_length,current_charging,target,tenure_days
FL,M,1.4739930555555556,1.120486111111111,6.0,1.0833333333333333,28.666666666666668,0.0,3.0,1.0,0.0,0.0,18.5,124.5,0.0,0.0,2.0,3.0,9.5,4.5,4.5,0.0,0.0,1.5,0.0,1.5,1.0,0.0,0.0,4.625,24.9,0.0,0.0,2.0,3.0,1.9,1.2857142857142858,1.125,0.0,0.0,1.5,247.34535767068263,562.20689,free,0,182
CA,M,3.5122795893719805,3.037351851851852,19.5,1.2857142857142858,61.15384615384615,6.5,20.0,8.5,4.5,1.0,36.5,953.5,5.0,2.5,2.0,24.0,45.0,45.0,15.5,1.0,1.0,7.0,1.0833333333333333,2.857142857142857,1.4166666666666667,1.125,1.0,3.842105263157895,54.48571428571429,1.25,1.25,1.3333333333333333,2.2857142857142856,3.2142857142857144,3.4615384615384617,1.7222222222222223,1.0,1.0,1.5555555555555556,248.49965038804413,2520.99873,free,0,76
CA,M,5.39875,3.571203703703704,2.0,1.0,85.5,1.0,8.0,0.0,0.0,1.0,0.0,129.0,0.0,1.0,2.0,4.0,10.0,7.0,3.0,0.0,0.0,0.0,1.0,2.6666666666666665,0.0,0.0,1.0,0.0,43.0,0.0,1.0,2.0,2.0,3.333333333333333,3.5,3.0,0.0,0.0,0.0,237.17984046511637,655.77751,free,0,46
WA,M,1.7336111111111112,2.147037037037037,1.5,1.0,40.0,0.0,0.0,2.0,0.0,0.0,7.0,49.5,0.0,0.0,1.0,2.0,1.5,2.5,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,3.5,33.0,0.0,0.0,1.0,2.0,1.0,2.5,1.0,0.0,0.0,0.0,240.1308181818181,472.89424,free,0,72
NH,M,4.63616681929182,4.166661111111111,13.5,1.32,81.18518518518519,7.0,18.0,7.5,10.5,1.0,14.0,891.0,1.0,3.0,3.5,23.0,38.0,44.5,12.0,1.0,1.5,4.0,2.0,2.4,1.875,1.9090909090909087,1.0,3.111111111111111,74.25,1.0,1.0,1.4,3.5384615384615383,3.3043478260869565,5.933333333333334,1.5,1.0,1.0,1.6,248.4982112401797,978.442,free,0,71
MO-IL,M,1.893888888888889,1.6560833333333331,6.0,1.2,33.5,0.0,2.0,0.0,3.0,1.0,6.0,163.0,0.0,0.0,1.0,6.0,6.0,8.0,2.0,0.0,0.0,0.0,0.0,1.0,0.0,3.0,1.0,2.0,32.6,0.0,0.0,1.0,2.0,2.0,2.6666666666666665,2.0,0.0,0.0,0.0,256.8356251533741,2594.87302,free,1,40
VA-NC,M,4.8853935185185176,4.104786324786325,5.5,1.0,98.1,5.0,5.5,4.5,2.5,1.0,10.0,382.5,0.0,0.0,2.0,12.5,17.5,24.0,8.5,0.0,1.0,2.5,2.5,2.75,1.2857142857142858,1.25,1.0,2.5,58.84615384615385,0.0,0.0,1.0,2.272727272727273,2.9166666666666665,4.0,1.8888888888888888,0.0,1.0,1.25,250.39803994771245,1001.74322,free,0,78
IL-IN-WI,F,1.956276655443322,1.4893121693121691,10.0,1.1904761904761905,36.5,3.0,5.0,2.0,0.0,0.0,29.5,280.0,0.0,1.0,3.0,6.5,16.0,14.5,5.5,0.0,1.0,1.0,1.0,1.6666666666666667,1.3333333333333333,0.0,0.0,3.6875,28.0,0.0,1.0,1.5,1.625,2.1333333333333333,2.4166666666666665,1.2222222222222223,0.0,1.0,1.0,256.7948352857143,1314.48118,free,0,67
TX,M,5.987434413580248,4.781111111111112,29.0,1.619047619047619,90.25862068965516,13.0,42.0,19.5,19.0,2.0,25.5,2123.0,2.0,4.5,9.0,51.5,100.0,113.5,31.0,1.0,3.0,9.5,1.4444444444444444,3.652173913043478,1.625,2.235294117647059,1.0,3.4,106.15,1.0,1.5,1.5,3.21875,5.0,6.305555555555555,2.0,1.0,1.5,1.1875,247.11603391662737,1731.00363,free,1,124
GA-AL,M,1.1334920634920636,1.0518055555555557,7.0,1.1428571428571428,21.142857142857142,1.0,2.0,3.0,0.0,0.0,9.0,122.0,0.0,1.0,0.0,4.0,6.0,5.0,0.0,0.0,0.0,0.0,1.0,1.0,1.5,0.0,0.0,2.25,17.428571428571427,0.0,1.0,0.0,2.0,1.5,1.25,0.0,0.0,0.0,0.0,246.8560476229509,655.77751,free,1,43


## Training machine learning model
#### Overview
1. Converting Categorical to OneHot (location, gender, current_charging)
2. Normalizing the numerical features
3. Putting all in pipeline
4. CrossValidator with 5 folds

#### Reloading data to reduce training time

1. the data was prepared on databricks community but the training was too slow on the platform so, the notebook was export to a local cluster for training the model

In [5]:
spark = SparkSession.builder.appName("CustomerChurn").getOrCreate()
df_ready = spark.read.csv("../data/df_ready.csv", header=True, inferSchema=True)

#### Categorical columns

In [6]:
# Categorical columns 
categorical_cols = ['location', 'gender', 'current_charging']
categorical_cols_indexed = [f"{col_name}_indexed" for col_name in categorical_cols]
categorical_cols_ohe = [f"{col_name}_ohe" for col_name in categorical_cols] # at this point the categorical columns will become a vector

string_indexer =StringIndexer(inputCols=categorical_cols, outputCols=categorical_cols_indexed, handleInvalid='keep')
ohe = OneHotEncoder(inputCols=categorical_cols_indexed, outputCols=categorical_cols_ohe)


#### Numerical columns

In [7]:
numerical_cols = df_ready.columns
numerical_cols = [column for column in numerical_cols if column not in categorical_cols]
numerical_cols.remove('target')
# numerical_cols_scaled = [f"{col_name}_scaled" for col_name in numerical_cols]

# std_scalers = [StandardScaler(inputCol=column, outputCol=column_scaled) for column, column_scaled in zip(numerical_cols, numerical_cols_scaled)]

## Assembling features

In [35]:
feature_assembler = VectorAssembler(inputCols=numerical_cols, outputCol='features')

#### Gradient boosting

In [55]:
gbt_model = GBTClassifier(featuresCol='features', labelCol='target')
gbt_pipeline = Pipeline(stages=[string_indexer,ohe, feature_assembler,gbt_model])
gbt_param_grid = ParamGridBuilder().addGrid(gbt_model.maxDepth, [3,5,7]).build()
binary_evaluator = MulticlassClassificationEvaluator(predictionCol="prediction",labelCol='target',metricName='recallByLabel')
gbt_cross_validator = CrossValidator(estimator=gbt_pipeline,
                                 estimatorParamMaps=gbt_param_grid,
                                 evaluator=binary_evaluator,
                                 numFolds=5)

In [56]:
start_time = time.time()
gbt_cross_validator = gbt_cross_validator.fit(df_ready)
end_time = time.time()

In [57]:
gbt_cross_validator.avgMetrics

[0.9228899794682727, 0.9122272566845888, 0.8652133749648248]

### Randomforest

In [58]:
rf_model = RandomForestClassifier(featuresCol='features', labelCol='target')
rf_pipeline = Pipeline(stages=[string_indexer,ohe, feature_assembler,rf_model])
rf_param_grid = ParamGridBuilder().addGrid(rf_model.numTrees,[20,100,200]).addGrid(rf_model.maxDepth,[5,7,11]).build()
binary_evaluator = MulticlassClassificationEvaluator(predictionCol="prediction",labelCol='target',metricName='recallByLabel')


rf_cross_validator = CrossValidator(estimator=rf_pipeline,
                                    estimatorParamMaps=rf_param_grid,
                                    evaluator=binary_evaluator,
                                    numFolds=5)

In [59]:
start_time = time.time()
rf_cross_validator = rf_cross_validator.fit(df_ready)
end_time = time.time()

In [60]:
rf_cross_validator.avgMetrics

[0.986004446004446,
 0.9826465426465427,
 0.962629364858031,
 0.9912366912366914,
 0.9883795483795486,
 0.974299960770549,
 0.9888615888615888,
 0.9802714402714403,
 0.9692322320491336]

## Conclusion

I was able to get a F1 score of around 76% and quite a high recall of 98-99%. In this task, we would want to he recall to be as high as possible so we can act on all the potential churners