# Sparkify Project Workspace
This workspace contains a tiny subset (128MB) of the full dataset available (12GB).

In [23]:
# import libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import avg, col, concat, count, desc, explode, lit, min, max, sum, split, stddev, udf, when, lag, isnull

from pyspark.sql.types import IntegerType
from pyspark.ml.feature import RegexTokenizer, VectorAssembler, StandardScaler
from pyspark.ml.regression import LinearRegression
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, GBTClassifier, RandomForestClassifier
from pyspark.ml.clustering import KMeans
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

from pyspark.sql import Window

import datetime

import os

# Set spark environments
os.environ['PYSPARK_PYTHON'] = '/Users/stillqe/anaconda/envs/Sparkify/bin/python3'
os.environ['PYSPARK_DRIVER_PYTHON'] = '/Users/stillqe/anaconda/envs/Sparkify/bin/python3'

In [2]:
# create a Spark session
spark = SparkSession.builder \
    .master("local") \
    .appName("Sparkify") \
    .getOrCreate()

# Load and Clean Dataset
In this workspace, the mini-dataset file is `mini_sparkify_event_data.json`. Load and clean the dataset, checking for invalid or missing data - for example, records without userids or sessionids. 

In [3]:
sample = 'mini_sparkify_event_data.json'
df = spark.read.json(sample)
df.printSchema()
print((df.count(), len(df.columns)))

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)

(286500, 18)


In [5]:
def count_missings(spark_df):
    spark_df.select([count(when(isnull(c), c)).alias(c) for c in spark_df.columns]).show()

count_missings(df)

In [10]:
df = df.where(df.userId != "")
count_missings(df)

+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+------+--------+----------+---+--------+------+-----+----+---------+-------+-------+---+
|artist|auth|firstName|gender|itemInSession|lastName|length|level|location|method|page|registration|sessionId| song|status| ts|userAgent|userId|cancel|thumpsup|thumpsdown| ad|playlist|friend|error|home|downgrade|upgrade|setting|day|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+------+--------+----------+---+--------+------+-----+----+---------+-------+-------+---+
| 50046|   0|        0|     0|            0|       0| 50046|    0|       0|     0|   0|           0|        0|50046|     0|  0|        0|     0|     0|       0|         0|  0|       0|     0|    0|   0|        0|      0|      0|  0|
+------+----+---------+------+-------------+--------+------+-----+--

# Exploratory Data Analysis

### Define Churn

Churn is more like dynamic feature rather than static. Some users are stick to the service from the begining but they can be likely to churn at certain time period. So focus on behavioral events for certain period, not just entire service life time. 

User characteristics change over time. The probability of a user churn also changes over time. Therefore, it is more appropriate to view the user's characteristics in multiple time windows rather than consolidating them with one input without considering changes over time.

You need to know in advance when predicting churn. That way, you can act ahead. It is practically meaningless to make predictions by looking at the records up to the moment of churn. After a certain period of time, for example, by looking at the records up to 4 weeks ago, you have to learn whether it is churn or not after 4 weeks. 

### Explore Data
Once you've defined churn, perform some exploratory data analysis to observe the behavior for users who stayed vs users who churned. You can start by exploring aggregates on these two groups of users, observing how much of a specific action they experienced per a certain time unit or number of songs played.

In [4]:
df.select('page').dropDuplicates().show(30)

+--------------------+
|                page|
+--------------------+
|              Cancel|
|    Submit Downgrade|
|         Thumbs Down|
|                Home|
|           Downgrade|
|         Roll Advert|
|              Logout|
|       Save Settings|
|Cancellation Conf...|
|               About|
| Submit Registration|
|            Settings|
|               Login|
|            Register|
|     Add to Playlist|
|          Add Friend|
|            NextSong|
|           Thumbs Up|
|                Help|
|             Upgrade|
|               Error|
|      Submit Upgrade|
+--------------------+



In [6]:
df.where((df.page=='Cancel') | (df.page=='Cancellation Confirmation')).orderBy('userID', 'ts').show()

+------+---------+---------+------+-------------+--------+------+-----+--------------------+------+--------------------+-------------+---------+----+------+-------------+--------------------+------+
|artist|     auth|firstName|gender|itemInSession|lastName|length|level|            location|method|                page| registration|sessionId|song|status|           ts|           userAgent|userId|
+------+---------+---------+------+-------------+--------+------+-----+--------------------+------+--------------------+-------------+---------+----+------+-------------+--------------------+------+
|  null|Logged In|  Delaney|     F|           22|   Perez|  null| free|Miami-Fort Lauder...|   PUT|              Cancel|1534627466000|       53|null|   307|1538498074000|"Mozilla/5.0 (Mac...|100001|
|  null|Cancelled|  Delaney|     F|           23|   Perez|  null| free|Miami-Fort Lauder...|   GET|Cancellation Conf...|1534627466000|       53|null|   200|1538498205000|"Mozilla/5.0 (Mac...|100001|
|  nu

In [7]:
df.where(df.page=='NextSong').groupBy('gender').agg({'song':'count'}).show()

+------+-----------+
|gender|count(song)|
+------+-----------+
|     F|     126696|
|     M|     101412|
+------+-----------+



In [None]:
df.filter(df.page == 'NextSong') \
    .select('Artist') \
    .groupBy('Artist') \
    .agg({'Artist':'count'}) \
    .withColumnRenamed('count(Artist)', 'Artistcount') \
    .sort(desc('Artistcount')) \
    .show(5)

In [8]:
df.select('page').dropDuplicates().show(30)

+--------------------+
|                page|
+--------------------+
|              Cancel|
|    Submit Downgrade|
|         Thumbs Down|
|                Home|
|           Downgrade|
|         Roll Advert|
|              Logout|
|       Save Settings|
|Cancellation Conf...|
|               About|
| Submit Registration|
|            Settings|
|               Login|
|            Register|
|     Add to Playlist|
|          Add Friend|
|            NextSong|
|           Thumbs Up|
|                Help|
|             Upgrade|
|               Error|
|      Submit Upgrade|
+--------------------+



In [9]:
df = df.withColumn('cancel', when(df.page == 'Cancellation Confirmation', 1).otherwise(0)) \
    .withColumn('thumpsup', when(df.page == 'Thumbs Up', 1).otherwise(0)) \
    .withColumn('thumpsdown', when(df.page == 'Thumbs Down', 1).otherwise(0)) \
    .withColumn('ad', when(df.page == 'Roll Advert', 1).otherwise(0)) \
    .withColumn('playlist', when(df.page == 'Add to Playlist', 1).otherwise(0)) \
    .withColumn('friend', when(df.page == 'Add Friend', 1).otherwise(0)) \
    .withColumn('error', when(df.page == 'Error', 1).otherwise(0)) \
    .withColumn('home', when(df.page == 'Home', 1).otherwise(0)) \
    .withColumn('downgrade', when(df.page == 'Downgrade', 1).otherwise(0)) \
    .withColumn('upgrade', when(df.page == 'Upgrade', 1).otherwise(0)) \
    .withColumn('setting', when(df.page == 'Save Settings', 1).otherwise(0)) \
    .withColumn('ts', col('ts')/1000) \
    .withColumn('registration', col('registration')/1000) \
    .withColumn('day', ((col('ts')-col('registration'))/(60*60*24)).cast(IntegerType()))

### Aggregate

How to aggregate historical data into a single row?


In [11]:


sessions = df.groupBy('userId', 'sessionId').agg(count('song').alias('play'),
                                                 max('cancel').alias('churn'),
                                            min('day').alias('day'),
                                            min('ts').alias('start'), 
                                            max('ts').alias('end'),
                                           max('cancel').alias('cancel'),
                                           sum('thumpsup').alias('thumpsup'),
                                           sum('thumpsdown').alias('thumpsdown'),
                                           sum('ad').alias('ads'),
                                           sum('playlist').alias('playlists'),
                                           sum('friend').alias('friends'),
                                           sum('error').alias('errors'),
                                           sum('setting').alias('setting')) \
    .withColumn('duration', (col('end') - col('start')).cast(IntegerType()))


In [12]:
windowSpec = Window.partitionBy('userId').orderBy('start')
userWindow = Window.partitionBy('userId')
sessions = sessions.withColumn('gap', col('day') - lag('day', 1, 0).over(windowSpec))
sessions = sessions.withColumn('churn', max('cancel').over(userWindow))
sessions = sessions.withColumn('last', max('day').over(userWindow))
#It is practically not useful to predict churning with entire records including right before churning. 
# Exclude the record from 30 days before last sessions.
sessions = sessions.withColumn('first', min('day').over(userWindow)) \
                   .withColumn('last', max('day').over(userWindow))

It is practically not useful to predict churning with data 

In [13]:
sessions = sessions.where(sessions.day < (sessions.last-30))
sessions = sessions.withColumn('elast', max('day').over(userWindow))
sessions = sessions.withColumn('span', col('elast')-col('first'))


In [14]:
sessions = sessions.withColumn('quater', when(col('day') <= (col('first')+col('span')/4), 1) \
                    .when((col('day')>(col('first')+col('span')/4)) & (col('day')<=(col('first')+col('span')/2)), 2) \
                    .when((col('day')>(col('first')+col('span')/2)) & (col('day')<=(col('first')+col('span')*3/4)), 3) \
                    .otherwise(4))
sessions.show()

+------+---------+----+-----+---+-------------+-------------+------+--------+----------+---+---------+-------+------+-------+--------+---+----+-----+-----+----+------+
|userId|sessionId|play|churn|day|        start|          end|cancel|thumpsup|thumpsdown|ads|playlists|friends|errors|setting|duration|gap|last|first|elast|span|quater|
+------+---------+----+-----+---+-------------+-------------+------+--------+----------+---+---------+-------+------+-------+--------+---+----+-----+-----+----+------+
|100010|       31|  31|    0| 11|1.539003534E9|1.539010247E9|     0|       2|         0|  7|        0|      0|     0|      0|    6713| 11|  55|   11|   23|  12|     1|
|100010|       78|   7|    0| 18|1.539603322E9|1.539604675E9|     0|       0|         0|  0|        0|      0|     0|      0|    1353|  7|  55|   11|   23|  12|     3|
|100010|      113|  38|    0| 19|1.539699018E9|1.539707924E9|     0|       2|         1|  7|        1|      1|     0|      0|    8906|  1|  55|   11|   23|  12|

In [15]:
count_missings(sessions)

+------+---------+----+-----+---+-----+---+------+--------+----------+---+---------+-------+------+-------+--------+---+----+-----+-----+----+------+
|userId|sessionId|play|churn|day|start|end|cancel|thumpsup|thumpsdown|ads|playlists|friends|errors|setting|duration|gap|last|first|elast|span|quater|
+------+---------+----+-----+---+-----+---+------+--------+----------+---+---------+-------+------+-------+--------+---+----+-----+-----+----+------+
|     0|        0|   0|    0|  0|    0|  0|     0|       0|         0|  0|        0|      0|     0|      0|       0|  0|   0|    0|    0|   0|     0|
+------+---------+----+-----+---+-----+---+------+--------+----------+---+---------+-------+------+-------+--------+---+----+-----+-----+----+------+



### How to simulate setup

As mentioned before, it is not practically useful to predict churning with all users records inclduing data right before users churn. If we predict users churn within 1 month, we should use users records from the begining to the some point before churn. 

# Feature Engineering
Once you've familiarized yourself with the data, build out the features you find promising to train your model on. To work with the full dataset, you can follow the following steps.
- Write a script to extract the necessary features from the smaller subset of data
- Ensure that your script is scalable, using the best practices discussed in Lesson 3
- Try your script on the full data set, debugging your script if necessary

If you are working in the classroom workspace, you can just extract features based on the small subset of data contained here. Be sure to transfer over this work to the larger dataset when you work on your Spark cluster.

In [16]:
sessions.printSchema()

root
 |-- userId: string (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- play: long (nullable = false)
 |-- churn: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- start: double (nullable = true)
 |-- end: double (nullable = true)
 |-- cancel: integer (nullable = true)
 |-- thumpsup: long (nullable = true)
 |-- thumpsdown: long (nullable = true)
 |-- ads: long (nullable = true)
 |-- playlists: long (nullable = true)
 |-- friends: long (nullable = true)
 |-- errors: long (nullable = true)
 |-- setting: long (nullable = true)
 |-- duration: integer (nullable = true)
 |-- gap: integer (nullable = true)
 |-- last: integer (nullable = true)
 |-- first: integer (nullable = true)
 |-- elast: integer (nullable = true)
 |-- span: integer (nullable = true)
 |-- quater: integer (nullable = false)



In [17]:
users = sessions.groupBy('userId').agg(count('sessionId').alias('total_session'),
                                       avg('play').alias('avg_play'),
                                       when(count('play') > 1, stddev('play')).otherwise(0).alias('std_play'),
                                       avg('thumpsup').alias('avg_up'),
                                       avg('thumpsdown').alias('avg_down'),
                                       avg('ads').alias('avg_ads'),
                                       avg('setting').alias('avg_setting'),
                                       avg('playlists').alias('avg_playlists'),
                                       avg('friends').alias('avg_friends'),
                                       avg('duration').alias('avg_duration'),
                                       when(count('duration') > 1, stddev('duration')).otherwise(0).alias('std_duration'), 
                                       avg('gap').alias('avg_gap'),
                                       when(count('gap') > 1, stddev('gap')).otherwise(0).alias('std_gap'),
                                       max('day').alias('life_time'),
                                       max('churn').alias('label'))

recent = sessions.where(sessions.day>(sessions.elast-14)).groupBy('userId') \
                                  .agg(when(count('sessionId')>0, avg('play')).otherwise(0).alias('recent_play'),
                                       when(count('play') > 1, stddev('play')).otherwise(0).alias('recent_std_play'),
                                       when(count('sessionId')>0,avg('thumpsup')).otherwise(0).alias('recent_up'),
                                       when(count('sessionId')>0,avg('thumpsdown')).otherwise(0).alias('recent_down'),
                                       when(count('sessionId')>0,avg('ads')).otherwise(0).alias('recent_ads'),
                                       when(count('sessionId')>0,avg('setting')).otherwise(0).alias('recent_setting'),
                                       when(count('sessionId')>0,avg('playlists')).otherwise(0).alias('recent_playlists'),
                                       when(count('sessionId')>0,avg('friends')).otherwise(0).alias('recent_friends'),
                                       when(count('sessionId')>0,avg('duration')).otherwise(0).alias('recent_duration'),
                                       when(count('sessionId')>0,avg('gap')).otherwise(0).alias('recent_gap'),
                                       when(count('gap') > 1, stddev('gap')).otherwise(0).alias('recent_std_gap'))




users = users.join(recent, on='userId')

users.show()
                                                                           

+------+-------------+------------------+------------------+------------------+-------------------+-------------------+--------------------+------------------+------------------+------------------+------------------+------------------+------------------+---------+-----+------------------+------------------+------------------+-------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+
|userId|total_session|          avg_play|          std_play|            avg_up|           avg_down|            avg_ads|         avg_setting|     avg_playlists|       avg_friends|      avg_duration|      std_duration|           avg_gap|           std_gap|life_time|label|       recent_play|   recent_std_play|         recent_up|        recent_down|        recent_ads|    recent_setting|  recent_playlists|    recent_friends|   recent_duration|        recent_gap|    recent_std_gap|
+------+-------------+----------------

In [18]:
count_missings(users)


+------+-------------+--------+--------+------+--------+-------+-----------+-------------+-----------+------------+------------+-------+-------+---------+-----+-----------+---------------+---------+-----------+----------+--------------+----------------+--------------+---------------+----------+--------------+
|userId|total_session|avg_play|std_play|avg_up|avg_down|avg_ads|avg_setting|avg_playlists|avg_friends|avg_duration|std_duration|avg_gap|std_gap|life_time|label|recent_play|recent_std_play|recent_up|recent_down|recent_ads|recent_setting|recent_playlists|recent_friends|recent_duration|recent_gap|recent_std_gap|
+------+-------------+--------+--------+------+--------+-------+-----------+-------------+-----------+------------+------------+-------+-------+---------+-----+-----------+---------------+---------+-----------+----------+--------------+----------------+--------------+---------------+----------+--------------+
|     0|            0|       0|       0|     0|       0|      0|   

# Modeling
Split the full dataset into train, test, and validation sets. Test out several of the machine learning methods you learned. Evaluate the accuracy of the various models, tuning parameters as necessary. Determine your winning model based on test accuracy and report results on the validation set. Since the churned users are a fairly small subset, I suggest using F1 score as the metric to optimize.

### Build a pipeline

In [19]:
# Split the data into train and test subsets
train, test = users.randomSplit([0.8, 0.2], seed=123)

In [20]:



assembler = VectorAssembler(inputCols=["total_session",
                                       "avg_play",
                                       "std_play",
                                       "avg_up",
                                       "avg_down",
                                       "avg_ads",
                                       "avg_setting",
                                       "avg_playlists",
                                       "avg_friends",
                                       "avg_duration",
                                       "std_duration",
                                       "avg_gap",
                                       "std_gap",
                                       "life_time",
                                       "recent_play",
                                       "recent_std_play",
                                       "recent_up",
                                       "recent_down",
                                       "recent_ads",
                                       "recent_setting",
                                       "recent_playlists",
                                       "recent_friends",
                                       "recent_duration",
                                       "recent_gap",
                                       "recent_std_gap"
                                      ], outputCol="NumFeatures")

scaler = StandardScaler(inputCol="NumFeatures", outputCol="features", withStd=True)
#gbt = GBTClassifier(featuresCol="features", maxIter=10)
rf = RandomForestClassifier(featuresCol="features")


pipeline = Pipeline(stages=[assembler, scaler, rf])

## baseline classifier

In [28]:
SEED=123

def check_baseline():
    lr = LogisticRegression()
    dt = DecisionTreeClassifier(seed=SEED)
    rf = RandomForestClassifier(featuresCol="features")
    gbt = GBTClassifier(featuresCol="features", maxIter=10)
    
    for clf in [lr, dt, rf, gbt]:
        clf_name = clf.__class__.__name__
        pipeline = Pipeline(stages=[assembler, scaler, clf])
        model = pipeline.fit(train)
        pred_test = model.transform(test)
        f1_score_evaluator = MulticlassClassificationEvaluator(metricName='f1')
        f1_score = f1_score_evaluator.evaluate(pred_test.select('label','prediction'),{f1_score_evaluator.metricName: 'f1'})

        print('{} : The F1 score on the test set is {:.2%}'.format(clf_name, f1_score))

In [29]:
check_baseline()

LogisticRegression : The F1 score on the test set is 84.61%
DecisionTreeClassifier : The F1 score on the test set is 76.01%
RandomForestClassifier : The F1 score on the test set is 87.67%
GBTClassifier : The F1 score on the test set is 76.01%


In [31]:
clf = RandomForestClassifier(featuresCol="features")
pipeline = Pipeline(stages=[assembler, scaler, clf])

paramGrid = ParamGridBuilder() \
            .addGrid(clf.maxDepth, [3, 5, 10]) \
            .addGrid(clf.numTrees, [2, 4, 8]) \
            .addGrid(clf.maxBins, [16, 32, 64]).build()

crossval = CrossValidator(estimator=pipeline,
                         estimatorParamMaps=paramGrid,
                         evaluator=MulticlassClassificationEvaluator(metricName='f1'),
                         numFolds=3)

cvModel = crossval.fit(train)
cvModel.avgMetrics

[0.8345543082129314,
 0.8067056796776213,
 0.8000877953292405,
 0.8216518587564561,
 0.818416241077041,
 0.818416241077041,
 0.8339838355723397,
 0.8369227646265922,
 0.8339838355723397,
 0.7977634458374352,
 0.8282697431048369,
 0.8499090530401386,
 0.8268118174767862,
 0.8140340396990084,
 0.8268118174767862,
 0.827157949864783,
 0.8339838355723397,
 0.818566145506755,
 0.811652334726324,
 0.8282697431048369,
 0.8248246489081155,
 0.8305040895380005,
 0.818416241077041,
 0.818416241077041,
 0.8334164532661434,
 0.8339838355723397,
 0.818566145506755]

In [32]:
pred_test = cvModel.transform(test)

f1_score_evaluator = MulticlassClassificationEvaluator(metricName='f1')
f1_score = f1_score_evaluator.evaluate(pred_test.select('label','prediction'),{f1_score_evaluator.metricName: 'f1'})

print('The F1 score on the test set is {:.2%}'.format(f1_score))

The F1 score on the test set is 79.46%


In [33]:
inputCols=["total_session", "avg_play","std_play","avg_up","avg_down","avg_ads","avg_setting","avg_playlists",
           "avg_friends","avg_duration","std_duration","avg_gap","std_gap","life_time","recent_play",
                                       "recent_std_play",
                                       "recent_up",
                                       "recent_down",
                                       "recent_ads",
                                       "recent_setting",
                                       "recent_playlists",
                                       "recent_friends",
                                       "recent_duration",
                                       "recent_gap",
                                       "recent_std_gap"]
# print feature importances
for i in range(len(model.stages[-1].featureImportances)):
    print("{} : {} \n".format(inputCols[i], model.stages[-1].featureImportances[i]))


NameError: name 'model' is not defined

In [None]:
pred_test.select("prediction", "label", "features").show(100)

In [None]:
accuracy = evaluator.evaluate(predictions)

# Final Steps
Clean up your code, adding comments and renaming variables to make the code easier to read and maintain. Refer to the Spark Project Overview page and Data Scientist Capstone Project Rubric to make sure you are including all components of the capstone project and meet all expectations. Remember, this includes thorough documentation in a README file in a Github repository, as well as a web app or blog post.