In [1]:
from pyspark.sql import SparkSession

Initialize a spark session

In [2]:
spark = SparkSession.builder.master("local[*]").appName("sparkify").getOrCreate()
spark

Read the data

In [3]:
data = spark.read.json("../data/mini_sparkify_event_data.json")

Look at the columns

In [4]:
data.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)



Look at the data, we will do a vertical show

In [5]:
data.show(vertical=True, n=2)

-RECORD 0-----------------------------
 artist        | Martha Tilston       
 auth          | Logged In            
 firstName     | Colin                
 gender        | M                    
 itemInSession | 50                   
 lastName      | Freeman              
 length        | 277.89016            
 level         | paid                 
 location      | Bakersfield, CA      
 method        | PUT                  
 page          | NextSong             
 registration  | 1538173362000        
 sessionId     | 29                   
 song          | Rockpools            
 status        | 200                  
 ts            | 1538352117000        
 userAgent     | Mozilla/5.0 (Wind... 
 userId        | 30                   
-RECORD 1-----------------------------
 artist        | Five Iron Frenzy     
 auth          | Logged In            
 firstName     | Micah                
 gender        | M                    
 itemInSession | 79                   
 lastName      | Long    

The most important column in the data set seems to be the `Page` column. It holds all Sparkify pages that the customers have visited. 

In [6]:
data.select("page").dropDuplicates().show()

+--------------------+
|                page|
+--------------------+
|              Cancel|
|    Submit Downgrade|
|         Thumbs Down|
|                Home|
|           Downgrade|
|         Roll Advert|
|              Logout|
|       Save Settings|
|Cancellation Conf...|
|               About|
| Submit Registration|
|            Settings|
|               Login|
|            Register|
|     Add to Playlist|
|          Add Friend|
|            NextSong|
|           Thumbs Up|
|                Help|
|             Upgrade|
+--------------------+
only showing top 20 rows



From the list above we can use the `Cancellation Confirmation` to indicate if a particular user churned or not. We are building a model that predicts if a user is going to churn given the history of interactions with the service.

Therfore, it is important to understand the customers that reached the `Cancellation Confirmation` page. These are the customers that churned, and the task is to build a prediction model that can recognize them.

## Data cleaning

Now that we indentified our "goal", we can let go of some of the columns that are not needed for further analysis.

In [7]:
data = data.drop(*['firstName', 'lastName', 'id_copy'])
data.show(vertical=True, n=2)

-RECORD 0-----------------------------
 artist        | Martha Tilston       
 auth          | Logged In            
 gender        | M                    
 itemInSession | 50                   
 length        | 277.89016            
 level         | paid                 
 location      | Bakersfield, CA      
 method        | PUT                  
 page          | NextSong             
 registration  | 1538173362000        
 sessionId     | 29                   
 song          | Rockpools            
 status        | 200                  
 ts            | 1538352117000        
 userAgent     | Mozilla/5.0 (Wind... 
 userId        | 30                   
-RECORD 1-----------------------------
 artist        | Five Iron Frenzy     
 auth          | Logged In            
 gender        | M                    
 itemInSession | 79                   
 length        | 236.09424            
 level         | free                 
 location      | Boston-Cambridge-... 
 method        | PUT     

Since we are building a model that focuses on a user, remove any null/na values in the userID

In [8]:
from pyspark.sql.functions import isnan, isnull

In [9]:
data.filter((isnan(data['userId'])) | (data['userId'].isNull()) | (data['userId'] == "")).count()

8346

In [10]:
data = data.dropna(how = 'any', subset = ['userId'])
data = data.filter(data["userId"] != "") 
data = data.filter(data['userId'].isNotNull())
data.filter((isnan(data['userId'])) | (data['userId'].isNull()) | (data['userId'] == "")).count()

0

The times in `registration` and `ts` column are given in milliseconds, we will convert those to seconds.

In [11]:
from pyspark.sql.types import DoubleType
import pyspark.sql.functions as F

In [12]:
time_unit_udf = F.udf(lambda x: float(x/1000), DoubleType())
data = data.withColumn("registration", time_unit_udf("registration")). \
    withColumn("ts", time_unit_udf("ts"))

data.select('registration', 'ts').show()

+-------------+-------------+
| registration|           ts|
+-------------+-------------+
|1.538173362E9|1.538352117E9|
| 1.53833163E9| 1.53835218E9|
|1.538173362E9|1.538352394E9|
| 1.53833163E9|1.538352416E9|
|1.538173362E9|1.538352676E9|
| 1.53833163E9|1.538352678E9|
| 1.53833163E9|1.538352886E9|
|1.538173362E9|1.538352899E9|
|1.538173362E9|1.538352905E9|
|1.538173362E9|1.538353084E9|
| 1.53833163E9|1.538353146E9|
| 1.53833163E9| 1.53835315E9|
|1.538173362E9|1.538353218E9|
| 1.53833163E9|1.538353375E9|
| 1.53833163E9|1.538353376E9|
|1.538173362E9|1.538353441E9|
| 1.53833163E9|1.538353576E9|
|1.537365219E9|1.538353668E9|
|1.538173362E9|1.538353687E9|
| 1.53833163E9|1.538353744E9|
+-------------+-------------+
only showing top 20 rows



#### Label

As mentioned above we are deriving the label `churn` as a function of the `Page` column. There are only two values to churn, churned or not churned. Therefore this is going to be a boolean column represented by 1 for churned and 0 for not churned.

In [13]:
from pyspark.sql.types import IntegerType

In [14]:
cancelation_event_udf = F.udf(lambda x: 1 if x == "Cancellation Confirmation" else 0, IntegerType())
data = data.withColumn("churn", cancelation_event_udf("page"))
data.filter(data['userId'] == 125).describe('churn').show()

+-------+-------------------+
|summary|              churn|
+-------+-------------------+
|  count|                 11|
|   mean|0.09090909090909091|
| stddev|0.30151134457776363|
|    min|                  0|
|    max|                  1|
+-------+-------------------+



The problem here is that a particular user will have multiple entries, since this is a database of user activities log. A user may have visitied number of other pages before reaching the "Cancellation" page. For a user that has churned, we want to make sure that all of his/her logs record 1 for the churn column. For this we will use the Window function from `sql.window`.


In [15]:
from pyspark.sql.window import Window
from pyspark.sql.functions import sum as Fsum

In [16]:
window = Window.partitionBy("userId") \
        .rangeBetween(Window.unboundedPreceding,
                      Window.unboundedFollowing)
data = data.withColumn("churn", Fsum("churn").over(window))


In [17]:
data.filter(data['userId'] == 125).describe('churn').show()

+-------+-----+
|summary|churn|
+-------+-----+
|  count|   11|
|   mean|  1.0|
| stddev|  0.0|
|    min|    1|
|    max|    1|
+-------+-----+



Let's see how many unique users we have

In [18]:
data.select("userID").dropDuplicates().count()

225

The total churn instances, again a user has multiple entries here

In [19]:
data.filter(data['churn'] == 1).count()

44864

Now let's create a label dataframe that has one entry per each user

In [20]:
from pyspark.sql.functions import col

In [21]:
label = data \
    .select('userId', col('churn').alias('label')) \
    .dropDuplicates()
label.show()

+------+-----+
|userId|label|
+------+-----+
|100010|    0|
|200002|    0|
|   125|    1|
|   124|    0|
|    51|    1|
|     7|    0|
|    15|    0|
|    54|    1|
|   155|    0|
|100014|    1|
|   132|    0|
|   154|    0|
|   101|    1|
|    11|    0|
|   138|    0|
|300017|    0|
|100021|    1|
|    29|    1|
|    69|    0|
|   112|    0|
+------+-----+
only showing top 20 rows



## Feature engineering

Users that have been using the service tend to stay with the service, even as a paying customer, than those that recently signed up.

Every company selling products and services to customers have an idea of a "lifetime value" of a customer. With that in mind we create our first feature.

We converted the times to seconds above. After gaining the lifetime of a user, we convert that to days.

In [22]:
from pyspark.sql.functions import sum as Fsum, col

In [23]:
# time since registration
time_since_registration = data \
    .select('userId', 'registration', 'ts') \
    .withColumn('lifetime', (data.ts - data.registration)) \
    .groupBy('userId') \
    .agg({'lifetime': 'max'}) \
    .withColumnRenamed('max(lifetime)', 'lifetime') \
    .select('userId', (col('lifetime') / 3600 / 24).alias('lifetime'))

In [24]:
time_since_registration.describe('lifetime').show()

+-------+-------------------+
|summary|           lifetime|
+-------+-------------------+
|  count|                225|
|   mean|   79.8456834876543|
| stddev|  37.66147001861254|
|    min|0.31372685185185184|
|    max|  256.3776736111111|
+-------+-------------------+



Bussiness big and small not only rely on their customers coming back for more good and services, but also referring the bussiness to their friends and family. It is in our nature to refer things that we like. With this idea in mind we will create another feature.

Again we will user the "Page" column specifically looking for the `Add Friend` page.

In [25]:
# referring friends
referring_friends = data \
    .select('userID', 'page') \
    .where(data.page == 'Add Friend') \
    .groupBy('userID') \
    .count() \
    .withColumnRenamed('count', 'add_friend')

In [26]:
referring_friends.describe('add_friend').show()

+-------+------------------+
|summary|        add_friend|
+-------+------------------+
|  count|               206|
|   mean|20.762135922330096|
| stddev|20.646779074405007|
|    min|                 1|
|    max|               143|
+-------+------------------+



The more songs a user listens to, the more likely they are to enjoy the streaming service and keep their subscription. Thus we add a total_songs_listened feature

In [27]:
# total songs listened
total_songs_listened = data \
    .select('userID', 'song') \
    .groupBy('userID') \
    .count() \
    .withColumnRenamed('count', 'total_songs')

In [28]:
total_songs_listened.describe('total_songs').show()

+-------+-----------------+
|summary|      total_songs|
+-------+-----------------+
|  count|              225|
|   mean|          1236.24|
| stddev|1329.531716432519|
|    min|                6|
|    max|             9632|
+-------+-----------------+



The more songs a user liked on the streaming service, the more it potentially implies that they enjoy their subscription
and the value they get from it. They are more likely to keep their subscription if they are liking more songs. Thus we create a feature to count the number of songs a user gives a thumbs_up to.

In [29]:
# thumbs up
thumbs_up = data \
    .select('userID', 'page') \
    .where(data.page == 'Thumbs Up') \
    .groupBy('userID') \
    .count() \
    .withColumnRenamed('count', 'num_thumb_up')

In [30]:
thumbs_up.describe('num_thumb_up').show()

+-------+-----------------+
|summary|     num_thumb_up|
+-------+-----------------+
|  count|              220|
|   mean|            57.05|
| stddev|65.67028650524044|
|    min|                1|
|    max|              437|
+-------+-----------------+



Similarly, the more songs a user dislikes, the less likely they are to enjoy their subscription and cancel it. 

In [31]:
# thumbs down
thumbs_down = data \
    .select('userID', 'page') \
    .where(data.page == 'Thumbs Down') \
    .groupBy('userID') \
    .count() \
    .withColumnRenamed('count', 'num_thumb_down')

In [32]:
thumbs_down.describe('num_thumb_down').show()

+-------+------------------+
|summary|    num_thumb_down|
+-------+------------------+
|  count|               203|
|   mean|12.541871921182265|
| stddev|13.198108566983787|
|    min|                 1|
|    max|                75|
+-------+------------------+



Playlist length counts the number of times a user visited the add to playlist page indicating that they added a song to their playlist. A long playlist implies a user is enjoying several songs and wants frequent access to them. This would likely lead to them keeping their subscription. Thus, we create a feature to compute the lenght of the playlist. 

In [33]:
# Playlist length
playlist_length = data.select('userID', 'page') \
    .where(data.page == 'Add to Playlist') \
    .groupby('userID').count() \
    .withColumnRenamed('count', 'playlist_length')

In [34]:
playlist_length.describe('playlist_length').show()

+-------+-----------------+
|summary|  playlist_length|
+-------+-----------------+
|  count|              215|
|   mean|30.35348837209302|
| stddev| 32.8520568555997|
|    min|                1|
|    max|              240|
+-------+-----------------+



Avg songs per session allows us to measure how long each user session lasts when they open the application and start a session. The longer the session, the more songs the user is listening too. This could imply that the user is enjoying the application and is less likely to cancel.

In [35]:
#  avg_songs_played per session
avg_songs_played = data.where('page == "NextSong"') \
    .groupby(['userId', 'sessionId']) \
    .count() \
    .groupby(['userId']) \
    .agg({'count': 'avg'}) \
    .withColumnRenamed('avg(count)', 'avg_songs_played')

In [36]:
avg_songs_played.describe('avg_songs_played').show()

+-------+-----------------+
|summary| avg_songs_played|
+-------+-----------------+
|  count|              225|
|   mean|70.78971233958933|
| stddev| 42.6153697543817|
|    min|              3.0|
|    max|286.6666666666667|
+-------+-----------------+



Artist count measures the amount of different artists the user listens to. A user listening to a wide variety of artists could potentially be enjoying the music application more and is likely to keep their subscription. On the other hand, a user listening to only a handful of artists may not be happy with the current subscription and would be likely to churn. 

In [37]:
# artist count
artist_count = data \
    .filter(data.page == "NextSong") \
    .select("userId", "artist") \
    .dropDuplicates() \
    .groupby("userId") \
    .count() \
    .withColumnRenamed("count", "artist_count")

In [38]:
artist_count.describe('artist_count').show()

+-------+-----------------+
|summary|     artist_count|
+-------+-----------------+
|  count|              225|
|   mean|696.3777777777777|
| stddev|603.9518698630802|
|    min|                3|
|    max|             3544|
+-------+-----------------+



Total number of sessions computes how many times the user has started a session on the app. The more sessions implies more visits to the application meaning they are less likely to churn. 

In [39]:
# number of sessions
num_sessions = data \
    .select("userId", "sessionId") \
    .dropDuplicates() \
    .groupby("userId") \
    .count() \
    .withColumnRenamed('count', 'num_sessions')

In [40]:
num_sessions.describe('num_sessions').show()

+-------+------------------+
|summary|      num_sessions|
+-------+------------------+
|  count|               225|
|   mean|14.115555555555556|
| stddev|14.646884657111562|
|    min|                 1|
|    max|               107|
+-------+------------------+



## Modelling
Three different models were picked: Gradient Boosting Trees, Logistic Regression and Linear SVC. Results are shown below

In [41]:
cols = ["lifetime",
        "total_songs",
        "num_thumb_up",
        'num_thumb_down',
        'add_friend',
        'playlist_length',
        'avg_songs_played',
        'artist_count',
        'num_sessions'
        ]

In [42]:
features = [time_since_registration,
                  total_songs_listened,
                  thumbs_up,
                  thumbs_down,
                  referring_friends,
                  playlist_length,
                  avg_songs_played,
                  artist_count,
                  num_sessions]

In [43]:
data = features.pop()
while len(features) > 0:
    data = data.join(features.pop(), 'userID', 'outer')

data = data.join(label, 'userID', 'outer').fillna(0)
data.show(vertical=True, n=2)

-RECORD 0------------------------------
 userId           | 100010             
 num_sessions     | 7                  
 artist_count     | 252                
 avg_songs_played | 39.285714285714285 
 playlist_length  | 7                  
 add_friend       | 4                  
 num_thumb_down   | 5                  
 num_thumb_up     | 17                 
 total_songs      | 381                
 lifetime         | 55.6436574074074   
 label            | 0                  
-RECORD 1------------------------------
 userId           | 200002             
 num_sessions     | 6                  
 artist_count     | 339                
 avg_songs_played | 64.5               
 playlist_length  | 8                  
 add_friend       | 4                  
 num_thumb_down   | 6                  
 num_thumb_up     | 21                 
 total_songs      | 474                
 lifetime         | 70.07462962962963  
 label            | 0                  
only showing top 2 rows



In [44]:
data.printSchema()

root
 |-- userId: string (nullable = true)
 |-- num_sessions: long (nullable = true)
 |-- artist_count: long (nullable = true)
 |-- avg_songs_played: double (nullable = false)
 |-- playlist_length: long (nullable = true)
 |-- add_friend: long (nullable = true)
 |-- num_thumb_down: long (nullable = true)
 |-- num_thumb_up: long (nullable = true)
 |-- total_songs: long (nullable = true)
 |-- lifetime: double (nullable = false)
 |-- label: long (nullable = true)



In [45]:
data.groupby('label').count().show()

+-----+-----+
|label|count|
+-----+-----+
|    0|  173|
|    1|   52|
+-----+-----+



In [46]:
# Vector assembler
from pyspark.ml.feature import StandardScaler, VectorAssembler
assembler = VectorAssembler(inputCols=cols, outputCol="unScaled_features")
data = assembler.transform(data)

In [47]:
data.show(vertical=True, n=2)

-RECORD 0---------------------------------
 userId            | 100010               
 num_sessions      | 7                    
 artist_count      | 252                  
 avg_songs_played  | 39.285714285714285   
 playlist_length   | 7                    
 add_friend        | 4                    
 num_thumb_down    | 5                    
 num_thumb_up      | 17                   
 total_songs       | 381                  
 lifetime          | 55.6436574074074     
 label             | 0                    
 unScaled_features | [55.6436574074074... 
-RECORD 1---------------------------------
 userId            | 200002               
 num_sessions      | 6                    
 artist_count      | 339                  
 avg_songs_played  | 64.5                 
 playlist_length   | 8                    
 add_friend        | 4                    
 num_thumb_down    | 6                    
 num_thumb_up      | 21                   
 total_songs       | 474                  
 lifetime  

In [48]:
data = data.select('userID', 'unScaled_features', 'label')
# scale the features
scaler = StandardScaler(inputCol="unScaled_features", outputCol="features", withStd=True)
scalerModel = scaler.fit(data)
data = scalerModel.transform(data)
data.show(vertical=True, n=2)

-RECORD 0---------------------------------
 userID            | 100010               
 unScaled_features | [55.6436574074074... 
 label             | 0                    
 features          | [1.47746907860760... 
-RECORD 1---------------------------------
 userID            | 200002               
 unScaled_features | [70.0746296296296... 
 label             | 0                    
 features          | [1.86064509948757... 
only showing top 2 rows



In [49]:
# train test split
trainTest = data.randomSplit([0.8, 0.2])
trainingDF = trainTest[0]
testDF = trainTest[1]

In [50]:
from pyspark.ml.classification import GBTClassifier, LogisticRegression, LinearSVC
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

In [51]:
# Gradient Boosted Trees

# initialize classifier
GradBoostTree = GBTClassifier()

# We use a ParamGridBuilder to construct a grid of parameters to search over.
depth = [3, 5, 7] 
iterations = [10, 20, 30]
bins = [16, 32, 64]
param_grid = ParamGridBuilder() \
    .addGrid(GradBoostTree.maxDepth, depth) \
    .addGrid(GradBoostTree.maxIter, iterations) \
    .build()

evaluator = MulticlassClassificationEvaluator(metricName='f1')

cross_validator = CrossValidator(estimator=GradBoostTree,
                                 estimatorParamMaps=param_grid,
                                 evaluator=evaluator,
                                 numFolds=3)

# Fit the model
cvModel_GradBoostTree = GradBoostTree.fit(trainingDF) ## --> CHANGE THIS BACK TO cross_validator to run it

# Make Predictions
results_GradBoostTree = cvModel_GradBoostTree.transform(testDF)
results_GradBoostTree = results_GradBoostTree.select('userID', 'label', 'prediction')
results_GradBoostTree.show(10)

# Get Results
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
accuracy = evaluator.evaluate(results_GradBoostTree, {evaluator.metricName: "accuracy"})
f1Score = evaluator.evaluate(results_GradBoostTree, {evaluator.metricName: "f1"})
print('Gradient Boosted Trees Metrics:')
print('Accuracy: {:.2f}'.format(accuracy))
print('F1 Score: {:.2f}'.format(f1Score))

+------+-----+----------+
|userID|label|prediction|
+------+-----+----------+
|    51|    1|       1.0|
|     7|    0|       0.0|
|    69|    0|       0.0|
|     3|    1|       0.0|
|    30|    0|       0.0|
|    22|    0|       0.0|
|100022|    1|       0.0|
|    35|    0|       0.0|
|    52|    0|       0.0|
|    47|    0|       0.0|
+------+-----+----------+
only showing top 10 rows

Gradient Boosted Trees Metrics:
Accuracy: 0.78
F1 Score: 0.79


In [52]:
# Logistic Regression Classifier

# initialize classifier
maxIter = 10
lgr = LogisticRegression(maxIter=maxIter)

# Fit the model
cvModel_lgr = lgr.fit(trainingDF)

# Make Predictions
results_lgr = cvModel_lgr.transform(testDF).select('userID', 'label', 'prediction')
results_lgr.show(10)

# Get Results
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
accuracy = evaluator.evaluate(results_lgr, {evaluator.metricName: "accuracy"})
f1Score = evaluator.evaluate(results_lgr, {evaluator.metricName: "f1"})
print('Logistic Regression Metrics:')
print('Accuracy: {:.2f}'.format(accuracy))
print('F1 Score: {:.2f}'.format(f1Score))

+------+-----+----------+
|userID|label|prediction|
+------+-----+----------+
|    51|    1|       0.0|
|     7|    0|       0.0|
|    69|    0|       0.0|
|     3|    1|       0.0|
|    30|    0|       0.0|
|    22|    0|       0.0|
|100022|    1|       1.0|
|    35|    0|       0.0|
|    52|    0|       0.0|
|    47|    0|       0.0|
+------+-----+----------+
only showing top 10 rows

Logistic Regression Metrics:
Accuracy: 0.85
F1 Score: 0.82


In [53]:
# SVM

# initialize classifier
maxIter = 10
svc = LinearSVC(maxIter=maxIter)

# Fit the model
cvModel_svc = svc.fit(trainingDF)

# Make Predictions
results_svc = cvModel_svc.transform(testDF).select('userID', 'label', 'prediction')
results_svc.show(10)

# Get Results
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
accuracy = evaluator.evaluate(results_svc, {evaluator.metricName: "accuracy"})
f1Score = evaluator.evaluate(results_svc, {evaluator.metricName: "f1"})
print('Logistic Regression Metrics:')
print('Accuracy: {:.2f}'.format(accuracy))
print('F1 Score: {:.2f}'.format(f1Score))

+------+-----+----------+
|userID|label|prediction|
+------+-----+----------+
|    51|    1|       0.0|
|     7|    0|       0.0|
|    69|    0|       0.0|
|     3|    1|       0.0|
|    30|    0|       0.0|
|    22|    0|       0.0|
|100022|    1|       0.0|
|    35|    0|       0.0|
|    52|    0|       0.0|
|    47|    0|       0.0|
+------+-----+----------+
only showing top 10 rows

Logistic Regression Metrics:
Accuracy: 0.80
F1 Score: 0.72
