In [1]:
from pyspark.sql import SparkSession


Initialize a spark session

In [2]:
spark = SparkSession.builder.master("local[*]").appName("sparkify").getOrCreate()
spark

Read the data

In [56]:
data = spark.read.json("../data/mini_sparkify_event_data.json")

Look at the columns

In [57]:
data.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)



Look at the data, we will do a vertical show

In [58]:
data.show(vertical=True, n=2)

-RECORD 0-----------------------------
 artist        | Martha Tilston       
 auth          | Logged In            
 firstName     | Colin                
 gender        | M                    
 itemInSession | 50                   
 lastName      | Freeman              
 length        | 277.89016            
 level         | paid                 
 location      | Bakersfield, CA      
 method        | PUT                  
 page          | NextSong             
 registration  | 1538173362000        
 sessionId     | 29                   
 song          | Rockpools            
 status        | 200                  
 ts            | 1538352117000        
 userAgent     | Mozilla/5.0 (Wind... 
 userId        | 30                   
-RECORD 1-----------------------------
 artist        | Five Iron Frenzy     
 auth          | Logged In            
 firstName     | Micah                
 gender        | M                    
 itemInSession | 79                   
 lastName      | Long    

The most important column in the data set seems to be the `Page` column. It holds all Sparkify pages that the customers have visited. 

In [59]:
data.select("page").dropDuplicates().show()

+--------------------+
|                page|
+--------------------+
|              Cancel|
|    Submit Downgrade|
|         Thumbs Down|
|                Home|
|           Downgrade|
|         Roll Advert|
|              Logout|
|       Save Settings|
|Cancellation Conf...|
|               About|
| Submit Registration|
|            Settings|
|               Login|
|            Register|
|     Add to Playlist|
|          Add Friend|
|            NextSong|
|           Thumbs Up|
|                Help|
|             Upgrade|
+--------------------+
only showing top 20 rows



From the list above we can use the `Cancellation Confirmation` to indicate if a particular user churned or not. We are building a model that predicts if a user is going to churn given the history of interactions with the service.

Therfore, it is important to understand the customers that reached the `Cancellation Confirmation` page. These are the customers that churned, and the task is to build a prediction model that can recognize them.

#### Data cleaning

Now that we indentified our "goal", we can let go of some of the columns that are not needed for further analysis.

In [60]:
data = data.drop(*['artist','song','firstName', 'lastName', 'id_copy'])
data.show(vertical=True, n=2)

-RECORD 0-----------------------------
 auth          | Logged In            
 gender        | M                    
 itemInSession | 50                   
 length        | 277.89016            
 level         | paid                 
 location      | Bakersfield, CA      
 method        | PUT                  
 page          | NextSong             
 registration  | 1538173362000        
 sessionId     | 29                   
 status        | 200                  
 ts            | 1538352117000        
 userAgent     | Mozilla/5.0 (Wind... 
 userId        | 30                   
-RECORD 1-----------------------------
 auth          | Logged In            
 gender        | M                    
 itemInSession | 79                   
 length        | 236.09424            
 level         | free                 
 location      | Boston-Cambridge-... 
 method        | PUT                  
 page          | NextSong             
 registration  | 1538331630000        
 sessionId     | 8       

Since we are building a model that focuses on a user, remove any null/na values in the userID

In [61]:
from pyspark.sql.functions import isnan, isnull

In [62]:
data.filter((isnan(data['userId'])) | (data['userId'].isNull()) | (data['userId'] == "")).count()

8346

In [63]:
data = data.dropna(how = 'any', subset = ['userId'])
data = data.filter(data["userId"] != "") 
data = data.filter(data['userId'].isNull() == False)
data.filter((isnan(data['userId'])) | (data['userId'].isNull()) | (data['userId'] == "")).count()

0

#### Label

As mentioned above we are deriving the label `churn` as a function of the `Page` column. There are only two values to churn, churned or not churned. Therefore this is going to be a boolean column represented by 1 for churned and 0 for not churned.

In [64]:
import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType

In [78]:
cancelation_event_udf = F.udf(lambda x: 1 if x == "Cancellation Confirmation" else 0, IntegerType())
data = data.withColumn("churn", cancelation_event_udf("page"))
data.filter(data['userId'] == 125).show(vertical=True, n=2)

-RECORD 0-----------------------------
 auth          | Logged In            
 gender        | M                    
 itemInSession | 0                    
 length        | 337.91955            
 level         | free                 
 location      | Corpus Christi, TX   
 method        | PUT                  
 page          | NextSong             
 registration  | 1533157139000        
 sessionId     | 174                  
 status        | 200                  
 ts            | 1539317144000        
 userAgent     | "Mozilla/5.0 (Mac... 
 userId        | 125                  
 churn         | 0                    
-RECORD 1-----------------------------
 auth          | Logged In            
 gender        | M                    
 itemInSession | 1                    
 length        | 230.03383            
 level         | free                 
 location      | Corpus Christi, TX   
 method        | PUT                  
 page          | NextSong             
 registration  | 15331571

In [76]:
data.filter(data['churn'] == 1).show(vertical=True, n=5)

-RECORD 0-----------------------------
 auth          | Logged In            
 gender        | M                    
 itemInSession | 0                    
 length        | 337.91955            
 level         | free                 
 location      | Corpus Christi, TX   
 method        | PUT                  
 page          | NextSong             
 registration  | 1533157139000        
 sessionId     | 174                  
 status        | 200                  
 ts            | 1539317144000        
 userAgent     | "Mozilla/5.0 (Mac... 
 userId        | 125                  
 churn         | 1                    
-RECORD 1-----------------------------
 auth          | Logged In            
 gender        | M                    
 itemInSession | 1                    
 length        | 230.03383            
 level         | free                 
 location      | Corpus Christi, TX   
 method        | PUT                  
 page          | NextSong             
 registration  | 15331571

The problem here is that a particular user will have multiple entries, since this is a database of user activities log. A user may have visitied number of other pages before reaching the "Cancellation" page. For a user that has churned, we want to make sure that all of his/her logs record 1 for the churn column. For this we will use the Window function from `sql.window`.


In [72]:
from pyspark.sql.window import Window

In [79]:
window = Window.partitionBy("userId") \
        .rangeBetween(Window.unboundedPreceding,
                      Window.unboundedFollowing)
data = data.withColumn("churn", Fsum("churn").over(window))


In [81]:
data.filter(data['userId'] == 125).show(vertical=True, n=5)

-RECORD 0-----------------------------
 auth          | Logged In            
 gender        | M                    
 itemInSession | 0                    
 length        | 337.91955            
 level         | free                 
 location      | Corpus Christi, TX   
 method        | PUT                  
 page          | NextSong             
 registration  | 1533157139000        
 sessionId     | 174                  
 status        | 200                  
 ts            | 1539317144000        
 userAgent     | "Mozilla/5.0 (Mac... 
 userId        | 125                  
 churn         | 1                    
-RECORD 1-----------------------------
 auth          | Logged In            
 gender        | M                    
 itemInSession | 1                    
 length        | 230.03383            
 level         | free                 
 location      | Corpus Christi, TX   
 method        | PUT                  
 page          | NextSong             
 registration  | 15331571

Let's see how many unique users we have

In [82]:
data.select("userID").dropDuplicates().count()

225

The total churn instances, again a user has multiple entries here

In [84]:
data.filter(data['churn'] == 1).count()

44864

Now let's create a label dataframe that has one entry per each user

In [85]:
label = data \
    .select('userId', col('churn').alias('label')) \
    .dropDuplicates()
label.show()

+------+-----+
|userId|label|
+------+-----+
|100010|    0|
|200002|    0|
|   125|    1|
|   124|    0|
|    51|    1|
|     7|    0|
|    15|    0|
|    54|    1|
|   155|    0|
|100014|    1|
|   132|    0|
|   154|    0|
|   101|    1|
|    11|    0|
|   138|    0|
|300017|    0|
|100021|    1|
|    29|    1|
|    69|    0|
|   112|    0|
+------+-----+
only showing top 20 rows



#### Feature engineering

#### Modeling