
# A Tutorial of Customer Churn Analysis & Prediction 
## Data Visualization, EDA, and Churn Prediction Using ML Algorithms for Music Streaming Service


<a id="table-of-contents"></a>

1. [Introduction](#intro)
2. [Data Preparation](#data)
    * 2.1. [Load and Clean Missing Data](#clean)
    * 2.2  [Build the Website-log Dataframe (df_log)](#log) 
    * 2.3. [Transform the Webstie Logs into the User-log Data by Aggregation](#user)
3. [Exploratory Data Analysis(EDA): Churned Vs. Stayed](#eda)
4. [Customer Churn Prediction](#prediction)
    * 4.1. [Engineer Features](#prediction)
    * 4.2. [Build the Model Pipeline](#pipeline)
    * 4.3. [Select ML Algorithms](#model) 
    * 4.4. [Tunne Hyper Parameters](#tune)
5. [Conclusion](#conclusion)
6. [Reference](#reference)

<a id="intro"></a>
[back to top](#table-of-contents)

# 1. Introduction

Predicting churn rates is very challenging so many data scientist and anlysts struggles in any customer-facing business. Since the user-interaction services like Spotify requires to communicate with customers frequently, there are large amount of logging data every day. Thus, in this project, I would like to show how to manipulate large and realistic datasets with Spark, as well as how to build the prediction model with Spark MLlib. Let's dive in! 

### 
<img src="https://images.unsplash.com/photo-1616356607338-fd87169ecf1a?ixid=MXwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHw%3D&ixlib=rb-1.2.1&auto=format&fit=crop&w=1950&q=80">
<span>Photo by <a href="https://unsplash.com/@fhavlik?utm_source=unsplash&utm_medium=referral&utm_content=creditCopyText">Filip Havlik</a> on <a href="https://unsplash.com/s/photos/music-streaming-phone?utm_source=unsplash&utm_medium=referral&utm_content=creditCopyText">Unsplash</a></span>
  

<a id="data"></a>
[back to top](#table-of-contents)

# 2.  Data Preparation 

First, you need to import lot's of Spark libralies as below, then you can start opening the instance of SparkSession to wranlge the big data. 

In [None]:
import os
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly
from plotly import graph_objs as go
%matplotlib inline 

In [None]:
# import pyspark libraries
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from pyspark.sql import Window
from pyspark.sql.types import IntegerType
from pyspark.sql.types import FloatType
from pyspark.sql.types import StringType
from pyspark.sql.types import DateType

from pyspark.ml import Pipeline, PipelineModel, Estimator, Transformer
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StandardScaler
from pyspark.ml.feature import MinMaxScaler
from pyspark.ml.feature import VectorSlicer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import LinearSVC
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.classification import GBTClassifier

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator
from pyspark.ml.tuning import CrossValidatorModel
from pyspark.ml.tuning import ParamGridBuilder
from pyspark.mllib.evaluation import MulticlassMetrics

In [None]:
# create a Spark session
spark = SparkSession \
        .builder \
        .appName("Sparkify") \
        .getOrCreate()

<a id="clean"></a>
[back to top](#table-of-contents)

## 2.1. Load and Clean Missing Data 

### Load Data
In this project, we are going to usethe website logging data collected from the virtual music streaming company, called "Sparkify". The original size of this dataset is 12GB, but we can start exploring data with the subset of them -`mini_sparkify_event_data.json`. 


In [None]:
# load data 
df = spark.read.json("mini_sparkify_event_data.json")

In [None]:
# printthe basic inforamation of the dataset
print(f"Total records of the data: {df.count()}")
df.printSchema()
df.head()

In [None]:
# print the unique categories of each columns to understand data structure
for col in ['auth', 'level', 'page',  'gender','location','artist', 'song',]: 
    unique_cols = df.select(col) \
                    .groupBy(col).count() \
                    .orderBy("count", ascending=False) \
                    .show()

### Clean Missing Data
Now it's time to clean up some empty and invalid data. If you run the code below, you can see that there are some users with empty string, who probably not regular users in Sparkify. So we can discard those users from our dataset. Also, you can see that there are a lot of None values in "song" and other columns, but we don't need to take care of them for now (We are going to transform this dataset for each user so those song and artists data are not critical)

In [None]:
# Show the missing / invalid values in each column
for col in df.columns: 
    empty_count = df.where(df[col] == "").count()
    none_count = df.where(df[col].isNull()).count()
    print(f"{col}: \t\tempty({empty_count}), \tnone({none_count})")

In [None]:
# visualize the missing values for each column
pd_isnull = df.toPandas().isnull().replace({True:1, False:0}) 

trace = go.Heatmap(
            x=pd_isnull.columns.tolist(),
            y=pd_isnull.index.tolist(),
            z=pd_isnull.values.tolist(), 
            xgap=0.5,
            colorscale=[[0,'black'], [1,'whitesmoke']], 
            showscale=False,
        )
    
layout = dict(title = dict(
                text='Missing Data HeatMap',
                x=0.5, 
                y=0.9,
                xanchor='center',
                yanchor='top', 
                font_size=25, 
            ),
            plot_bgcolor = 'darkgrey', 
            paper_bgcolor = 'rgb(243,243,243)', 
            font = dict(
                family='Times New Roman', 
                size=15,
            ),
            xaxis=dict(
                title='columns', 
                ticks='outside', 
                tickangle=-45, 
                side='top'
            ),
            yaxis=dict(
                title='index', 
                showticklabels=False
            ),
            margin=dict(t=200,b=10),
        )
fig = go.Figure(data=[trace], layout=layout)
fig.show()

In [None]:
# Drop the empty userID since invalid userID generates 8346 "none"s in "firstName", "gender", "lastName", "location", "registration", and "userAgent"
# for "nones" in the "artist" and "song" columns, we can leave them because it is an optional information when a user plays some music. 
df_clean = df.where(df.userId != "")
df_clean.count()

<a id="log"></a>
[back to top](#table-of-contents)

## 2.2. Build the Webstie-Log Dataset (df_log) 

### Understand the Website Log data
I wrote down some descriptions of the Sparkify user transaction data in the below table. Also I highlighted some important variables that used in this analysis. 


| Column | Type |  Description| Comment|
| --- | --- | --- | --- |
|**ts** |(long) | the timestamp of user log |  |
| sessionId |(long) | an identifier for the current session |  |
| auth |(string)| the authentification log |'LoggedIn','LoggedOut', 'Guest', 'Cancelled' |
| itemInSession |(long)| the number of items in a single session |  |
| method |(string)| HTTP request method |'put', 'get'|
| status |(long)| http status |307:Temporary Redirect, 404:Not Found, 200:OK)|
| userAgent |(string)| the name of HTTP user agent  | |
| **userId** | (string) | the user Id | |
| **gender** | (string) | the user gender | 'F', 'M' |
| **location** | (string) | the user location  | |
| firstName | (string) | the user first name  | |
| lastName | (string) | the user last  | |
| **registration** | (long) | the timestamp when user registered | |
| **level** | (string) | the level of user subscription | 'free','paid' |
| **page** | (string) | the page name visited by users | 'Home', 'Login', 'LogOut','Settings', 'Save Settings','about', 'NextSong', <br> 'Thumbs Up', 'Thumbs Down', 'Add to Playlist','Add Friend', 'Roll Advert',<br> 'Upgrade', 'Downgrade', 'help','Submit Downgrade', 'Cancel','Cancellation Confrimation' |
| **artist** | (string) | the name of the artist played by users | |
| **song** | (string) | the name of songs played by users | |
| length | (double) | the length of a song in seconds | |


### Define Churn

The customer churn is defined when existing customers cancel the subscription. In this project, I define the churn status as 1 when a user visit the page `Cancellation Confirmation` only. Since this datset shows only two months, if someone submit the downgrade before Oct, the level of the person is 'free', but not churned! Thus we need to analyze churn rate for both free and paid users. 


In [None]:
# the subscription level is changed for single user when the user hits the `submit Downgrade` page
df_clean.where((df_clean.userId == '131') & (df_clean.page != 'NextSong')).select('page', 'level').show()

In [None]:
# see who cancelled 
df_clean.select(['userId', 'page','level']).where(df_clean.page == 'Cancellation Confirmation').show(5)

# show the total cancels 
churned_users = df_clean.where(df_clean.page == "Cancellation Confirmation").select("userId").dropDuplicates().count()
total_users = df_clean.select("userId").dropDuplicates().count()

print(f"The count of churned users (1): {churned_users}")
print(f"The count of Not-Churned users (0): {total_users - churned_users}")

### Build The Website-Log Dataframe (df_log) by Adding More Columns
From the cleaned dataset (df_clean), we can build the website-log dataframe (df_log) with our target variable `churn` and some additional columns as below.  
- `churn`(int): The status which indicates if a user is churned(1) or not(0)  
- `ds` (datetime): DateStamp(ds) converted from the timestamp data "ts"
- `dsRestration` (datetime): DateStamp(ds) converted from the timestamp data "registration"
- `locCity` (str): the name of city from "location" data 

In [None]:
# add datetime columns from timestamp columns
to_datestr = f.udf(lambda t: datetime.datetime.fromtimestamp(t / 1000.0).strftime("%Y-%m-%d %H:%M:%S"))
is_cancelled = f.udf(lambda page: 1 if page == 'Cancellation Confirmation' else 0, IntegerType())
w = Window.partitionBy('userId')
df_log = df_clean.withColumn('churnEvent', is_cancelled(f.col('page'))) \
                 .withColumn('churn', f.max(f.col('churnEvent')).over(w)) \
                 .withColumn('ds', f.to_date(to_datestr(f.col('ts')))) \
                 .withColumn('dsRegistration', f.to_date(to_datestr(f.col('registration')))) \
                 .withColumn("locCities", f.split(df.location, ',').getItem(0)) \
                 .withColumn("locCity", f.split(f.col('locCities'), '-').getItem(0)) \
                 .drop('churnEvent','locCities') # intermediate columns

# show result
df_log.dropDuplicates(subset=['userId']).select(['userId', 'churn', 'ds', 'dsRegistration', 'locCity']).show(3)
df_log.printSchema()

In [None]:
# summarize the df_log

print(f'The shape of the raw data: {df.toPandas().shape}')
print(f'The shape of the clean data: {df_clean.toPandas().shape}')
print(f'The shape of the log data: {df_log.toPandas().shape}')

total_users = df_log.dropDuplicates(["userId"]).count()
churned_users = df_log.where(f.col('churn')==1).dropDuplicates(['userId']).count()
stayed_users = df_log.where(f.col('churn')==0).dropDuplicates(['userId']).count()
print(f'The number of users (unique userId): {total_users}')  
print(f"The count of churned users (1): {churned_users}")
print(f"The count of Not-Churned users (0): {stayed_users}")

log_period = df_log.select(f.min('ds'), f.max('ds')).first()
print(f'The logging period: {log_period[0]} - {log_period[1]}')

<a id="user"></a>
[back to top](#table-of-contents)

## 2.3. Transform the Webstie Logs into the User-log Data by Aggregation

In order to predict churn status for users, the website-log data needs to be transformed for each user. First, we need to discard some columns that are not related to customer churn events such as session logs and user names. Then, we can transform data based on userId and there are two types of data: user information and user activities. User information columns in our data are churn, gender, level, and locCity, which must be the same for each user.

For user activity data, we need to aggregate the logging data to create some meaningful features. I listed the new columns that I added to the user-log dataset below. 
- lifeTime (long): the user lifetime is how long a user has been alive on the website, and the number indicates days from the registration date to the last active log date
- playTime (double): the song playtime is the average time(sec) of the total songs played by a user while the user visits next songpage
- numSongs (long): the total number of song names for each user 
- numArtists (long): the total number of artist names for each user
- numPage_* (long): the total number of page visits for each page and each user. Note that Cancellation and Conform cancellation pages are not considered for the feature group because those are used for generating churn labels. Also, Login and Register have zero counts for all users in our dataset, so they are automatically dropped

In [None]:
# 1. add lifeTime 
w =  Window.partitionBy('userId')
df_life = df_log.select('userId', 'ds', 'dsRegistration') \
               .withColumn('dsLastLog', f.max('ds').over(w)) \
               .withColumn('lifeTime', f.datediff(f.col('dsLastLog'), f.col('dsRegistration'))) \
               .dropDuplicates(subset=['userId']) \
               .drop('ds','registration', 'dsRegistration', 'dsLastLog')
df_life.show(5)
df_life.count() # make sure total 225

In [None]:
# 2. add playTime
# user defined function to set flag if a user visits 'home' or not 
is_homevisit = f.udf(lambda ishome : int(ishome == 'Home'), IntegerType())
# define windows to partition 
w1 = Window.partitionBy('userID') \
           .orderBy(f.desc('ts')) \
           .rangeBetween(Window.unboundedPreceding, 0)
w2 = Window.partitionBy(['userId', 'songCnt'])
w3 = Window.partitionBy('userId')
# add playTime column (# 1000: msec -> sec )
df_play = df_log.select('userId', 'page', 'ts') \
               .where((f.col('page') == 'NextSong') | (f.col('page') == 'Home')) \
               .withColumn('homevisit', is_homevisit(f.col('page'))) \
               .withColumn('songCnt', f.sum('homevisit').over(w1)) \
               .withColumn('songTime', f.max('ts').over(w2) - f.min('ts').over(w2)) \
               .dropDuplicates(subset=['userId', 'songCnt']) \
               .withColumn('playTime', f.avg('songTime').over(w3) / 1000) \
               .dropDuplicates(subset=['userId']) \
               .drop('page','ts','homevisit', 'songCnt', 'songTime')

df_play.show(3)
df_play.count() # make sure total 225

In [None]:
# 3. num of songs 
is_valid = f.udf(lambda x: 1 if x else 0, IntegerType()) # null 
w = Window.partitionBy('userId')
df_song = df_log.select('userId', 'song') \
                .dropDuplicates(['userId','song']) \
                .withColumn('songValid', is_valid(f.col('song'))) \
                .withColumn('numSongs', f.sum('songValid').over(w)) \
                .dropDuplicates(['userID']) \
                .drop('song', 'songValid')

df_song.show(3)
df_song.count() # make sure total 225

In [None]:
# 4. num of artists 
is_valid = f.udf(lambda x: 1 if x else 0, IntegerType()) # null 
w = Window.partitionBy('userId')
df_arts = df_log.select('userId', 'artist') \
                .dropDuplicates(['userId','artist']) \
                .withColumn('artistValid', is_valid(f.col('artist'))) \
                .withColumn('numArtists', f.sum('artistValid').over(w)) \
                .dropDuplicates(['userID']) \
                .drop('artist', 'artistValid')

df_arts.show(3)
df_arts.count() # make sure total 225

In [None]:
# 5. pages  
print(df.where(df.page == 'Login').count(), df_clean.where(df_clean.page =='Login').count())
print(df.where(df.page == 'Register').count(), df_clean.where(df_clean.page =='Register').count())

w = Window.partitionBy('userId','page')
df_page = df_log.select('userId','page')\
    .withColumn('pageVisits',f.count('page').over(w))\
    .groupBy('userId') \
    .pivot('page') \
    .max('pageVisits')\
    .na.fill(0) \
    .drop('page', 'Cancel', 'Cancellation Confirmation') # don't include features related to churn

# rename the columns with prefix 'numPage'
page_names = df_page.columns
page_names.remove('userId')
df_page = df_page.select(['userId']+[f.col(c).alias('numPage_' + c.replace(" ", "")) for c in page_names])

print(df_page.take(1))
print(df_page.toPandas().shape)

In [None]:
# We predict churn for each user so that we need to extract features based on 'userId'
# user data 
df_user = df_log.select('userId', 'churn', 'gender', 'level', 'locCity') \
                .dropDuplicates(subset=['userId']) \
                .join(df_life, on='userId') \
                .join(df_play, on='userId') \
                .join(df_song, on='userId') \
                .join(df_arts, on='userId') \
                .join(df_page, on='userId') \
        
# show result
print(f'The shape of the User-Log data(df_user): {df_user.toPandas().shape}')


In [None]:
df_user.select('userId', 'churn', 'gender', 'level', 'locCity').show(3)
df_user.select('lifeTime', 'PlayTime', 'numSongs', 'numArtists', 'numPage_About').show(3)
df_user.printSchema()

<a id="eda"></a>
[back to top](#table-of-contents)
# 3. Exploratory Data Analysis (EDA): Churned Vs. Stayed

For the visualization, I used the Plotly libray to make interactive plots. The main code is long so I hide the cells below. If you want to check, please expand the cell below. 

In [None]:
# this cell contains helper functions to draw plots using Plotly

# define the global color map for churned users and stayed user
fig_colors = {'churned':'rgba(171, 50, 96, 0.6)', 'stayed':'rgba(12, 50, 196, 0.6)'} # churn, stayed


def draw_barplot(x, y_churned, y_stayed, topic=''): 
    """ draw a stacked bar plot for two user groupbs (churned and stayed) 
        using plotly library 
    [Args] 
        x (list): the data list for x axis
        y_churned (list) : the y axis data for 'churned' trace
        y_stayed (list): the y axis data for 'stayed' trace
        topic (str): the str used for title 
    [Returns] 
        fig (obj): plotly go.Figure() object 
    """
    y_total = []
    y_total.append(int(y_churned[0] + y_stayed[0]))
    y_total.append(int(y_churned[1] + y_stayed[1]))
    
    trace1 = go.Bar(
        name="Churned", 
        x=x,  
        y=y_churned,
        text=[f"Churn Rate: {100*(cnt/(total+1e-10)) :.2f}%<br>" \
              for cnt,total in zip(y_churned,y_total)], 
        textposition='auto',
        hovertext = [f"Churn Rate: {100*(cnt/(total+1e-10)) :.2f}%<br>" \
                     + f"Churned Users: {cnt :.0f} / {total :.0f}" 
                     for cnt,total in zip(y_churned,y_total)], 
        marker=dict(color=fig_colors['churned']), 
        opacity=0.75, 
    )

    trace2 = go.Bar(
        name="Stayed", 
        x=x,  
        y=y_stayed,
        text=[f"Stay Rate: {100*(cnt/(total+1e-10)) :.2f}%" for cnt,total in zip(y_stayed,y_total)], 
        textposition='auto',
        hovertext = [f"Churn Rate: {100*(cnt/(total+1e-10)) :.2f}%<br>" \
                     + f"Churned Users: {cnt :.0f} / {total :.0f}" 
                     for cnt,total in zip(y_stayed,y_total)], 
        marker=dict(color=fig_colors['stayed']), 
        opacity=0.75, 
    )

    layout = dict(
        barmode = 'stack', 
        hovermode = 'x', 
        title = dict(
            text=f'The Churn Analysis: {topic}',
            x=0.5, 
            y=0.95,
            xanchor='center',
            yanchor='top', 
            font_size=25, 
        ),
        yaxis_title = 'Number of Users',
        legend = dict(
            orientation='h', 
            x=0.5, 
            y=1.15,
            xanchor='center',
            yanchor='top', 
        ), 
        plot_bgcolor = 'rgb(243,243,243)', 
        paper_bgcolor = 'rgb(243,243,243)', 
        font = dict(
            family='Times New Roman', 
            size=15,
        )
    )

    fig = go.Figure(data=[trace2, trace1], layout=layout)

    for item, total in zip(x, y_total): 
        fig.add_annotation(
            x=item, y=total, yshift=25,showarrow=False,
            text=f"Total in {item}: {total}",
        )

    return fig

def draw_timeplot(x, y_churned, y_stayed, topic=''): 
    """ draw a stacked bar plot for two user groupbs (churned and stayed) 
        with additional line traces  
    [Args] 
        x (list): the data list for x axis
        y_churned (list) : the y axis data for 'churned' trace
        y_stayed (list): the y axis data for 'stayed' trace
        topic (str): the str used for title 
    [Returns] 
        fig (obj): plotly go.Figure() object 
    """
    offset = 15

    trace1 = go.Scatter(
        name="Churned", 
        x=x,  
        y=np.array(y_churned) + np.array(y_stayed) - offset,
        marker=dict(color=fig_colors['churned'], symbol='square', size=15), 
        line=dict(color=fig_colors['churned'], width=2, dash='dash'), 
        showlegend=False,
        hoverinfo='none'
    )

    trace2 = go.Scatter(
        name="Stayed", 
        x=x,  
        y=np.array(y_stayed) - offset,
        marker=dict(color=fig_colors['stayed'], symbol='square', size=15), 
        line=dict(color=fig_colors['stayed'], width=2, dash='dash'), 
        showlegend=False,
        hoverinfo='none'
    )
    
    fig = draw_barplot(x, y_churned, y_stayed, topic=topic)
    # clear the text on bar plots 
    for trace in fig['data']: 
        trace['textposition'] = 'none'
    
    # add lineplots 
    fig.add_traces([trace1, trace2])
    
    # add annotations
    for xi, yi, shift, value in zip(trace1['x'], trace1['y'], [-30, +30], y_churned): 
        fig.add_annotation(
            x=xi, y=yi, xshift=shift,showarrow=False,
            text=f'{value:.0f}',
            font=dict(color=fig_colors['churned'])
        )
    for xi, yi, shift, value in zip(trace2['x'], trace2['y'], [-30, +30], y_stayed): 
        fig.add_annotation(
            x=xi, y=yi, xshift=shift,showarrow=False,
            text=f'{value:.0f}',
            font=dict(color=fig_colors['stayed'])
        )
    
    return fig

def stack_bars_horizontally(figs, topics=""):
    """ draw a combined multiple bar plots 
        using plotly library 
    [Args] 
        figs (list of objs): list of go.Figure() object to combine 
        topics (str): the str used for title 
    [Returns] 
        fig (obj): plotly go.Figure() object 
    """
    # Combine bar plots horizontaly 
    if not topics:
        topics = ', '.join([str(f.data[0].y0) for f in figs]) 
    
    fig_bars = go.Figure().set_subplots(rows=1, cols=len(figs), shared_yaxes=True)

    for i, fig in enumerate(figs): 
        fig_bars.add_trace(fig.data[0], row=1, col=i+1)
        fig_bars.add_trace(fig.data[1], row=1, col=i+1)
        fig_bars.update_xaxes(fig.layout.xaxis, row=i+1, col=1)
        fig_bars.update_yaxes(fig.layout.yaxis, row=i+1, col=1)

        if i < (len(figs)-1): 
            fig_bars.update_traces(showlegend=False)

        fig_bars.add_annotation(fig.layout.annotations[0], row=1, col=i+1)
        fig_bars.add_annotation(fig.layout.annotations[1], row=1, col=i+1)


    fig_bars.update_layout(dict(
        barmode = 'stack', 
        hovermode = 'x', 
        title = dict(
            text=f'Churn Analysis: ' + topics,
            x=0.5, 
            y=0.95,
            xanchor='center',
            yanchor='top', 
            font_size=25, 
        ),
        yaxis_title = 'Number of Users',
        legend = dict(
            orientation='h', 
            x=0.5, 
            y=1.15,
            xanchor='center',
            yanchor='top', 
        ), 
        plot_bgcolor = 'rgb(243,243,243)', 
        paper_bgcolor = 'rgb(243,243,243)', 
        font = dict(
            family='Times New Roman', 
            size=15,
        ),
    ))

    return fig_bars
    
    
def draw_geoplot(pd_city):
    """ draw a geoplot for US city using plotly library 
    [Args] 
        pd_city (Spark Dataframe): the dataframe with the geographical information 
    [Returns] 
        fig (obj): plotly go.Figure() object 
    """
    limits = [(0, 0.05), (0.05,0.25),(0.25,0.50),(0.50,0.75),(0.75,.95),(0.95,1.001)]
    scale = 150

    data = []
    for i in range(len(limits)):
        lim = limits[i]
        pd_sub = pd_city.loc[(pd_city['churnRate']>=lim[0]) & (pd_city['churnRate'] <lim[1]), :]
        data.append(go.Scattergeo(
            locationmode = 'USA-states',
            lon = pd_sub['lon'],
            lat = pd_sub['lat'],
            text = pd_sub['text'],
            marker = dict(
                size = pd_sub['churnRate']*scale + 10,
                opacity = 0.5,
                line_color='rgb(217, 217, 217)',
                line_width= 0.5,
                sizemode = 'area',
            ),
            name = f'{100*lim[0]:.0f}% - {100*lim[1]:.0f}%'
          )
        )
    
    layout = dict(
        geo = dict(
            scope = 'usa',
            landcolor = 'rgb(217, 217, 217)',
            bgcolor = 'rgb(243,243,243)',
        ),
        title = dict(
            text='The Churn Analysis: Location',
            x=0.5, 
            y=0.9,
            xanchor='center',
            yanchor='top', 
            font_size=25, 
        ),
        showlegend = True,
        legend_title='   Churn Rate',
        plot_bgcolor = 'rgb(243,243,243)', 
        paper_bgcolor = 'rgb(243,243,243)', 
        font = dict(
            family='Times New Roman', 
            size=15,
        )
    )


    fig = go.Figure(data=data, layout=layout)
    return fig


def draw_violinplot(x_churned, x_stayed, topic='', unit=''):
    """ draw a horizontal violin  plot for two user groupbs (churned and stayed) 
        using plotly library 
    [Args] 
        y_churned (list) : the data for 'churned' trace
        y_stayed (list): the data for 'stayed' trace
        topic (str): the str used for title 
        unit (str): additional unit information 
    [Returns] 
        fig (obj): plotly go.Figure() object 
    """
    trace1 = go.Violin(
        x=x_churned,
        y0=topic,
        width=1,
        name='Churned', scalegroup='Churned', legendgroup='Churned', 
        side='positive',
        marker=dict(color=fig_colors['churned']),
        pointpos=0.05,
    )

    trace2 = go.Violin(
        x=x_stayed,
        y0=topic,
        width=1,
        name='Stayed',scalegroup='Stayed', legendgroup='Stayed', 
        side='negative',
        marker=dict(color=fig_colors['stayed']),
        pointpos=-0.05,
    )

    #update characteristics shared by all traces 
    trace_all = dict(
        box_visible=True,
        box_width=0.8,
        meanline_visible=True,
        meanline_width=3,
        opacity=0.75,
        jitter=0.02,
        scalemode='width',
        orientation='h',
#         points='all',
    )

    layout = dict(
        violinmode='overlay',
        violingroupgap=0, violingap=0,
        title = dict(
            text=f'The Churn Analysis: {topic}',
            x=0.5, 
            y=0.9,
            xanchor='center',
            yanchor='top', 
            font_size=25, 
        ),
        yaxis=dict(title=f'{topic}', showticklabels=False),
        xaxis=dict(ticksuffix=f' {unit}',rangemode='tozero'),
        showlegend = True,
        plot_bgcolor = 'rgb(243,243,243)', 
        paper_bgcolor = 'rgb(243,243,243)', 
        font = dict(
            family='Times New Roman', 
            size=15,
        )
    )

    fig = go.Figure(data=[trace1, trace2], layout=layout)
    fig.update_traces(trace_all)

    return fig

# Combine violin plots vertically  
def stack_violins_vertically(figs, topics=""):
    """ draw a vertically stacked violin plots 
    [Args] 
        figs (list of objs): list of go.Figure() object to combine 
        topics (str): the str used for title 
    [Returns] 
        fig (obj): plotly go.Figure() object 
    """
    if not topics:
        topcis = ', '.join([str(f.data[0].y0) for f in figs]) 
    
    fig_violins = go.Figure().set_subplots(rows=len(figs), cols=1, vertical_spacing=0.05)
    for i, fig in enumerate(figs): 
        fig_violins.add_trace(fig.data[0], row=i+1, col=1)
        fig_violins.add_trace(fig.data[1], row=i+1, col=1)

        fig_violins.update_xaxes(fig.layout.xaxis, row=i+1, col=1)
        fig_violins.update_yaxes(fig.layout.yaxis, row=i+1, col=1)
        fig_violins.update_yaxes(title="", showticklabels=True, ticks='outside', row=i+1, col=1)
        if i < (len(figs)-1): 
            fig_violins.update_traces(showlegend=False)    

    # fig_violins.update_traces(box_visible=True,points='all', jitter=0.1, scalemode='count')
    fig_violins.update_layout(dict(
        violinmode='overlay',
        title = dict(
            text=f'The Churn Analysis: {topics}',
            x=0.5, 
            y=0.97,
            xanchor='center',
            yanchor='top', 
            font_size=25, 
        ),
        showlegend = True,
#             legend=dict(orientation='h', x=0.42, xanchor='center', y=1.06, yanchor='top'),
        legend = dict(
            orientation='h', 
            x=0.42, 
            y=1.05,
            xanchor='center',
            yanchor='top', 
        ), 
        plot_bgcolor = 'rgb(248,248,248)', 
        paper_bgcolor = 'rgb(243,243,243)', 
        font = dict(
            family='Times New Roman', 
            size=15,
        ), 
        height=1000,
#         margin_t=100, 
        ),
    )
    
    return fig_violins


In [None]:
# 1. Churn Analysis: Time 

# churn rate per month 
df_oct_log = df_log.where((df_log['ds'] >= datetime.date(2018, 10, 1)) & (df_log['ds'] < datetime.date(2018, 11, 1)))
df_nov_log = df_log.where((df_log['ds'] >= datetime.date(2018, 11, 1)) & (df_log['ds'] < datetime.date(2018, 12, 1)))

# total count of users across oct and nov 
y_usercount = []
y_usercount.append(df_oct_log.dropDuplicates(["userId"]).count())
y_usercount.append(df_nov_log.dropDuplicates(["userId"]).count())

y_churnrate = [] 
y_churnrate.append(df_oct_log.dropDuplicates(['userId']).select(f.mean('churn')).collect()[0][0])
y_churnrate.append(df_nov_log.dropDuplicates(['userId']).select(f.mean('churn')).collect()[0][0])

# Plot
x = ['October', 'November']
y_churned = [cnt*rate for cnt,rate in zip(y_usercount, y_churnrate)]
y_stayed = [cnt*(1-rate) for cnt,rate in zip(y_usercount, y_churnrate)]
fig_time = draw_timeplot(x, y_churned, y_stayed, topic='Time')
fig_time.show()

In [None]:
# 2. Churn Analysis: Gender
pd_gender = df_user.groupBy(['gender', 'churn']) \
                   .agg(f.count('churn').alias('churnCnt')) \
                   .orderBy('gender') \
                   .toPandas()

x = ['Female', 'Male'] 
y_churned = pd_gender[pd_gender.churn == 1]["churnCnt"].tolist() #[femaleChurned, maleChurned]
y_stayed = pd_gender[pd_gender.churn == 0]["churnCnt"].tolist() #[femaleStayed, maleStayed]

fig_gender = draw_barplot(x, y_churned, y_stayed, topic='Gender')
# fig_gender.show()

In [None]:
# 3. Churn Analysis: Subscription Level
pd_level = df_user.groupBy(['level', 'churn']) \
                   .agg(f.count('churn').alias('churnCnt')) \
                   .orderBy('level') \
                   .toPandas()

x = ['Free', 'Paid'] 
y_churned = pd_level[pd_level.churn == 1]["churnCnt"].tolist()
y_stayed = pd_level[pd_level.churn == 0]["churnCnt"].tolist()

fig_level = draw_barplot(x, y_churned, y_stayed, topic='Subscription Level')
# fig_level.show()

In [None]:
# Plot the integratedbar plots 
figs = [fig_gender, fig_level]
fig_bars = stack_bars_horizontally(figs, topics='Gender And Subscription Level')
fig_bars.show()

In [None]:
# 4. Churn Analysis: City Location  

# select the data of city and churn
pd_city = df_log[['locCity', 'churn']]\
            .groupBy('locCity')\
            .agg(f.count('churn').alias('userCount'), f.sum('churn').alias('churnCount'))\
            .toPandas()

# pull the us city location(long,lan) data from the external source
pd_usloc = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/2014_us_cities.csv')
pd_usloc = pd_usloc.loc[:,['name', 'lat', 'lon']]
pd_usloc['name'] = pd_usloc['name'].str.strip()
pd_usloc = pd_usloc.rename(columns={'name':'locCity'})
pd_usloc = pd_usloc.drop_duplicates('locCity')

# display the differnece
print(list(set(pd_city.locCity.unique()) - set(pd_usloc.locCity.unique())))

#replace the different names
pd_usloc = pd_usloc.replace({'Winston-Salem':'Winston', 'New London':'London'})
# add new data
pd_missing =pd.DataFrame([['Alexandria', 38.8048, -77.0469], 
                          ['Birmingham', 33.5186, -86.8104], 
                          ['Anchorage', 61.2181, -149.9003],
                          ['Morgantown', 39.6295, -79.9559],
                          ['Hagerstown', 39.6418, -77.7200],
                          ['North Wilkesboro', 36.1585, -81.1476],
                          ['Santa Maria', 34.9530, -120.4357],
                          ['Allentown', 40.602, -75.4714],
                          ['Manchester', 42.9956, -71.4548],
                          ['Fairbanks', 64.8378, -147.7164],
                          ['Flint', 43.0125, -83.6875]], columns=pd_usloc.columns)
pd_usloc = pd.concat([pd_usloc, pd_missing], axis=0, ignore_index=True)

# display the differnece
print(list(set(pd_city.locCity.unique()) - set(pd_usloc.locCity.unique())))

# join two dataframe 
pd_city = pd.merge(pd_city, pd_usloc, how='left', on='locCity')

# add more detail info for the figure 
pd_city['churnRate'] = pd_city['churnCount'] / pd_city['userCount']
pd_city['text'] = pd_city['locCity'].apply(lambda x: f'In {x} city <br>')
pd_city['text'] += pd_city['churnCount'].apply(lambda x: f' Churn Count: {x} <br>')
pd_city['text'] += pd_city['userCount'].apply(lambda x: f' User Count: {x} <br>')
pd_city['text'] += pd_city['churnRate'].apply(lambda x: f' Churn Rate: {x*100:.2f} %')

# sort dataframe along churnCount
# pd_city = pd_city.sort_values('churnCount')
print(pd_city.loc[pd_city.churnRate > 0, :].describe())

#plot
fig_city = draw_geoplot(pd_city)
fig_city.show()

In [None]:
# 5. Churn Analysis: User Lifetime
# :User lifetime is define as the duration from the registration to the last activity log 
pd_lifetime = df_user.select('userid','lifeTime','churn') \
                 .toPandas()
x_churned = pd_lifetime['lifeTime'][pd_lifetime['churn'] == 1]
x_stayed = pd_lifetime['lifeTime'][pd_lifetime['churn'] == 0]

fig_life = draw_violinplot(x_churned, x_stayed, topic='User Life Time', unit='days')
# fig_life.show()

In [None]:
# 6. Churn Analysis: Song Playing Time
pd_playtime = df_user.select('playTime','churn').toPandas()
x_churned = pd_playtime['playTime'][pd_playtime['churn'] == 1] / 3600 # sec->hours
x_stayed = pd_playtime['playTime'][pd_playtime['churn'] == 0] / 3600

fig_play = draw_violinplot(x_churned, x_stayed, topic='Average Song Playing Time', unit='hrs')
# fig_play.show()

In [None]:
# 7. Churn Analysis : number of songs
pd_numsongs = df_user.select('numSongs','churn').toPandas()
x_churned = pd_numsongs['numSongs'][pd_numsongs['churn'] == 1] 
x_stayed = pd_numsongs['numSongs'][pd_numsongs['churn'] == 0] 

fig_song = draw_violinplot(x_churned, x_stayed, topic='Number of Songs Played', unit='songs')
# fig_song.show()


In [None]:
# 8. Churn Analysis : number of Artists
pd_numarts = df_user.select('numArtists','churn').toPandas()
x_churned = pd_numarts['numArtists'][pd_numarts['churn'] == 1] 
x_stayed = pd_numarts['numArtists'][pd_numarts['churn'] == 0] 

fig_arts = draw_violinplot(x_churned, x_stayed, topic='Number of Artists', unit='artists')
# fig_arts.show()

In [None]:
# Show the combined violin plots for numerical variables 
fig_list = [fig_life, fig_play, fig_song, fig_arts]
fig_violins1 = stack_violins_vertically(fig_list, topics='Numerical Features')
fig_violins1.show()

In [None]:
# 9. Churn Analysis : Pages
pd_pages = df_user.select(['churn'] + [c for c in df_user.columns if c.lower().find('page') > -1]).toPandas()
y_churned = pd_pages.loc[pd_pages['churn'] == 1, pd_pages.columns.drop('churn')]
y_stayed = pd_pages.loc[pd_pages['churn'] == 0, pd_pages.columns.drop('churn')]

fig_pages = {}
for i, col in enumerate(y_churned.columns): 
    fig_pages[col] = draw_violinplot(y_churned[col], y_stayed[col], 
                                     topic=f'Page Visits<br>({col.split("_")[1]})', 
                                     unit='visits')
#     fig.show()

In [None]:
page_names = ['numPage_About', 'numPage_Error', 'numPage_SubmitDowngrade', 'numPage_RollAdvert', 'numPage_ThumbsDown', 'numPage_Upgrade'] 
fig_list = [fig_pages[name] for name in page_names]
fig_violins2 = stack_violins_vertically(fig_list, topics='Numerical Features - PageVisits')
fig_violins2.show()

fig_list = list(fig_pages.values())[:6]
fig_violins_all = stack_violins_vertically(fig_list, topics='Numerical Features - PageVisits')
fig_violins_all.show()

fig_list = list(fig_pages.values())[6:12]
fig_violins_all = stack_violins_vertically(fig_list, topics='Numerical Features - PageVisits')
fig_violins_all.show()

fig_list = list(fig_pages.values())[12:]
fig_violins_all = stack_violins_vertically(fig_list, topics='Numerical Features - PageVisits')
fig_violins_all.show()

<a id="prediction"></a>
[back to top](#table-of-contents)

# 4. Customer Churn Prediction

Now it's time to build a model and predict if a user is going to churn or not. With this information, the business model can be personalized, such as providing special promotions. For this, we need Feature Engineer first, and then find the proper model, and finally, we can deploy this model to the real business in the future.  

## 4.1.Engineer Features
From the visualization, we can finally select our features for the prediction model by modifying the user-log dataset as below.
- `LocCity` will be dropped from the feature set since it has many extreme values
- `numSongs` and `numArts` are highly correlated (0.99) so numSongs will be chosen only for the feature set
- `numPage_SubmitDowngrade` will be converted to the categorical feature `page_SubmitDowngrade` having only two values: ‘visited’ or ‘none’
- The features related to the `page` column have many correlations each other, so I selected only 7 features:`numPage_About`, `numPage_Error`, `numPage_RollAdvert`, `numPage_SaveSettings`, `numPage_SubmitDowngrade`, `numPage_ThumbsDown`, `numPage_Upgrade`


In [None]:
# build the feature dataframe
def show_corr(df): 
    plt.figure(figsize = (16,16))
    sns.heatmap(df.drop('userId').toPandas().corr(), annot=True, fmt='.2g'\
                ,vmin=-1,vmax=1,center=0,cmap='coolwarm',square=True)
    plt.savefig("corr_heatmap.png")
# Feature Selection from correlation graph 
show_corr(df_user) 
non_features = ['userId', 'churn']
category_cols = ['gender', 'level', 'page_SubmitDowngrade'] # drop LocCity
numeric_cols = ['lifeTime', 'playTime', 'numSongs', 'numPage_About', 'numPage_Error', 'numPage_RollAdvert', 'numPage_SaveSettings', 
                 'numPage_SubmitUpgrade', 'numPage_ThumbsDown', 'numPage_Upgrade'] 
# show_corr(df_user.select(non_features + numeric_cols)) 

is_visit = f.udf(lambda x: 'visited' if x > 0 else 'none', StringType())
df_feat = df_user.withColumn('page_SubmitDowngrade', is_visit(f.col('numPage_SubmitDowngrade'))) \
                 .drop('numPage_SubmitDowngrade') \
                 .withColumnRenamed('churn', 'label') \
                 .select(['userId', 'label'] + category_cols + numeric_cols)


df_feat.printSchema()

<a id="model"></a>
[back to top](#table-of-contents)

## 4.2. Build the Model Pipeline
Let’s start to build the pipeline for the prediction model using ML libraries in Spark. For better cross-validation, I combined all feature transformations and an estimator into one pipeline and feed it into the `CrossValidator`. There are three main parts of the pipeline.

1. **Feature Transformation**: category variables will be transformed to one-hot encoded vectors by `StringIndexer` and `OneHotEncoder`. Then, the categorical vector and numerical variables will be assembled into a dense feature vector using `VectorAssembler`.

2. **Feature Importance Selection**: I built a custom `FeatureSelector` class to extract only important features using a Tree-based estimator. This step is optional so that I didn’t use it for Logistic Regression or LinearSVC models.

3. **Estimator**: The final step is using ML algorithms to estimate the churn label for each user.


In [None]:
class FeatureSelector(Estimator): 
    """ Custom FeatureSelector class inherited from the Estimator class defined in PySpark
        to add the ML pipeline later 
    """
    def __init__(self, estimator=None, threshold = 0.01, outputCol='features'): 
        """ initialize the FeatureSelector variables to use later
        [Args]
            estimator (object) : tree-based estimator with featureImportances attributes 
            threshold (float): the threshold value to select features 
            outputCol (str): the name of output column having the selected features 
        [Returns]
            None
        """
        super(FeatureSelector, self).__init__()
        self._setDefault()
        self.estimator = estimator
        self.threshold = threshold
        self.outputCol = outputCol 
        
    def _fit(self, data): 
        """ fit the data using the tree-based estimator to extract feautre importnace. 
            Then, only select the improtant features over the threshold
        [Args]
            data (Spark Dataframe): dataframe for feature selection
        [Returns]
            (Spark Dataframe): return the dataframe with the selected features 
        """
        model = self.estimator.fit(data)
        pred = model.transform(data)
        df_featImp = self.ExtractFeatureImportance(model.featureImportances, pred, self.estimator.getFeaturesCol())
        feat_idx = [x for x in df_featImp.loc[df_featImp['score'] > self.threshold, 'idx']] 
        return VectorSlicer(inputCol = self.estimator.getFeaturesCol(), 
                            outputCol = self.outputCol, 
                            indices=feat_idx)
        
    def ExtractFeatureImportance(self, featureImp, dataset, featuresCol):
        list_extract = []
        for i in dataset.schema[featuresCol].metadata["ml_attr"]["attrs"]:
            list_extract = list_extract + dataset.schema[featuresCol].metadata["ml_attr"]["attrs"][i]
            
        varlist = pd.DataFrame(list_extract)
        varlist['score'] = varlist['idx'].apply(lambda x: featureImp[x])
        
        return(varlist.sort_values('score', ascending = False))

In [None]:
class Trainer: 
    """Custom trainer class to run the ML pipeline Using Cross Validation 
    """

    def __init__(self, data, category_cols, numeric_cols): 
        """ initialize the dataset infomation and split the data for the pipeline 
        [Args]
            data (Spark Dataframe): the dataset used for ML prediction learning and testing 
            cateogry_cols (list of str): the list of category column names in the dataset 
            numeric_cols (list of str): the list of numerical column names in the dataset 
        [Returns]
            None
        """
        # split the data into train and test 
        self.train_data, self.test_data = data.randomSplit([0.7, 0.3], seed=42)
        
        # initialize internal variables 
        self.cat_cols = category_cols
        self.num_cols = numeric_cols
        
        
    def run(self, est, params=ParamGridBuilder().build(), feature_selector_on=False, restart=False, prefix='base'):
        """ run the ML pipline using cross validation 
        [Args] 
            est (object): estimator object defined in the PySpark ml library  
            params (object): ParamGridBuilder object for cross validation 
            feature_selector_on (bool): featureSelction step in the pipeline will be turned on if it is True
            restart (bool): Trainer will generate a new model if it is True, otherwise, it will load the pre-trained model(./input/models)
            prefix (str): it is the prefix string for the file name which is used to save the trained model
        [Returns] 
            None (just print some status and result directly)
        """
        print('--------------------')
        # construct feature transform pipeline
        string_indexers = [StringIndexer(inputCol = c, outputCol = c + 'Idx', handleInvalid = 'keep') for c in self.cat_cols]
        onehot_indexers = [OneHotEncoder(inputCols = [c + 'Idx'], outputCols = [c + 'Vec']) for c in self.cat_cols]
        assembler = VectorAssembler(inputCols = self.num_cols + [c + 'Vec' for c in self.cat_cols], outputCol = "features")
#         label_indexer = StringIndexer(inputCol = 'churn', outputCol = 'label', handleInvalid = 'keep') # it makes class 3 not 2
       
        stages = string_indexers + onehot_indexers + [assembler] #, label_indexer]
        
        # make the use ofFeatureSelector optional depending on estimator
        if est and feature_selector_on: # est should be tree-based algorithms 
            selector = FeatureSelector(estimator = est.copy(), threshold=0.01, outputCol='features_subset')
            stages.append(selector)
            # update the estimator feature columns 
            est.setFeaturesCol('features_subset')
        
        # define the crossValidator 
        pipe = Pipeline(stages = stages + [est])
        evaluator = MulticlassClassificationEvaluator(metricName='f1')
        crossval = CrossValidator(estimator=pipe,
                                  evaluator=evaluator, 
                                  estimatorParamMaps=params,
                                  numFolds=3,
                                  seed=42)
        
        # dfeine the model path to save or load 
        model_name = str(est).split("_")[0]
        file_name = f'{prefix}_pipeline_{model_name}.pth'
        model_path = os.path.join('models', file_name)
        
        # train and select the best model
        if not restart and os.path.exists(model_path):           
            best_pipe = PipelineModel.load(model_path)
            print(f"load best pipeline from {model_path}")
        else: 
            print("running cross-validation ...")
            cv_model = crossval.fit(self.train_data)
            best_pipe = cv_model.bestModel
            best_pipe.write().overwrite().save(model_path) # save to kaggle output
            print(f'cv valid_f1: {cv_model.avgMetrics}')

        train_pred = best_pipe.transform(self.train_data) 
        test_pred = best_pipe.transform(self.test_data)
        
        # show the final results 
        print(f'best estimator: {best_pipe.stages[-1]}')
        print(f'parameters: {[(p.name, v) for p, v in best_pipe.stages[-1].extractParamMap().items()]}')
        print(f'train_f1: {evaluator.evaluate(train_pred):.4f}, test_f1: {evaluator.evaluate(test_pred):.4f}')
        if feature_selector_on: 
            feat_importance = selector.ExtractFeatureImportance(
                                best_pipe.stages[-1].featureImportances, 
                                train_pred, 
                                est.getFeaturesCol())
            print(feat_importance)

In [None]:
# Model Selection 
trainer = Trainer(df_feat, category_cols, numeric_cols)

trainer.run(LogisticRegression(), feature_selector_on=False)
trainer.run(LinearSVC(), feature_selector_on=False)
trainer.run(DecisionTreeClassifier(), feature_selector_on=True)
trainer.run(RandomForestClassifier(), feature_selector_on=True)
trainer.run(GBTClassifier(), feature_selector_on=True)

<a id="tune"></a>
[back to top](#table-of-contents)

## 4.4. Tune Hyper Parameters

Finally, I ran the cross-validation to tune the hyper-parameters of the `RandomForestClassifier`. **Since our dataset is very small, you can observe an almost perfect train score, pointing that model is overfitted**. Thus, I selected some cross-validation parameter maps to make the model less complex compared to the default model that I used for the above model selection (numTrees=20). **The result shows that the model with 10 trees and 16 max bins has slightly better performance but it didn’t overcome the overfitting problem well. I assume that this problem can be solved by adding more data.**

In [None]:
# 3. Hyper Parameter Tunning 
est = RandomForestClassifier()
params=ParamGridBuilder() \
        .addGrid(est.numTrees, [5, 10, 20]) \
        .addGrid(est.maxDepth, [3, 5]) \
        .addGrid(est.maxBins, [16, 32]) \
        .build()

trainer.run(est=est, params=params, feature_selector_on=True, prefix='best')

<a id="conclusion"></a>
[back to top](#table-of-contents)

# 5. [Conclusion](#conclusion)

In this article, we tackled one of the most challenging and common business problems - how to predict customer churns. From the nasty and huge website logs, we extracted several meaningful features for each user, and visualize them based on two user groups (churned vs. stayed) for more analysis. Finally, we built the ML pipeline including feature transformations and estimator, which fed to the cross-validator for model selection and hyper-parameter tuning. The final model shows a pretty high testing score (f1-score: 0.73), but it also has an overfitting problem due to the limitation of the small dataset size (128MB). Since Udacity provides the full dataset (12GB) on AWS cloud, I have a plan to deploy this Spark cluster to handle the overfitting problem soon. 

Absolutely, there are many things we can do to improve this model without considering data size. First, most features are just aggregated regardless of the time factor. The logs are collected for 2 months so it would be better to emphasize more recent ones with different methods like a weighted sum.  Also, we can apply some strategies to handle the data imbalance (this [blog](https://www.analyticsvidhya.com/blog/2017/03/imbalanced-data-classification/) will help you get some ideas). Furthermore, we can model this problem as time series models because churn rates should be reported periodically to the business stakeholders. 

 I hope that this project gives you a tutorial on how to deal with large data to solve a real-world problem using data science and machine learning skills. Plus, it can be good practice to play with the Spark and Plotly libraries. 


<a id="reference"></a>
[back to top](#table-of-contents)

# 6. [Reference](#reference)

- This project is the part of the [Udacity Data Scientist Nanodegree Program](https://www.udacity.com/course/data-scientist-nanodegree--nd025). The topic and dataset are given from the Udacity but the code and contents are written by myself. 

- To create the pipline, I got some help from this [blog](https://www.timlrx.com/blog/feature-selection-using-feature-importance-score-creating-a-pyspark-estimator)