# Sparkify Data Science Nanodegree Capstone Project

* The goal of this data science project is to predict user churn, i.e., identify what causes users to create a subscription or cancel it respectively.
* Instead of using the pandas library, we instead use the Apache Spark framework, which is build for big data analysis running on clusters.
* The dataset used is only a fraction of the original dataset, which is about 12GB large and stored on AWS S3 (in a first step we use the small dataset; later we'll run this notebook on an AWS EMR Cluster)

# Set up Spark and other Dependencies

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, split, trim, expr, datediff, from_unixtime

from pyspark.ml.feature import RegexTokenizer, CountVectorizer, IDF, StringIndexer, Normalizer, StandardScaler, VectorAssembler
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier, LinearSVC
from pyspark.ml import Pipeline
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator


import datetime
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

In [None]:
spark = SparkSession \
    .builder \
    .appName("sparkify") \
    .getOrCreate()

In [None]:
spark

In [None]:
path = "mini_sparkify_event_data.json"
user_log = spark.read.json(path)

Run the cell below to increase width of notebook. Will be needed for SQL data operations.

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
# display(HTML("<style>.output_result { max-width:100% !important; }</style>"))
display(HTML("<style>pre { white-space: pre !important; }</style>"))

# Data Exploration

To learn more about the users the follwoing questions need to answered:


## Who are the Users?

* What gender do users have?
* Where do they live?
* What are their favorite artists/songs?

## How do they interact with the Platform?

### Platform Usage

* How often do they use the platform?
* Do users who cancel their subscription don't use the platform or do they and might not like it?
* How long is an average session per user?
* What devices do they use to access the platform?

### Subscription Behavior

* How often do users switch their subscription level (paid/free)
* Is there a leading indicator that a user might cancel his paid subscriptions, e.g., high count of thumbs down



In [None]:
user_log.printSchema()

In [None]:
user_log.head()

The timestamp format is not in a readable format. By creating the user defined function `get_datetime` and applying this function onto the ts column we obtain the dateTime column which has a convenient format

In [None]:
# get_datetime = udf(lambda x: datetime.datetime.fromtimestamp(x / 1000.0).strftime("%Y-%m-%d %H:%M:%S"))

# user_log = user_log.withColumn("ts_iso", get_datetime(user_log.ts)) \
#             .withColumn("registration_iso", get_datetime(user_log.registration))

# use pysparks built-in function to handel unix timestamps (divide by 1000 to convert from milliseconds)
user_log = user_log.withColumn("ts_iso", from_unixtime(user_log.ts/1000)) \
                .withColumn("registration_iso", from_unixtime(user_log.registration/1000))

The location column contains both city and state, we want to isolate both of them into seperate columns

In [None]:
user_log = user_log.withColumn('city', trim(split(user_log['location'], ',').getItem(0))) \
            .withColumn('state', trim(split(user_log['location'], ',').getItem(1)))

The userAgent columns contains information about which end-device was used to access the platform. Therfore it is usefull to learn something about the platform's users

In [None]:
user_log = user_log.withColumn('userDevice', expr("CASE WHEN lower(userAgent) LIKE '%windows%' THEN 'pc'"+
                                       "WHEN lower(userAgent) LIKE '%macintosh%' THEN 'mac'" +
                                        "WHEN lower(userAgent) LIKE '%linux%' THEN 'pc'" +
                                        "WHEN lower(userAgent) LIKE '%iphone%' THEN 'mobile'" +
                                        "WHEN lower(userAgent) LIKE '%ipad%' THEN 'mobile'" +
                                        "ELSE 'other'" +
                                        "END"))

In [None]:
user_log.head()

Since we are interested in predicting user behavior, actions recorded in the dataset need to be attributed to a userid. For records without a userId we are not able to draw conclusions on the behaviour.

In [None]:
print('{} records in the dataset have no userId. They will be dropped'.format(user_log.filter(user_log.userId == '').count()))
user_log = user_log.where(user_log.userId != '')
print('{} records in the dataset have no userId.'.format(user_log.filter(user_log.userId == '').count()))


## Leveraging PySpark's SQL Function for Data Exploration

In [None]:
# create temporary view against which SQL queries can be run
user_log.createOrReplaceTempView("user_log_table")

### Create a table that contains user subscription behavior

Idea of the query below:

1. `errors_tab`
    * Assign each entry row number to maintain order of events (have consecutive entries with exact same timestamp)
    * Some actions are not realized because an error follows in the logs. Thus if entry $t+1$ has page "Error", set a flag for entry $t$ (Note that we want to keep the actual errors themselves, since they might be an indicator for users terminating their subscription, if they encounters a large number of errors during their paid regime)

2. `records_tab`
    * When level at $t$ is different from $t-1$ a new regime starts. Set a flag
    * When level at $t$ is "paid" and level at $t+1$ is "free", the regime at $t$ is cancelled. Set a flag
    * Choose only actions that don't result in errors

3. `regime_tab`
    * Sum up the flag set in the `records_tab`. Whenever the level changes, the `regime_count` increases by 1

4. `regime_dates`
    * The `regime_start` in the first regime is per assumption the registration date, as we don't have other data available. In later regimes $n>1$ we'll take the timestamp of the first record of regime $n$
    * In the last regime, $N$, the `regime_end` is the max. timestamp in the dataset. Else it's the timestamp of the latest record in the particular regime $n<N$

5. Final Result
    * Collect the pre-processed data
    * Obtain aggregate values per regime

**Note:** might want to break the pretabs down into seperate SQL query commands and tables to reduce complexity of statement

In [None]:
spark.sql("""
            SELECT * FROM user_log_table
            WHERE userId = 100001
            ORDER BY userId, ts
        """).show(5000)

In [None]:
user_regime = spark.sql("""
                        with errors_tab AS(
                        SELECT
                            userId, auth, page, level, ts/1000 ts, registration/1000 registration, userDevice,
                            row_number() OVER(PARTITION BY userId ORDER BY userId, ts) record_number,
                            CASE WHEN LEAD(page) OVER(PARTITION BY userId ORDER BY userId, ts) = 'Error' THEN 1 ELSE 0 END leads_to_error
                        FROM user_log_table
                        ),


                        records_tab as(
                        SELECT
                            userId, page, level, ts, registration, userDevice, record_number,
                            CASE WHEN auth = 'Cancelled' THEN ts END date_fully_cancelled,
                            CASE WHEN LAG(level) OVER(PARTITION BY userId ORDER BY record_number) != level THEN 1 ELSE 0 END regime_change,
                            CASE WHEN LEAD(level) OVER(PARTITION BY userId ORDER BY record_number) = 'free' and level = 'paid' THEN 1 ELSE 0 END downgrade_regime
                        FROM errors_tab
                        WHERE leads_to_error = 0
                        ),


                        regime_tab AS(
                        SELECT userId, page, level, ts, registration, userDevice, record_number, date_fully_cancelled, downgrade_regime,
                            SUM(regime_change) OVER(PARTITION BY userId ORDER BY record_number)+1 regime_count
                        FROM records_tab),


                        regime_dates AS(            
                        SELECT userId, page, level, userDevice, record_number, FROM_UNIXTIME(date_fully_cancelled) date_fully_cancelled, downgrade_regime, regime_count,
                            CASE
                                WHEN regime_count = 1 THEN FROM_UNIXTIME(registration)
                                WHEN regime_count > 1 THEN FROM_UNIXTIME(first_value(ts) OVER(PARTITION BY userId, regime_count))
                            END regime_start,
                            
                            CASE
                                WHEN regime_count = last_value(regime_count) OVER(PARTITION BY userId ORDER BY ts) AND date_fully_cancelled IS NULL THEN FROM_UNIXTIME((select max(ts) FROM records_tab))
                                WHEN regime_count = last_value(regime_count) OVER(PARTITION BY userId ORDER BY ts) AND date_fully_cancelled IS NOT NULL THEN FROM_UNIXTIME(date_fully_cancelled)
                                ELSE FROM_UNIXTIME(last_value(ts) OVER(PARTITION BY userId, regime_count))
                            END regime_end
                        FROM regime_tab)

                        SELECT
                            DISTINCT userId,
                            level,
                            regime_count,
                            DATE(regime_start) regime_start,
                            DATE(regime_end) regime_end,
                            DATEDIFF(day, regime_start, regime_end) regime_lenght,
                            SUM(downgrade_regime) OVER(PARTITION BY userId, regime_count) = 1 regime_downgraded,
                            DATE(date_fully_cancelled) date_fully_cancelled,
                            count_if(page = 'NextSong') OVER(PARTITION BY userId, regime_count) next_song,
                            count_if(page = 'Thumbs Up') OVER(PARTITION BY userId, regime_count) thumbs_up,
                            count_if(page = 'Thumbs Down') OVER(PARTITION BY userId, regime_count) thumbs_down,
                            count_if(page = 'Add Friend') OVER(PARTITION BY userId, regime_count) add_friend,
                            count_if(page = 'Add to Playlist') OVER(PARTITION BY userId, regime_count) add_to_playlist,
                            count_if(page = 'Save Settings') OVER(PARTITION BY userId, regime_count) save_settings,
                            count_if(page = 'Error') OVER(PARTITION BY userId, regime_count) error,
                            count_if(page = 'Help') OVER(PARTITION BY userId, regime_count) help,
                            count_if(page = 'Roll Advert') OVER(PARTITION BY userId, regime_count) roll_advert,
                            userDevice
                        FROM regime_dates
                        ORDER BY userId, regime_count
                        
                        """)
# store results in temporary view: user_regime_table
user_regime.createOrReplaceTempView("user_regime_table")
user_regime.show(100)

### Aggregate the `user_regime_table` data per User

The aggregate table should have one row per user which contains all necessary information and can be used for ML model input.

In [None]:
user_features= spark.sql("""
                        SELECT
                            CAST(userId AS int),
                            count_if(level = 'paid') paid_regimes,
                            count_if(level = 'free') free_regimes,
                            sum(regime_lenght) membership_days,
                            count_if(regime_downgraded = 'true') downgraded_regimes,
                            CASE WHEN max(date_fully_cancelled) IS NOT NULL THEN 1 ELSE 0 END fully_cancelled,
                            sum(next_song) next_song,
                            sum(thumbs_up) thumbs_up,
                            sum(thumbs_down) thumbs_down,
                            sum(add_friend) add_friend,
                            sum(add_to_playlist) add_to_playlist,
                            sum(save_settings) save_settings,
                            sum(error) error,
                            sum(help) help,
                            sum(roll_advert) roll_advert,
                            CASE WHEN count_if(userDevice = 'mac')>0 THEN 1 ELSE 0 END device_mac,
                            CASE WHEN count_if(userDevice = 'pc')>0 THEN 1 ELSE 0 END device_pc,
                            CASE WHEN count_if(userDevice = 'mobile')>0 THEN 1 ELSE 0 END device_mobile
                        FROM
                            user_regime_table
                        GROUP BY
                            userId
                        ORDER BY
                            userId
                            
                        """)

In [None]:
user_log.select("auth").groupby("auth").count().show()

In [None]:
user_log.select("gender").groupby("gender").count().show()

In [None]:
user_log.select("level").groupby("level").count().orderBy("count", ascending=False).show()


In [None]:
user_log.select("page").groupby("page").count().orderBy("count", ascending=False).show()

In [None]:
user_log.select("userDevice").groupby("userDevice").count().orderBy("count", ascending=False).show()

In [None]:
spark.sql('''
        SELECT artist, count(*) as count
        FROM user_log_table
        WHERE artist is not null
        GROUP BY artist
        ORDER BY count(*) DESC
        ''').show()

In [None]:
spark.sql('''
        SELECT song, count(*) as count
        FROM user_log_table
        WHERE song is not null
        GROUP BY song
        ORDER BY count(*) DESC
        ''').show()

In [None]:
spark.sql('''
        SELECT artist, song, count(*) as count
        FROM user_log_table
        WHERE
            song is not null
            AND artist is not null
        GROUP BY artist, song
        ORDER BY artist, count(*) DESC
        ''').show()

In [None]:
df_membership_days = spark.sql('''
                                SELECT userId, max(membership_days)
                                FROM user_log_table
                                GROUP BY userId
                                ORDER BY max(membership_days) DESC
                                ''').toPandas()

In [None]:
df_membership_days.plot(kind='hist', title='Membership in Days', legend=False, figsize=(10,6));

In [None]:
spark.sql('''
            SELECT userId, max(membership_days)
            FROM user_log_table
            WHERE page = 'Cancellation Confirmation'
            GROUP BY userId
            ORDER BY max(membership_days) DESC
            ''').show()

In [None]:
spark.sql('''
            with pretab as(
            SELECT userId, max(membership_days)
            FROM user_log_table
            WHERE page = 'Cancellation Confirmation'
            GROUP BY userId
            ORDER BY max(membership_days) DESC
            )
            SELECT pretab.userId, count(*) records, max(membership_days)
            FROM
                user_log_table
                INNER JOIN pretab on pretab.userId = user_log_table.userId
            GROUP BY pretab.userId
            ORDER BY max(membership_days) desc
            ''').show()

In [None]:
df_cancelled_userId = user_log.select('userId').filter(user_log.page == 'Cancellation Confirmation').drop_duplicates()

In [None]:
df_cancelled_userId.createOrReplaceTempView("cancelled_users_table")

In [None]:
spark.sql('''
            with is_home as(
            SELECT
                userID, page, ts,
                CASE WHEN page = 'Home' THEN 1 ELSE 0 END AS is_home                
            FROM user_log_table
            WHERE
                (page = 'NextSong') or (page = 'Home')
            ),
            cum_sum as (
            SELECT *, sum(is_home) OVER(PARTITION BY userID ORDER BY ts DESC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as period
            FROM is_home
            )
            SELECT AVG(count_results)
            FROM (
                SELECT COUNT(*) AS count_results FROM cum_sum
                GROUP BY userID, period, page HAVING page = 'NextSong'
            ) as counts
                
            ''').show()

In [None]:
spark.sql('''
            with is_home as(
            SELECT
                userId, auth, page, ts,
                CASE WHEN page = 'Home' THEN 1 ELSE 0 END AS is_home                
            FROM user_log_table
            WHERE
                (page = 'NextSong') or (page = 'Home')
            ),
            cum_sum as (
            SELECT *, sum(is_home) OVER(PARTITION BY userID ORDER BY ts DESC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as period
            FROM is_home
            )
            SELECT "cancelled_users" as user_state, AVG(count_results) as 
            FROM (
                SELECT COUNT(*) AS count_results FROM cum_sum WHERE userId in (select userId from cancelled_users_table)
                GROUP BY userID, period, page HAVING page = 'NextSong'
            ) as counts
            
            UNION
            
            SELECT "active_users" as user_state, AVG(count_results)
            FROM (
                SELECT COUNT(*) AS count_results FROM cum_sum WHERE userId not in (select userId from cancelled_users_table)
                GROUP BY userID, period, page HAVING page = 'NextSong'
            ) as counts
            
            
            ''').show()

In [None]:
spark.sql('''
            with is_home as(
            SELECT
                userId, auth, page, ts,
                CASE WHEN page = 'Home' THEN 1 ELSE 0 END AS is_home                
            FROM user_log_table
            WHERE
                (page = 'NextSong') or (page = 'Home')
            ),
            cum_sum as (
            SELECT *, sum(is_home) OVER(PARTITION BY userID ORDER BY ts DESC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as period
            FROM is_home
            )
            
            select userId,  from cum_sum
            
                
            ''').show()

In [None]:
spark.sql('''
            SELECT userId, auth, ts, ts_iso, page, sessionId
            FROM user_log_table
            WHERE (page = 'Home' or page = 'NextSong')
                and userId=10
            ORDER BY userId, ts asc
            ''').show(1000)

### Todo

To understand how users interact with the platform, need to know:

* Songs per session
* Session count
* Session frequency
    * Sessions over membership days
    * Avg time between sessions

Need to differentiate between cancelled and active users (makes sense to add `cancelled` indicator to user_log dataframe)

# Feature Engineering

* User inactivity: if a user only listens to a handful of songs per month or no songs at all, a cancellation might be likely
    * The ratio of songs listend over time being a paying client might be a good predictor
* Number of thumbs up/down they gave before cancelling

## Create New Features from Aggregate User Values

## Transform Fatures for Pipeline Usage

Split the user_feature data into training and validation dataset

In [None]:
rest, validation = user_features.randomSplit([0.8, 0.2], seed=42)

Combine all features into a vector

In [None]:
inputCols = user_features.columns
inputCols.remove('fully_cancelled')

In [None]:
assembler = VectorAssembler(inputCols= inputCols, outputCol='inputFeatures')

Normalize them

In [None]:
scaler = Normalizer(inputCol="inputFeatures", outputCol="features")

# Modeling

In version 3.3.2 spark supports the following classification algorithms:

* Logistic regression
    * Binomial logistic regression
    * Multinomial logistic regression
* Decision tree classifier
* Random forest classifier
* Gradient-boosted tree classifier
* Multilayer perceptron classifier
* Linear Support Vector Machine
* One-vs-Rest classifier (a.k.a. One-vs-All)
* Naive Bayes
* Factorization machines classifier

We'll choose the four most popular:

1. [Logistic regression](#lr)
2. [Random forest classifier](#rf)
3. [Gradient boosted tree classifier](#gbt)
4. [Linear Support Vector Machine](#lsv)

To **evaluate the models**, we compute the **accuracy** and the **F1-score** for each model.

* Accuracy is defined as $$\frac{\text{# correct predictions}}{\text{# total predictions}}$$
* The F1-score combines the measure of accuracy and recall $$\frac{1}{N} \sum_{i=0}^{N-1}2\frac{|P_i\cap L_i|}{|P_i|\cdot|L_i|},$$
where $L_0, L_1, ..., L_{N-1}$ are the label sets and $P_0, P_1, ..., P_{N-1}$ the prediction sets


Details: [Spark Documentation](https://spark.apache.org/docs/2.2.0/mllib-evaluation-metrics.html#multilabel-classification)

## Logistist Regression <a id='lr'></a>

In [None]:
# Initialize model and pipeline
lr           =  LogisticRegression(maxIter=10, regParam=0.0, elasticNetParam=0, labelCol='fully_cancelled')
pipeline_lr  = Pipeline(stages=[assembler, scaler, lr ])

In [None]:
# Set up parameter grid
paramgrid =ParamGridBuilder()\
.addGrid(lr.regParam, [0.0, 0,1])\
.addGrid(lr.maxIter, [10])\
.build()

# Choose f1-score as evaluation metric
evaluator = MulticlassClassificationEvaluator(metricName="f1", labelCol='fully_cancelled')

# Set up cross validator
crossval = CrossValidator(estimator=pipeline_lr,
                          estimatorParamMaps=paramgrid,
                          evaluator = evaluator , 
                          numFolds=3)


In [None]:
# Fit the model
cvModel_ls = crossval.fit(rest)

In [None]:
# Apply model to validation set
results_ls = cvModel_ls.transform(validation)

### Accuracy of Logistic Regression Model

In [None]:
ls_TP = (results_ls.filter(results_ls.fully_cancelled == results_ls.prediction).count())
ls_totalNoPredictions = (results_ls.count())

print('Number of Correct Predictions: {}'.format(ls_TP))
print('Total Number of Predictions {}'.format(ls_totalNoPredictions))
print('Accuracy {:.4f}'.format(ls_TP/ls_totalNoPredictions))

In [None]:
print('F1-score: {:.4f}'.format(evaluator.evaluate(cvModel_ls.transform(validation))))

## Random Forest Classifier <a id='rf'></a>

In [None]:
# Initialize model and pipeline
rf           = RandomForestClassifier(labelCol='fully_cancelled')
pipeline_rf  = Pipeline(stages=[assembler, scaler, rf])

In [None]:
paramgrid_rf = ParamGridBuilder()\
.build()

evaluator = MulticlassClassificationEvaluator(metricName="f1", labelCol='fully_cancelled')

crossval = CrossValidator(estimator=pipeline_rf,  
                          estimatorParamMaps=paramgrid_rf,
                          evaluator=evaluator, 
                          numFolds=3)

In [None]:
cvModel_rf = crossval.fit(rest)

In [None]:
results_rf = cvModel_rf.transform(validation)

### Accuracy of Random Forest Classifier Model

In [None]:
rf_TP = (results_rf.filter(results_rf.fully_cancelled == results_rf.prediction).count())
rf_totalNoPredictions = (results_rf.count())

print('Number of Correct Predictions: {}'.format(rf_TP))
print('Total Number of Predictions: {}'.format(rf_totalNoPredictions))
print('Accuracy {:.4f}'.format(rf_TP/rf_totalNoPredictions))


In [None]:
print('F1-score: {:.4f}'.format(evaluator.evaluate(cvModel_rf.transform(validation))))

## Gradient Boosted Tree Classifier <a id='gbt'></a>

In [None]:
# Initialize model and pipeline
gbt          = GBTClassifier(labelCol='fully_cancelled')
pipeline_gbt = Pipeline(stages=[assembler, scaler, gbt])

In [None]:
paramgrid_gbt =ParamGridBuilder()\
.build()

evaluator = MulticlassClassificationEvaluator(metricName="f1", labelCol='fully_cancelled')

crossval = CrossValidator(estimator=pipeline_gbt,
                          estimatorParamMaps=paramgrid_gbt,
                          evaluator=evaluator, 
                          numFolds=3)

In [None]:
cvModel_gbt = crossval.fit(rest)

In [None]:
results_gbt = cvModel_gbt.transform(validation)

### Accuracy of Gradient Boosted Tree Classifier Model

In [None]:
gbt_TP = (results_gbt.filter(results_gbt.fully_cancelled == results_gbt.prediction).count())
gbt_totalNoPredictions = (results_gbt.count())


print('Number of Correct Predictions: {}'.format(gbt_TP))
print('Total Number of Predictions {}'.format(gbt_totalNoPredictions))
print('Accuracy {:.4f}'.format(gbt_TP/gbt_totalNoPredictions))

In [None]:
print('F1-score: {:.4f}'.format(evaluator.evaluate(cvModel_gbt.transform(validation))))

## Linear Support Vector Machine <a id='lsv'></a>

In [None]:
svm = LinearSVC(labelCol='fully_cancelled')
pipeline_svm = Pipeline(stages=[assembler, scaler, svm])

In [None]:
paramgrid_svm =ParamGridBuilder()\
.build()

evaluator = MulticlassClassificationEvaluator(metricName="f1", labelCol='fully_cancelled')

crossval = CrossValidator(estimator=pipeline_svm,  
                          estimatorParamMaps=paramgrid_svm,
                          evaluator=evaluator, 
                          numFolds=3)

In [None]:
cvModel_svm=crossval.fit(rest)

In [None]:
results_svm = cvModel_svm.transform(validation)

### Accuracy of Linear Support Vector Machine Model

In [None]:
smv_TP = (results_svm.filter(results_svm.fully_cancelled == results_svm.prediction).count())
smv_totalNoPredictions = (results_svm.count())

print('Number of Correct Predictions: {}'.format(smv_TP))
print('Total Number of Predictions {}'.format(smv_totalNoPredictions))
print('Accuracy {:.4f}'.format(smv_TP/smv_totalNoPredictions))

In [None]:
print('F1-score: {:.4f}'.format(evaluator.evaluate(cvModel_svm.transform(validation))))

## Summary 